Commit ff3be8e3 authored by rusty1s's avatar rusty1s
Browse files

pytorch 1.4 support: toIntVector -> to IntList

parent 7ef77d92
#include <torch/script.h> #include <torch/script.h>
#include "cpu/scatter_cpu.h" #include "cpu/scatter_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA #ifdef WITH_CUDA
#include "cuda/scatter_cuda.h" #include "cuda/scatter_cuda.h"
...@@ -58,7 +59,7 @@ public: ...@@ -58,7 +59,7 @@ public:
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto index = saved[0]; auto index = saved[0];
auto dim = ctx->saved_data["dim"].toInt(); auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = ctx->saved_data["src_shape"].toIntVector(); auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::gather(grad_out, dim, index, false); auto grad_in = torch::gather(grad_out, dim, index, false);
return {grad_in, Variable(), Variable(), Variable(), Variable()}; return {grad_in, Variable(), Variable(), Variable(), Variable()};
} }
...@@ -100,7 +101,7 @@ public: ...@@ -100,7 +101,7 @@ public:
auto index = saved[0]; auto index = saved[0];
auto count = saved[1]; auto count = saved[1];
auto dim = ctx->saved_data["dim"].toInt(); auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = ctx->saved_data["src_shape"].toIntVector(); auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
count = torch::gather(count, dim, index, false); count = torch::gather(count, dim, index, false);
auto grad_in = torch::gather(grad_out, dim, index, false); auto grad_in = torch::gather(grad_out, dim, index, false);
grad_in.div_(count); grad_in.div_(count);
...@@ -134,7 +135,7 @@ public: ...@@ -134,7 +135,7 @@ public:
auto index = saved[0]; auto index = saved[0];
auto arg_out = saved[1]; auto arg_out = saved[1];
auto dim = ctx->saved_data["dim"].toInt(); auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = ctx->saved_data["src_shape"].toIntVector(); auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[dim] += 1; src_shape[dim] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options()); auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(dim, arg_out, grad_out); grad_in.scatter_(dim, arg_out, grad_out);
...@@ -169,7 +170,7 @@ public: ...@@ -169,7 +170,7 @@ public:
auto index = saved[0]; auto index = saved[0];
auto arg_out = saved[1]; auto arg_out = saved[1];
auto dim = ctx->saved_data["dim"].toInt(); auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = ctx->saved_data["src_shape"].toIntVector(); auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[dim] += 1; src_shape[dim] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options()); auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(dim, arg_out, grad_out); grad_in.scatter_(dim, arg_out, grad_out);
......
#include <torch/script.h> #include <torch/script.h>
#include "cpu/segment_coo_cpu.h" #include "cpu/segment_coo_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA #ifdef WITH_CUDA
#include "cuda/segment_coo_cuda.h" #include "cuda/segment_coo_cuda.h"
...@@ -57,7 +58,7 @@ public: ...@@ -57,7 +58,7 @@ public:
auto grad_out = grad_outs[0]; auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto index = saved[0]; auto index = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector(); auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options()); auto grad_in = torch::empty(src_shape, grad_out.options());
gather_coo_fw(grad_out, index, grad_in); gather_coo_fw(grad_out, index, grad_in);
return {grad_in, Variable(), Variable(), Variable()}; return {grad_in, Variable(), Variable(), Variable()};
...@@ -85,7 +86,7 @@ public: ...@@ -85,7 +86,7 @@ public:
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto index = saved[0]; auto index = saved[0];
auto count = saved[1]; auto count = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector(); auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options()); auto grad_in = torch::empty(src_shape, grad_out.options());
gather_coo_fw(grad_out, index, grad_in); gather_coo_fw(grad_out, index, grad_in);
count = gather_coo_fw(count, index, torch::nullopt); count = gather_coo_fw(count, index, torch::nullopt);
...@@ -118,7 +119,7 @@ public: ...@@ -118,7 +119,7 @@ public:
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto index = saved[0]; auto index = saved[0];
auto arg_out = saved[1]; auto arg_out = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector(); auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[index.dim() - 1] += 1; src_shape[index.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options()); auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(index.dim() - 1, arg_out, grad_out); grad_in.scatter_(index.dim() - 1, arg_out, grad_out);
...@@ -150,7 +151,7 @@ public: ...@@ -150,7 +151,7 @@ public:
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto index = saved[0]; auto index = saved[0];
auto arg_out = saved[1]; auto arg_out = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector(); auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[index.dim() - 1] += 1; src_shape[index.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options()); auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(index.dim() - 1, arg_out, grad_out); grad_in.scatter_(index.dim() - 1, arg_out, grad_out);
...@@ -177,7 +178,7 @@ public: ...@@ -177,7 +178,7 @@ public:
auto grad_out = grad_outs[0]; auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto index = saved[0]; auto index = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector(); auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::zeros(src_shape, grad_out.options()); auto grad_in = torch::zeros(src_shape, grad_out.options());
segment_coo_fw(grad_out, index, grad_in, torch::nullopt, "sum"); segment_coo_fw(grad_out, index, grad_in, torch::nullopt, "sum");
......
#include <torch/script.h> #include <torch/script.h>
#include "cpu/segment_csr_cpu.h" #include "cpu/segment_csr_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA #ifdef WITH_CUDA
#include "cuda/segment_csr_cuda.h" #include "cuda/segment_csr_cuda.h"
...@@ -55,7 +56,7 @@ public: ...@@ -55,7 +56,7 @@ public:
auto grad_out = grad_outs[0]; auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto indptr = saved[0]; auto indptr = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector(); auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options()); auto grad_in = torch::empty(src_shape, grad_out.options());
gather_csr_fw(grad_out, indptr, grad_in); gather_csr_fw(grad_out, indptr, grad_in);
return {grad_in, Variable(), Variable()}; return {grad_in, Variable(), Variable()};
...@@ -79,7 +80,7 @@ public: ...@@ -79,7 +80,7 @@ public:
auto grad_out = grad_outs[0]; auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto indptr = saved[0]; auto indptr = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector(); auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options()); auto grad_in = torch::empty(src_shape, grad_out.options());
gather_csr_fw(grad_out, indptr, grad_in); gather_csr_fw(grad_out, indptr, grad_in);
auto indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1); auto indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1);
...@@ -114,7 +115,7 @@ public: ...@@ -114,7 +115,7 @@ public:
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto indptr = saved[0]; auto indptr = saved[0];
auto arg_out = saved[1]; auto arg_out = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector(); auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[indptr.dim() - 1] += 1; src_shape[indptr.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options()); auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out); grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out);
...@@ -145,7 +146,7 @@ public: ...@@ -145,7 +146,7 @@ public:
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto indptr = saved[0]; auto indptr = saved[0];
auto arg_out = saved[1]; auto arg_out = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector(); auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[indptr.dim() - 1] += 1; src_shape[indptr.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options()); auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out); grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out);
...@@ -172,7 +173,7 @@ public: ...@@ -172,7 +173,7 @@ public:
auto grad_out = grad_outs[0]; auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables(); auto saved = ctx->get_saved_variables();
auto indptr = saved[0]; auto indptr = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector(); auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options()); auto grad_in = torch::empty(src_shape, grad_out.options());
segment_csr_fw(grad_out, indptr, grad_in, "sum"); segment_csr_fw(grad_out, indptr, grad_in, "sum");
......
#pragma once
#include <torch/script.h>
#include <vector>
inline std::vector<int64_t> list2vec(const c10::List<int64_t> list) {
std::vector<int64_t> result;
result.reserve(list.size());
for (size_t i = 0; i < list.size(); i++)
result.push_back(list[i]);
return result;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment