Commit 62815576 authored by rusty1s's avatar rusty1s
Browse files

moved extensions to torch.ops

parent 0a221ab8
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
auto TENSOR3##_stride = TENSOR3.stride(DIM); \ auto TENSOR3##_stride = TENSOR3.stride(DIM); \
\ \
auto dims = TENSOR1.dim(); \ auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \ auto zeros = torch::zeros(dims, TENSOR1.options().dtype(torch::kLong)); \
auto counter = zeros.DATA_PTR<int64_t>(); \ auto counter = zeros.DATA_PTR<int64_t>(); \
bool has_finished = false; \ bool has_finished = false; \
\ \
...@@ -78,7 +78,7 @@ ...@@ -78,7 +78,7 @@
auto TENSOR4##_stride = TENSOR4.stride(DIM); \ auto TENSOR4##_stride = TENSOR4.stride(DIM); \
\ \
auto dims = TENSOR1.dim(); \ auto dims = TENSOR1.dim(); \
auto zeros = at::zeros(dims, TENSOR1.options().dtype(at::kLong)); \ auto zeros = torch::zeros(dims, TENSOR1.options().dtype(torch::kLong)); \
auto counter = zeros.DATA_PTR<int64_t>(); \ auto counter = zeros.DATA_PTR<int64_t>(); \
bool has_finished = false; \ bool has_finished = false; \
\ \
......
#include <torch/extension.h> #include <torch/script.h>
#include "compat.h" #include "compat.h"
#include "index_info.h" #include "index_info.h"
#include <vector> #include <vector>
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor") #define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
at::Tensor gather_csr(at::Tensor src, at::Tensor indptr, torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
at::optional<at::Tensor> out_opt) { torch::optional<torch::Tensor> out_opt) {
CHECK_CPU(src); CHECK_CPU(src);
CHECK_CPU(indptr); CHECK_CPU(indptr);
if (out_opt.has_value()) if (out_opt.has_value())
...@@ -23,7 +23,7 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr, ...@@ -23,7 +23,7 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1, AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1,
"Input mismatch"); "Input mismatch");
at::Tensor out; torch::Tensor out;
if (out_opt.has_value()) { if (out_opt.has_value()) {
out = out_opt.value().contiguous(); out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++) for (int i = 0; i < out.dim(); i++)
...@@ -32,7 +32,7 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr, ...@@ -32,7 +32,7 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
} else { } else {
auto sizes = src.sizes().vec(); auto sizes = src.sizes().vec();
sizes[gather_dim] = *indptr.flatten()[-1].DATA_PTR<int64_t>(); sizes[gather_dim] = *indptr.flatten()[-1].DATA_PTR<int64_t>();
out = at::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
} }
auto N = src.size(gather_dim) * (indptr.numel() / indptr.size(-1)); auto N = src.size(gather_dim) * (indptr.numel() / indptr.size(-1));
...@@ -68,8 +68,8 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr, ...@@ -68,8 +68,8 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
return out; return out;
} }
at::Tensor gather_coo(at::Tensor src, at::Tensor index, torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
at::optional<at::Tensor> out_opt) { torch::optional<torch::Tensor> out_opt) {
CHECK_CPU(src); CHECK_CPU(src);
CHECK_CPU(index); CHECK_CPU(index);
if (out_opt.has_value()) if (out_opt.has_value())
...@@ -82,7 +82,7 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index, ...@@ -82,7 +82,7 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
src = src.contiguous(); src = src.contiguous();
auto gather_dim = index.dim() - 1; auto gather_dim = index.dim() - 1;
at::Tensor out; torch::Tensor out;
if (out_opt.has_value()) { if (out_opt.has_value()) {
out = out_opt.value().contiguous(); out = out_opt.value().contiguous();
for (int i = 0; i < index.dim(); i++) for (int i = 0; i < index.dim(); i++)
...@@ -92,7 +92,7 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index, ...@@ -92,7 +92,7 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
} else { } else {
auto sizes = src.sizes().vec(); auto sizes = src.sizes().vec();
sizes[gather_dim] = index.size(gather_dim); sizes[gather_dim] = index.size(gather_dim);
out = at::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
} }
auto E_1 = index.numel() / out.size(gather_dim); auto E_1 = index.numel() / out.size(gather_dim);
...@@ -139,7 +139,6 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index, ...@@ -139,7 +139,6 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
return out; return out;
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { static auto registry =
m.def("gather_csr", &gather_csr, "Gather CSR (CPU)"); torch::RegisterOperators("torch_scatter_cpu::gather_csr", &gather_csr)
m.def("gather_coo", &gather_coo, "Gather COO (CPU)"); .op("torch_scatter_cpu::gather_coo", &gather_coo);
}
...@@ -26,7 +26,7 @@ template <typename scalar_t> struct TensorInfo { ...@@ -26,7 +26,7 @@ template <typename scalar_t> struct TensorInfo {
}; };
template <typename scalar_t> template <typename scalar_t>
TensorInfo<scalar_t> getTensorInfo(const at::Tensor &tensor) { TensorInfo<scalar_t> getTensorInfo(const torch::Tensor &tensor) {
int sizes[MAX_TENSORINFO_DIMS]; int sizes[MAX_TENSORINFO_DIMS];
int strides[MAX_TENSORINFO_DIMS]; int strides[MAX_TENSORINFO_DIMS];
......
#include <torch/extension.h> #include <torch/script.h>
#include "dim_apply.h" #include "dim_apply.h"
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor") #define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_mul(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) { int64_t dim) {
CHECK_CPU(src); CHECK_CPU(src);
CHECK_CPU(index); CHECK_CPU(index);
...@@ -20,7 +20,7 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -20,7 +20,7 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
}); });
} }
void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_div(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) { int64_t dim) {
CHECK_CPU(src); CHECK_CPU(src);
CHECK_CPU(index); CHECK_CPU(index);
...@@ -36,8 +36,8 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -36,8 +36,8 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
}); });
} }
void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_max(torch::Tensor src, torch::Tensor index, torch::Tensor out,
at::Tensor arg, int64_t dim) { torch::Tensor arg, int64_t dim) {
CHECK_CPU(src); CHECK_CPU(src);
CHECK_CPU(index); CHECK_CPU(index);
CHECK_CPU(out); CHECK_CPU(out);
...@@ -56,8 +56,8 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -56,8 +56,8 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
}); });
} }
void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_min(torch::Tensor src, torch::Tensor index, torch::Tensor out,
at::Tensor arg, int64_t dim) { torch::Tensor arg, int64_t dim) {
CHECK_CPU(src); CHECK_CPU(src);
CHECK_CPU(index); CHECK_CPU(index);
CHECK_CPU(out); CHECK_CPU(out);
...@@ -77,9 +77,8 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -77,9 +77,8 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
}); });
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { static auto registry =
m.def("scatter_mul", &scatter_mul, "Scatter Mul (CPU)"); torch::RegisterOperators("torch_scatter_cpu::scatter_mul", &scatter_mul)
m.def("scatter_div", &scatter_div, "Scatter Div (CPU)"); .op("torch_scatter_cpu::scatter_div", &scatter_div)
m.def("scatter_max", &scatter_max, "Scatter Max (CPU)"); .op("torch_scatter_cpu::scatter_max", &scatter_max)
m.def("scatter_min", &scatter_min, "Scatter Min (CPU)"); .op("torch_scatter_cpu::scatter_min", &scatter_min);
}
#include <torch/extension.h> #include <torch/script.h>
#include "compat.h" #include "compat.h"
#include "index_info.h" #include "index_info.h"
#include <vector> #include <vector>
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor") #define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor")
enum ReductionType { SUM, MEAN, MIN, MAX }; enum ReductionType { SUM, MEAN, MIN, MAX };
...@@ -74,9 +74,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -74,9 +74,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
} }
}; };
std::tuple<at::Tensor, at::optional<at::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt, segment_csr(torch::Tensor src, torch::Tensor indptr,
std::string reduce) { torch::optional<torch::Tensor> out_opt, std::string reduce) {
CHECK_CPU(src); CHECK_CPU(src);
CHECK_CPU(indptr); CHECK_CPU(indptr);
if (out_opt.has_value()) if (out_opt.has_value())
...@@ -94,7 +94,7 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt, ...@@ -94,7 +94,7 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
src = src.contiguous(); src = src.contiguous();
auto reduce_dim = indptr.dim() - 1; auto reduce_dim = indptr.dim() - 1;
at::Tensor out; torch::Tensor out;
if (out_opt.has_value()) { if (out_opt.has_value()) {
out = out_opt.value().contiguous(); out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++) for (int i = 0; i < out.dim(); i++)
...@@ -105,13 +105,13 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt, ...@@ -105,13 +105,13 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
} else { } else {
sizes = src.sizes().vec(); sizes = src.sizes().vec();
sizes[reduce_dim] = indptr.size(reduce_dim) - 1; sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
out = at::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
} }
at::optional<at::Tensor> arg_out = at::nullopt; torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr; int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = at::full_like(out, src.size(reduce_dim), indptr.options()); arg_out = torch::full_like(out, src.size(reduce_dim), indptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>(); arg_out_data = arg_out.value().DATA_PTR<int64_t>();
} }
...@@ -156,8 +156,8 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt, ...@@ -156,8 +156,8 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
} }
std::tuple<at::Tensor, at::optional<at::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, segment_coo(torch::Tensor src, torch::Tensor index, torch::Tensor out,
std::string reduce) { std::string reduce) {
CHECK_CPU(src); CHECK_CPU(src);
CHECK_CPU(index); CHECK_CPU(index);
...@@ -180,10 +180,10 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -180,10 +180,10 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
if (i != reduce_dim) if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch"); AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
at::optional<at::Tensor> arg_out = at::nullopt; torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr; int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = at::full_like(out, src.size(reduce_dim), index.options()); arg_out = torch::full_like(out, src.size(reduce_dim), index.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>(); arg_out_data = arg_out.value().DATA_PTR<int64_t>();
} }
...@@ -251,7 +251,6 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -251,7 +251,6 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { static auto registry =
m.def("segment_csr", &segment_csr, "Segment CSR (CPU)"); torch::RegisterOperators("torch_scatter_cpu::segment_csr", &segment_csr)
m.def("segment_coo", &segment_coo, "Segment COO (CPU)"); .op("torch_scatter_cpu::segment_coo", &segment_coo);
}
#include <torch/extension.h> #include <torch/script.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor") #define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr, torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
at::optional<at::Tensor> out_opt); torch::optional<torch::Tensor> out_opt);
at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index, torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
at::optional<at::Tensor> out_opt); torch::optional<torch::Tensor> out_opt);
at::Tensor gather_csr(at::Tensor src, at::Tensor indptr, torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr,
at::optional<at::Tensor> out_opt) { torch::optional<torch::Tensor> out_opt) {
CHECK_CUDA(src); CHECK_CUDA(src);
CHECK_CUDA(indptr); CHECK_CUDA(indptr);
if (out_opt.has_value()) if (out_opt.has_value())
...@@ -16,8 +17,8 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr, ...@@ -16,8 +17,8 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
return gather_csr_cuda(src, indptr, out_opt); return gather_csr_cuda(src, indptr, out_opt);
} }
at::Tensor gather_coo(at::Tensor src, at::Tensor index, torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
at::optional<at::Tensor> out_opt) { torch::optional<torch::Tensor> out_opt) {
CHECK_CUDA(src); CHECK_CUDA(src);
CHECK_CUDA(index); CHECK_CUDA(index);
if (out_opt.has_value()) if (out_opt.has_value())
...@@ -25,7 +26,6 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index, ...@@ -25,7 +26,6 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
return gather_coo_cuda(src, index, out_opt); return gather_coo_cuda(src, index, out_opt);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { static auto registry =
m.def("gather_csr", &gather_csr, "Gather CSR (CUDA)"); torch::RegisterOperators("torch_scatter_cuda::gather_csr", &gather_csr)
m.def("gather_coo", &gather_coo, "Gather COO (CUDA)"); .op("torch_scatter_cuda::gather_coo", &gather_coo);
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh> #include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
#include "compat.cuh" #include "compat.cuh"
#include "indptr.cuh" #include "indptr.cuh"
...@@ -58,9 +58,10 @@ __global__ void gather_csr_broadcast_kernel( ...@@ -58,9 +58,10 @@ __global__ void gather_csr_broadcast_kernel(
} }
} }
at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr, torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
at::optional<at::Tensor> out_opt) { torch::optional<torch::Tensor> out_opt) {
cudaSetDevice(src.get_device());
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch"); AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
for (int i = 0; i < indptr.dim() - 1; i++) for (int i = 0; i < indptr.dim() - 1; i++)
AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch"); AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch");
...@@ -70,7 +71,7 @@ at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr, ...@@ -70,7 +71,7 @@ at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr,
AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1, AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1,
"Input mismatch"); "Input mismatch");
at::Tensor out; torch::Tensor out;
if (out_opt.has_value()) { if (out_opt.has_value()) {
out = out_opt.value().contiguous(); out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++) for (int i = 0; i < out.dim(); i++)
...@@ -152,8 +153,10 @@ __global__ void gather_coo_broadcast_kernel( ...@@ -152,8 +153,10 @@ __global__ void gather_coo_broadcast_kernel(
} }
} }
at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index, torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
at::optional<at::Tensor> out_opt) { torch::optional<torch::Tensor> out_opt) {
cudaSetDevice(src.get_device());
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch"); AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
for (int i = 0; i < index.dim() - 1; i++) for (int i = 0; i < index.dim() - 1; i++)
...@@ -162,7 +165,7 @@ at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index, ...@@ -162,7 +165,7 @@ at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index,
src = src.contiguous(); src = src.contiguous();
auto gather_dim = index.dim() - 1; auto gather_dim = index.dim() - 1;
at::Tensor out; torch::Tensor out;
if (out_opt.has_value()) { if (out_opt.has_value()) {
out = out_opt.value().contiguous(); out = out_opt.value().contiguous();
for (int i = 0; i < index.dim(); i++) for (int i = 0; i < index.dim(); i++)
...@@ -172,7 +175,7 @@ at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index, ...@@ -172,7 +175,7 @@ at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index,
} else { } else {
auto sizes = src.sizes().vec(); auto sizes = src.sizes().vec();
sizes[gather_dim] = index.size(gather_dim); sizes[gather_dim] = index.size(gather_dim);
out = at::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
} }
auto E = index.numel(); auto E = index.numel();
......
#pragma once #pragma once
#include <ATen/ATen.h>
#include <ATen/cuda/detail/TensorInfo.cuh> #include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
template <typename scalar1, typename scalar2, int64_t Dims> template <typename scalar1, typename scalar2, int64_t Dims>
struct IndexToScatterOffsets3 { struct IndexToScatterOffsets3 {
......
#pragma once #pragma once
#include <ATen/ATen.h>
#include <ATen/cuda/detail/TensorInfo.cuh> #include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
// We need our own `IndexToOffset` implementation since we do not want to // We need our own `IndexToOffset` implementation since we do not want to
// access the last element of the `indexptr`. // access the last element of the `indexptr`.
......
#include <torch/extension.h> #include <torch/script.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor") #define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
void scatter_mul_cuda(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_mul_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim); int64_t dim);
void scatter_div_cuda(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_div_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim); int64_t dim);
void scatter_max_cuda(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_max_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
at::Tensor arg, int64_t dim); torch::Tensor arg, int64_t dim);
void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_min_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
at::Tensor arg, int64_t dim); torch::Tensor arg, int64_t dim);
void index_backward_cuda(at::Tensor grad, at::Tensor index, at::Tensor arg, void index_backward_cuda(torch::Tensor grad, torch::Tensor index,
at::Tensor out, int64_t dim); torch::Tensor arg, torch::Tensor out, int64_t dim);
void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_mul(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) { int64_t dim) {
CHECK_CUDA(src); CHECK_CUDA(src);
CHECK_CUDA(index); CHECK_CUDA(index);
...@@ -21,7 +22,7 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -21,7 +22,7 @@ void scatter_mul(at::Tensor src, at::Tensor index, at::Tensor out,
scatter_mul_cuda(src, index, out, dim); scatter_mul_cuda(src, index, out, dim);
} }
void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_div(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) { int64_t dim) {
CHECK_CUDA(src); CHECK_CUDA(src);
CHECK_CUDA(index); CHECK_CUDA(index);
...@@ -29,8 +30,8 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -29,8 +30,8 @@ void scatter_div(at::Tensor src, at::Tensor index, at::Tensor out,
scatter_div_cuda(src, index, out, dim); scatter_div_cuda(src, index, out, dim);
} }
void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_max(torch::Tensor src, torch::Tensor index, torch::Tensor out,
at::Tensor arg, int64_t dim) { torch::Tensor arg, int64_t dim) {
CHECK_CUDA(src); CHECK_CUDA(src);
CHECK_CUDA(index); CHECK_CUDA(index);
CHECK_CUDA(out); CHECK_CUDA(out);
...@@ -38,8 +39,8 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -38,8 +39,8 @@ void scatter_max(at::Tensor src, at::Tensor index, at::Tensor out,
scatter_max_cuda(src, index, out, arg, dim); scatter_max_cuda(src, index, out, arg, dim);
} }
void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_min(torch::Tensor src, torch::Tensor index, torch::Tensor out,
at::Tensor arg, int64_t dim) { torch::Tensor arg, int64_t dim) {
CHECK_CUDA(src); CHECK_CUDA(src);
CHECK_CUDA(index); CHECK_CUDA(index);
CHECK_CUDA(out); CHECK_CUDA(out);
...@@ -47,9 +48,8 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -47,9 +48,8 @@ void scatter_min(at::Tensor src, at::Tensor index, at::Tensor out,
scatter_min_cuda(src, index, out, arg, dim); scatter_min_cuda(src, index, out, arg, dim);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { static auto registry =
m.def("scatter_mul", &scatter_mul, "Scatter Mul (CUDA)"); torch::RegisterOperators("torch_scatter_cuda::scatter_mul", &scatter_mul)
m.def("scatter_div", &scatter_div, "Scatter Div (CUDA)"); .op("torch_scatter_cuda::scatter_div", &scatter_div)
m.def("scatter_max", &scatter_max, "Scatter Max (CUDA)"); .op("torch_scatter_cuda::scatter_max", &scatter_max)
m.def("scatter_min", &scatter_min, "Scatter Min (CUDA)"); .op("torch_scatter_cuda::scatter_min", &scatter_min);
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh> #include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
#include "atomics.cuh" #include "atomics.cuh"
#include "index.cuh" #include "index.cuh"
...@@ -9,8 +9,6 @@ ...@@ -9,8 +9,6 @@
#define THREADS 1024 #define THREADS 1024
#define BLOCKS(N) (N + THREADS - 1) / THREADS #define BLOCKS(N) (N + THREADS - 1) / THREADS
auto stream = at::cuda::getCurrentCUDAStream();
#define KERNEL_RUN(NAME, DIMS, N, ...) \ #define KERNEL_RUN(NAME, DIMS, N, ...) \
[&] { \ [&] { \
auto stream = at::cuda::getCurrentCUDAStream(); \ auto stream = at::cuda::getCurrentCUDAStream(); \
...@@ -45,8 +43,9 @@ scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src, ...@@ -45,8 +43,9 @@ scatter_mul_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
} }
} }
void scatter_mul_cuda(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_mul_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) { int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_mul_kernel", [&] {
KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(), KERNEL_RUN(scatter_mul_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src), at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
...@@ -71,8 +70,9 @@ scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src, ...@@ -71,8 +70,9 @@ scatter_div_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
} }
} }
void scatter_div_cuda(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_div_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
int64_t dim) { int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_div_kernel", [&] {
KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(), KERNEL_RUN(scatter_div_kernel, index.dim(), index.numel(),
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src), at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src),
...@@ -116,8 +116,9 @@ scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src, ...@@ -116,8 +116,9 @@ scatter_max_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
} }
} }
void scatter_max_cuda(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_max_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
at::Tensor arg, int64_t dim) { torch::Tensor arg, int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_max_kernel", [&] {
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src); auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index); auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
...@@ -148,6 +149,7 @@ scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src, ...@@ -148,6 +149,7 @@ scatter_min_kernel(at::cuda::detail::TensorInfo<scalar_t, int64_t> src,
void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out, void scatter_min_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
at::Tensor arg, int64_t dim) { at::Tensor arg, int64_t dim) {
cudaSetDevice(src.get_device());
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min_kernel", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter_min_kernel", [&] {
auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src); auto src_info = at::cuda::detail::getTensorInfo<scalar_t, int64_t>(src);
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index); auto index_info = at::cuda::detail::getTensorInfo<int64_t, int64_t>(index);
......
#include <torch/extension.h> #include <torch/script.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be CUDA tensor") #define CHECK_CUDA(x) \
AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor")
std::tuple<at::Tensor, at::optional<at::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cuda(at::Tensor src, at::Tensor indptr, segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
at::optional<at::Tensor> out_opt, std::string reduce); torch::optional<torch::Tensor> out_opt, std::string reduce);
std::tuple<at::Tensor, at::optional<at::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, segment_coo_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
std::string reduce); std::string reduce);
std::tuple<at::Tensor, at::optional<at::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt, segment_csr(torch::Tensor src, torch::Tensor indptr,
std::string reduce) { torch::optional<torch::Tensor> out_opt, std::string reduce) {
CHECK_CUDA(src); CHECK_CUDA(src);
CHECK_CUDA(indptr); CHECK_CUDA(indptr);
if (out_opt.has_value()) if (out_opt.has_value())
...@@ -19,8 +20,8 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt, ...@@ -19,8 +20,8 @@ segment_csr(at::Tensor src, at::Tensor indptr, at::optional<at::Tensor> out_opt,
return segment_csr_cuda(src, indptr, out_opt, reduce); return segment_csr_cuda(src, indptr, out_opt, reduce);
} }
std::tuple<at::Tensor, at::optional<at::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, segment_coo(torch::Tensor src, torch::Tensor index, torch::Tensor out,
std::string reduce) { std::string reduce) {
CHECK_CUDA(src); CHECK_CUDA(src);
CHECK_CUDA(index); CHECK_CUDA(index);
...@@ -28,7 +29,6 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -28,7 +29,6 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
return segment_coo_cuda(src, index, out, reduce); return segment_coo_cuda(src, index, out, reduce);
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { static auto registry =
m.def("segment_csr", &segment_csr, "Segment CSR (CUDA)"); torch::RegisterOperators("torch_scatter_cuda::segment_csr", &segment_csr)
m.def("segment_coo", &segment_coo, "Segment COO (CUDA)"); .op("torch_scatter_cuda::segment_coo", &segment_coo);
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/IndexUtils.cuh> #include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh> #include <ATen/cuda/detail/TensorInfo.cuh>
#include <torch/extension.h>
#include "atomics.cuh" #include "atomics.cuh"
#include "compat.cuh" #include "compat.cuh"
...@@ -181,9 +181,11 @@ __global__ void segment_csr_broadcast_kernel( ...@@ -181,9 +181,11 @@ __global__ void segment_csr_broadcast_kernel(
} }
} }
std::tuple<at::Tensor, at::optional<at::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_cuda(at::Tensor src, at::Tensor indptr, segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
at::optional<at::Tensor> out_opt, std::string reduce) { torch::optional<torch::Tensor> out_opt, std::string reduce) {
cudaSetDevice(src.get_device());
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch"); AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
...@@ -197,7 +199,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr, ...@@ -197,7 +199,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
src = src.contiguous(); src = src.contiguous();
auto reduce_dim = indptr.dim() - 1; auto reduce_dim = indptr.dim() - 1;
at::Tensor out; torch::Tensor out;
if (out_opt.has_value()) { if (out_opt.has_value()) {
out = out_opt.value().contiguous(); out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++) for (int i = 0; i < out.dim(); i++)
...@@ -208,13 +210,13 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr, ...@@ -208,13 +210,13 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
} else { } else {
sizes = src.sizes().vec(); sizes = src.sizes().vec();
sizes[reduce_dim] = indptr.size(reduce_dim) - 1; sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
out = at::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
} }
at::optional<at::Tensor> arg_out = at::nullopt; torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr; int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = at::full_like(out, src.size(reduce_dim), indptr.options()); arg_out = torch::full_like(out, src.size(reduce_dim), indptr.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>(); arg_out_data = arg_out.value().DATA_PTR<int64_t>();
} }
...@@ -382,10 +384,12 @@ __global__ void segment_coo_arg_broadcast_kernel( ...@@ -382,10 +384,12 @@ __global__ void segment_coo_arg_broadcast_kernel(
} }
} }
std::tuple<at::Tensor, at::optional<at::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, segment_coo_cuda(torch::Tensor src, torch::Tensor index, torch::Tensor out,
std::string reduce) { std::string reduce) {
cudaSetDevice(src.get_device());
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch"); AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
// Broadcasting `index` via `expand`. // Broadcasting `index` via `expand`.
...@@ -403,10 +407,10 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -403,10 +407,10 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
if (i != reduce_dim) if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch"); AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
at::optional<at::Tensor> arg_out = at::nullopt; torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr; int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = at::full_like(out, src.size(reduce_dim), index.options()); arg_out = torch::full_like(out, src.size(reduce_dim), index.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>(); arg_out_data = arg_out.value().DATA_PTR<int64_t>();
} }
...@@ -467,7 +471,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -467,7 +471,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
if (reduce2REDUCE.at(reduce) == MEAN) { if (reduce2REDUCE.at(reduce) == MEAN) {
auto sizes = index.sizes().vec(); auto sizes = index.sizes().vec();
sizes[reduce_dim] = out.size(reduce_dim); sizes[reduce_dim] = out.size(reduce_dim);
auto count = at::zeros(sizes, out.options()); auto count = torch::zeros(sizes, out.options());
AT_DISPATCH_ALL_TYPES(out.scalar_type(), "count_kernel", [&] { AT_DISPATCH_ALL_TYPES(out.scalar_type(), "count_kernel", [&] {
auto count_data = count.DATA_PTR<scalar_t>(); auto count_data = count.DATA_PTR<scalar_t>();
......
...@@ -5,15 +5,16 @@ from setuptools import setup, find_packages ...@@ -5,15 +5,16 @@ from setuptools import setup, find_packages
from sys import argv from sys import argv
import torch import torch
from torch.utils.cpp_extension import BuildExtension
from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import CppExtension, CUDAExtension, CUDA_HOME
# Windows users: Edit both of these to contain your VS include path, i.e. # Windows users: Edit both of these to contain your VS include path, i.e.:
# cxx_extra_compile_args = ['-I{VISUAL_STUDIO_DIR}\\include'] # cxx_extra_compile_args = ['-I{VISUAL_STUDIO_DIR}\\include']
# nvcc_extra_compile_args = [..., '-I{VISUAL_STUDIO_DIR}\\include'] # nvcc_extra_compile_args = [..., '-I{VISUAL_STUDIO_DIR}\\include']
cxx_extra_compile_args = [] cxx_extra_compile_args = []
nvcc_extra_compile_args = ['-arch=sm_35', '--expt-relaxed-constexpr'] nvcc_extra_compile_args = ['-arch=sm_35', '--expt-relaxed-constexpr']
# Windows users: Edit both of these to contain your VS library path, i.e. # Windows users: Edit both of these to contain your VS library path, i.e.:
# cxx_extra_link_args = ['/LIBPATH:{VISUAL_STUDIO_DIR}\\lib\\{x86|x64}'] # cxx_extra_link_args = ['/LIBPATH:{VISUAL_STUDIO_DIR}\\lib\\{x86|x64}']
# nvcc_extra_link_args = ['/LIBPATH:{VISUAL_STUDIO_DIR}\\lib\\{x86|x64}'] # nvcc_extra_link_args = ['/LIBPATH:{VISUAL_STUDIO_DIR}\\lib\\{x86|x64}']
cxx_extra_link_args = [] cxx_extra_link_args = []
...@@ -26,7 +27,9 @@ TORCH_MINOR = int(torch.__version__.split('.')[1]) ...@@ -26,7 +27,9 @@ TORCH_MINOR = int(torch.__version__.split('.')[1])
if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2):
cxx_extra_compile_args += ['-DVERSION_GE_1_3'] cxx_extra_compile_args += ['-DVERSION_GE_1_3']
nvcc_extra_compile_args += ['-DVERSION_GE_1_3'] nvcc_extra_compile_args += ['-DVERSION_GE_1_3']
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension} cmdclass = {
'build_ext': BuildExtension.with_options(no_python_abi_suffix=True)
}
ext_modules = [] ext_modules = []
exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cpu', '*.cpp'))] exts = [e.split(osp.sep)[-1][:-4] for e in glob(osp.join('cpu', '*.cpp'))]
......
import torch
torch.ops.load_library('torch_scatter/scatter_cpu.so')
torch.ops.load_library('torch_scatter/segment_cpu.so')
torch.ops.load_library('torch_scatter/gather_cpu.so')
try:
torch.ops.load_library('torch_scatter/scatter_cuda.so')
torch.ops.load_library('torch_scatter/segment_cuda.so')
torch.ops.load_library('torch_scatter/gather_cuda.so')
except OSError as e:
if torch.cuda.is_available():
raise e
from .add import scatter_add from .add import scatter_add
from .sub import scatter_sub from .sub import scatter_sub
from .mul import scatter_mul from .mul import scatter_mul
......
from torch.autograd import Function import torch
from torch_scatter.utils.ext import get_func
from torch_scatter.utils.gen import gen from torch_scatter.utils.gen import gen
class ScatterDiv(Function): class ScatterDiv(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, out, src, index, dim): def forward(ctx, out, src, index, dim):
func = get_func('scatter_div', src) if src.is_cuda:
func(src, index, out, dim) torch.ops.torch_scatter_cuda.scatter_div(src, index, out, dim)
else:
torch.ops.torch_scatter_cpu.scatter_div(src, index, out, dim)
ctx.mark_dirty(out) ctx.mark_dirty(out)
ctx.save_for_backward(out, src, index) ctx.save_for_backward(out, src, index)
......
import torch import torch
from torch_scatter import segment_cpu, gather_cpu
if torch.cuda.is_available():
from torch_scatter import gather_cuda, segment_cuda
def gat(is_cuda):
return gather_cuda if is_cuda else gather_cpu
def seg(is_cuda):
return segment_cuda if is_cuda else segment_cpu
class GatherCOO(torch.autograd.Function): class GatherCOO(torch.autograd.Function):
@staticmethod @staticmethod
...@@ -22,7 +9,10 @@ class GatherCOO(torch.autograd.Function): ...@@ -22,7 +9,10 @@ class GatherCOO(torch.autograd.Function):
ctx.src_size = list(src.size()) ctx.src_size = list(src.size())
ctx.save_for_backward(index) ctx.save_for_backward(index)
return gat(src.is_cuda).gather_coo(src, index, out) if src.is_cuda:
return torch.ops.torch_scatter_cuda.gather_coo(src, index, out)
else:
return torch.ops.torch_scatter_cpu.gather_coo(src, index, out)
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
...@@ -30,8 +20,12 @@ class GatherCOO(torch.autograd.Function): ...@@ -30,8 +20,12 @@ class GatherCOO(torch.autograd.Function):
grad_src = None grad_src = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_src, _ = seg(grad_out.is_cuda).segment_coo( if grad_out.is_cuda:
grad_out, index, grad_out.new_zeros(src_size), 'sum') grad_src, _ = torch.ops.torch_scatter_cuda.segment_coo(
grad_out, index, grad_out.new_zeros(src_size), 'sum')
else:
grad_src, _ = torch.ops.torch_scatter_cpu.segment_coo(
grad_out, index, grad_out.new_zeros(src_size), 'sum')
return grad_src, None, None return grad_src, None, None
...@@ -44,7 +38,10 @@ class GatherCSR(torch.autograd.Function): ...@@ -44,7 +38,10 @@ class GatherCSR(torch.autograd.Function):
ctx.src_size = list(src.size()) ctx.src_size = list(src.size())
ctx.save_for_backward(indptr) ctx.save_for_backward(indptr)
return gat(src.is_cuda).gather_csr(src, indptr, out) if src.is_cuda:
return torch.ops.torch_scatter_cuda.gather_csr(src, indptr, out)
else:
return torch.ops.torch_scatter_cpu.gather_csr(src, indptr, out)
@staticmethod @staticmethod
def backward(ctx, grad_out): def backward(ctx, grad_out):
...@@ -52,8 +49,12 @@ class GatherCSR(torch.autograd.Function): ...@@ -52,8 +49,12 @@ class GatherCSR(torch.autograd.Function):
grad_src = None grad_src = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
grad_src, _ = seg(grad_out.is_cuda).segment_csr( if grad_out.is_cuda:
grad_out, indptr, grad_out.new_empty(src_size), 'sum') grad_src, _ = torch.ops.torch_scatter_cuda.segment_csr(
grad_out, indptr, grad_out.new_empty(src_size), 'sum')
else:
grad_src, _ = torch.ops.torch_scatter_cpu.segment_csr(
grad_out, indptr, grad_out.new_empty(src_size), 'sum')
return grad_src, None, None return grad_src, None, None
......
import torch import torch
from torch.autograd import Function
from torch_scatter.utils.ext import get_func
from torch_scatter.utils.gen import gen from torch_scatter.utils.gen import gen
class ScatterMax(Function): class ScatterMax(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, out, src, index, dim): def forward(ctx, out, src, index, dim):
arg = index.new_full(out.size(), -1) arg = index.new_full(out.size(), -1)
func = get_func('scatter_max', src)
func(src, index, out, arg, dim) if src.is_cuda:
torch.ops.torch_scatter_cuda.scatter_max(src, index, out, arg, dim)
else:
torch.ops.torch_scatter_cpu.scatter_max(src, index, out, arg, dim)
ctx.mark_dirty(out) ctx.mark_dirty(out)
ctx.dim = dim ctx.dim = dim
......
import torch import torch
from torch.autograd import Function
from torch_scatter.utils.ext import get_func
from torch_scatter.utils.gen import gen from torch_scatter.utils.gen import gen
class ScatterMin(Function): class ScatterMin(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, out, src, index, dim): def forward(ctx, out, src, index, dim):
arg = index.new_full(out.size(), -1) arg = index.new_full(out.size(), -1)
func = get_func('scatter_min', src)
func(src, index, out, arg, dim) if src.is_cuda:
torch.ops.torch_scatter_cuda.scatter_min(src, index, out, arg, dim)
else:
torch.ops.torch_scatter_cpu.scatter_min(src, index, out, arg, dim)
ctx.mark_dirty(out) ctx.mark_dirty(out)
ctx.dim = dim ctx.dim = dim
......
from torch.autograd import Function import torch
from torch_scatter.utils.ext import get_func
from torch_scatter.utils.gen import gen from torch_scatter.utils.gen import gen
class ScatterMul(Function): class ScatterMul(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, out, src, index, dim): def forward(ctx, out, src, index, dim):
func = get_func('scatter_mul', src) if src.is_cuda:
func(src, index, out, dim) torch.ops.torch_scatter_cuda.scatter_mul(src, index, out, dim)
else:
torch.ops.torch_scatter_cpu.scatter_mul(src, index, out, dim)
ctx.mark_dirty(out) ctx.mark_dirty(out)
ctx.save_for_backward(out, src, index) ctx.save_for_backward(out, src, index)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment