Commit ff3be8e3 authored by rusty1s's avatar rusty1s
Browse files

pytorch 1.4 support: toIntVector -> to IntList

parent 7ef77d92
#include <torch/script.h>
#include "cpu/scatter_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA
#include "cuda/scatter_cuda.h"
......@@ -58,7 +59,7 @@ public:
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::gather(grad_out, dim, index, false);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
......@@ -100,7 +101,7 @@ public:
auto index = saved[0];
auto count = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
count = torch::gather(count, dim, index, false);
auto grad_in = torch::gather(grad_out, dim, index, false);
grad_in.div_(count);
......@@ -134,7 +135,7 @@ public:
auto index = saved[0];
auto arg_out = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[dim] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(dim, arg_out, grad_out);
......@@ -169,7 +170,7 @@ public:
auto index = saved[0];
auto arg_out = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[dim] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(dim, arg_out, grad_out);
......
#include <torch/script.h>
#include "cpu/segment_coo_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA
#include "cuda/segment_coo_cuda.h"
......@@ -57,7 +58,7 @@ public:
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_coo_fw(grad_out, index, grad_in);
return {grad_in, Variable(), Variable(), Variable()};
......@@ -85,7 +86,7 @@ public:
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto count = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_coo_fw(grad_out, index, grad_in);
count = gather_coo_fw(count, index, torch::nullopt);
......@@ -118,7 +119,7 @@ public:
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[index.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(index.dim() - 1, arg_out, grad_out);
......@@ -150,7 +151,7 @@ public:
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[index.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(index.dim() - 1, arg_out, grad_out);
......@@ -177,7 +178,7 @@ public:
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::zeros(src_shape, grad_out.options());
segment_coo_fw(grad_out, index, grad_in, torch::nullopt, "sum");
......
#include <torch/script.h>
#include "cpu/segment_csr_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA
#include "cuda/segment_csr_cuda.h"
......@@ -55,7 +56,7 @@ public:
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_csr_fw(grad_out, indptr, grad_in);
return {grad_in, Variable(), Variable()};
......@@ -79,7 +80,7 @@ public:
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_csr_fw(grad_out, indptr, grad_in);
auto indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1);
......@@ -114,7 +115,7 @@ public:
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto arg_out = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[indptr.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out);
......@@ -145,7 +146,7 @@ public:
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto arg_out = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[indptr.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out);
......@@ -172,7 +173,7 @@ public:
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options());
segment_csr_fw(grad_out, indptr, grad_in, "sum");
......
#pragma once
#include <torch/script.h>
#include <vector>
inline std::vector<int64_t> list2vec(const c10::List<int64_t> list) {
std::vector<int64_t> result;
result.reserve(list.size());
for (size_t i = 0; i < list.size(); i++)
result.push_back(list[i]);
return result;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment