Commit 0799bc08 authored by limm's avatar limm
Browse files

suport v2.1.0

parent 50e05e1e
#pragma once
#ifdef _WIN32
#if defined(torchscatter_EXPORTS)
#define SCATTER_API __declspec(dllexport)
#else
#define SCATTER_API __declspec(dllimport)
#endif
#else
#define SCATTER_API
#endif
#if (defined __cpp_inline_variables) || __cplusplus >= 201703L
#define SCATTER_INLINE_VARIABLE inline
#else
#ifdef _MSC_VER
#define SCATTER_INLINE_VARIABLE __declspec(selectany)
#else
#define SCATTER_INLINE_VARIABLE __attribute__((weak))
#endif
#endif
#ifdef WITH_PYTHON
#include <Python.h> #include <Python.h>
#endif
#include <torch/script.h> #include <torch/script.h>
#include "cpu/scatter_cpu.h" #include "cpu/scatter_cpu.h"
#include "macros.h"
#include "utils.h" #include "utils.h"
#ifdef WITH_HIP #ifdef WITH_CUDA
#include "hip/scatter_hip.h" #include "cuda/scatter_cuda.h"
#endif #endif
#ifdef _WIN32 #ifdef _WIN32
#ifdef WITH_HIP #ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__scatter_cuda(void) { return NULL; } PyMODINIT_FUNC PyInit__scatter_cuda(void) { return NULL; }
#else #else
PyMODINIT_FUNC PyInit__scatter_cpu(void) { return NULL; } PyMODINIT_FUNC PyInit__scatter_cpu(void) { return NULL; }
#endif #endif
#endif #endif
#endif
torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) { torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) {
if (src.dim() == 1) if (src.dim() == 1)
...@@ -31,7 +37,7 @@ scatter_fw(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -31,7 +37,7 @@ scatter_fw(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) { torch::optional<int64_t> dim_size, std::string reduce) {
if (src.device().is_cuda()) { if (src.device().is_cuda()) {
#ifdef WITH_HIP #ifdef WITH_CUDA
return scatter_cuda(src, index, dim, optional_out, dim_size, reduce); return scatter_cuda(src, index, dim, optional_out, dim_size, reduce);
#else #else
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
...@@ -226,25 +232,28 @@ public: ...@@ -226,25 +232,28 @@ public:
} }
}; };
torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim, SCATTER_API torch::Tensor
torch::optional<torch::Tensor> optional_out, scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<int64_t> dim_size) { torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0]; return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
} }
torch::Tensor scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim, SCATTER_API torch::Tensor
torch::optional<torch::Tensor> optional_out, scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<int64_t> dim_size) { torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterMul::apply(src, index, dim, optional_out, dim_size)[0]; return ScatterMul::apply(src, index, dim, optional_out, dim_size)[0];
} }
torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim, SCATTER_API torch::Tensor
torch::optional<torch::Tensor> optional_out, scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<int64_t> dim_size) { torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterMean::apply(src, index, dim, optional_out, dim_size)[0]; return ScatterMean::apply(src, index, dim, optional_out, dim_size)[0];
} }
std::tuple<torch::Tensor, torch::Tensor> SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim, scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) { torch::optional<int64_t> dim_size) {
...@@ -252,7 +261,7 @@ scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -252,7 +261,7 @@ scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
return std::make_tuple(result[0], result[1]); return std::make_tuple(result[0], result[1]);
} }
std::tuple<torch::Tensor, torch::Tensor> SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim, scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) { torch::optional<int64_t> dim_size) {
......
#pragma once #pragma once
#include <torch/extension.h> #include "extensions.h"
int64_t cuda_version(); namespace scatter {
SCATTER_API int64_t cuda_version() noexcept;
torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim, namespace detail {
torch::optional<torch::Tensor> optional_out, SCATTER_INLINE_VARIABLE int64_t _cuda_version = cuda_version();
torch::optional<int64_t> dim_size); } // namespace detail
} // namespace scatter
torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim, SCATTER_API torch::Tensor
torch::optional<torch::Tensor> optional_out, scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<int64_t> dim_size); torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
SCATTER_API torch::Tensor
scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
std::tuple<torch::Tensor, torch::Tensor> SCATTER_API torch::Tensor
scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim, scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size); torch::optional<int64_t> dim_size);
std::tuple<torch::Tensor, torch::Tensor> SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim, scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size); torch::optional<int64_t> dim_size);
torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index, SCATTER_API torch::Tensor
torch::optional<torch::Tensor> optional_out, segment_sum_coo(torch::Tensor src, torch::Tensor index,
torch::optional<int64_t> dim_size); torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index, SCATTER_API torch::Tensor
torch::optional<torch::Tensor> optional_out, segment_mean_coo(torch::Tensor src, torch::Tensor index,
torch::optional<int64_t> dim_size); torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size);
std::tuple<torch::Tensor, torch::Tensor> SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_min_coo(torch::Tensor src, torch::Tensor index, segment_min_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size); torch::optional<int64_t> dim_size);
std::tuple<torch::Tensor, torch::Tensor> SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_max_coo(torch::Tensor src, torch::Tensor index, segment_max_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size); torch::optional<int64_t> dim_size);
torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index, SCATTER_API torch::Tensor
torch::optional<torch::Tensor> optional_out); gather_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out);
torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr, SCATTER_API torch::Tensor
torch::optional<torch::Tensor> optional_out); segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr, SCATTER_API torch::Tensor
torch::optional<torch::Tensor> optional_out); segment_mean_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
std::tuple<torch::Tensor, torch::Tensor> SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_min_csr(torch::Tensor src, torch::Tensor indptr, segment_min_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out); torch::optional<torch::Tensor> optional_out);
std::tuple<torch::Tensor, torch::Tensor> SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_max_csr(torch::Tensor src, torch::Tensor indptr, segment_max_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out); torch::optional<torch::Tensor> optional_out);
torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr, SCATTER_API torch::Tensor
torch::optional<torch::Tensor> optional_out); gather_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out);
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/scatter_cpu.h"
#include "macros.h"
#include "utils.h"
#ifdef WITH_CUDA
#include "hip/scatter_cuda.h"
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__scatter_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__scatter_cpu(void) { return NULL; }
#endif
#endif
#endif
torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) {
if (src.dim() == 1)
for (auto i = 0; i < dim; i++)
src = src.unsqueeze(0);
for (auto i = src.dim(); i < other.dim(); i++)
src = src.unsqueeze(-1);
src = src.expand(other.sizes().vec());
return src;
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_fw(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return scatter_cuda(src, index, dim, optional_out, dim_size, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return scatter_cpu(src, index, dim, optional_out, dim_size, reduce);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class ScatterSum : public torch::autograd::Function<ScatterSum> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
ctx->save_for_backward({index});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::gather(grad_out, dim, index, false);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMul : public torch::autograd::Function<ScatterMul> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "mul");
auto out = std::get<0>(result);
ctx->save_for_backward({src, index, out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto src = saved[0];
auto index = saved[1];
auto out = saved[2];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::gather(grad_out * out, dim, index, false).div_(src);
grad_in.masked_fill_(grad_in.isnan(), 0);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMean : public torch::autograd::Function<ScatterMean> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
auto old_index = index;
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
auto ones = torch::ones(old_index.sizes(), src.options());
result = scatter_fw(ones, old_index,
old_index.dim() <= dim ? old_index.dim() - 1 : dim,
torch::nullopt, out.size(dim), "sum");
auto count = std::get<0>(result);
count.masked_fill_(count < 1, 1);
count = broadcast(count, out, dim);
if (out.is_floating_point())
out.true_divide_(count);
else
out.div_(count, "floor");
ctx->save_for_backward({index, count});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto count = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
count = torch::gather(count, dim, index, false);
auto grad_in = torch::gather(grad_out, dim, index, false);
grad_in.true_divide_(count);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMin : public torch::autograd::Function<ScatterMin> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "min");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[dim] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(dim, arg_out, grad_out);
grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMax : public torch::autograd::Function<ScatterMax> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "max");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[dim] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(dim, arg_out, grad_out);
grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
SCATTER_API torch::Tensor
scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
}
SCATTER_API torch::Tensor
scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterMul::apply(src, index, dim, optional_out, dim_size)[0];
}
SCATTER_API torch::Tensor
scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterMean::apply(src, index, dim, optional_out, dim_size)[0];
}
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = ScatterMin::apply(src, index, dim, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = ScatterMax::apply(src, index, dim, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
static auto registry = torch::RegisterOperators()
.op("torch_scatter::scatter_sum", &scatter_sum)
.op("torch_scatter::scatter_mul", &scatter_mul)
.op("torch_scatter::scatter_mean", &scatter_mean)
.op("torch_scatter::scatter_min", &scatter_min)
.op("torch_scatter::scatter_max", &scatter_max);
#ifdef WITH_PYTHON
#include <Python.h> #include <Python.h>
#endif
#include <torch/script.h> #include <torch/script.h>
#include "cpu/segment_coo_cpu.h" #include "cpu/segment_coo_cpu.h"
#include "macros.h"
#include "utils.h" #include "utils.h"
#ifdef WITH_HIP #ifdef WITH_CUDA
#include "hip/segment_coo_hip.h" #include "cuda/segment_coo_cuda.h"
#endif #endif
#ifdef _WIN32 #ifdef _WIN32
#ifdef WITH_HIP #ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__segment_coo_cuda(void) { return NULL; } PyMODINIT_FUNC PyInit__segment_coo_cuda(void) { return NULL; }
#else #else
PyMODINIT_FUNC PyInit__segment_coo_cpu(void) { return NULL; } PyMODINIT_FUNC PyInit__segment_coo_cpu(void) { return NULL; }
#endif #endif
#endif #endif
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_fw(torch::Tensor src, torch::Tensor index, segment_coo_fw(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) { torch::optional<int64_t> dim_size, std::string reduce) {
if (src.device().is_cuda()) { if (src.device().is_cuda()) {
#ifdef WITH_HIP #ifdef WITH_CUDA
return segment_coo_cuda(src, index, optional_out, dim_size, reduce); return segment_coo_cuda(src, index, optional_out, dim_size, reduce);
#else #else
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
...@@ -34,7 +40,7 @@ segment_coo_fw(torch::Tensor src, torch::Tensor index, ...@@ -34,7 +40,7 @@ segment_coo_fw(torch::Tensor src, torch::Tensor index,
torch::Tensor gather_coo_fw(torch::Tensor src, torch::Tensor index, torch::Tensor gather_coo_fw(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) { torch::optional<torch::Tensor> optional_out) {
if (src.device().is_cuda()) { if (src.device().is_cuda()) {
#ifdef WITH_HIP #ifdef WITH_CUDA
return gather_coo_cuda(src, index, optional_out); return gather_coo_cuda(src, index, optional_out);
#else #else
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
...@@ -195,19 +201,21 @@ public: ...@@ -195,19 +201,21 @@ public:
} }
}; };
torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index, SCATTER_API torch::Tensor
torch::optional<torch::Tensor> optional_out, segment_sum_coo(torch::Tensor src, torch::Tensor index,
torch::optional<int64_t> dim_size) { torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return SegmentSumCOO::apply(src, index, optional_out, dim_size)[0]; return SegmentSumCOO::apply(src, index, optional_out, dim_size)[0];
} }
torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index, SCATTER_API torch::Tensor
torch::optional<torch::Tensor> optional_out, segment_mean_coo(torch::Tensor src, torch::Tensor index,
torch::optional<int64_t> dim_size) { torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return SegmentMeanCOO::apply(src, index, optional_out, dim_size)[0]; return SegmentMeanCOO::apply(src, index, optional_out, dim_size)[0];
} }
std::tuple<torch::Tensor, torch::Tensor> SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_min_coo(torch::Tensor src, torch::Tensor index, segment_min_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) { torch::optional<int64_t> dim_size) {
...@@ -215,7 +223,7 @@ segment_min_coo(torch::Tensor src, torch::Tensor index, ...@@ -215,7 +223,7 @@ segment_min_coo(torch::Tensor src, torch::Tensor index,
return std::make_tuple(result[0], result[1]); return std::make_tuple(result[0], result[1]);
} }
std::tuple<torch::Tensor, torch::Tensor> SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_max_coo(torch::Tensor src, torch::Tensor index, segment_max_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) { torch::optional<int64_t> dim_size) {
...@@ -223,8 +231,9 @@ segment_max_coo(torch::Tensor src, torch::Tensor index, ...@@ -223,8 +231,9 @@ segment_max_coo(torch::Tensor src, torch::Tensor index,
return std::make_tuple(result[0], result[1]); return std::make_tuple(result[0], result[1]);
} }
torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index, SCATTER_API torch::Tensor
torch::optional<torch::Tensor> optional_out) { gather_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
return GatherCOO::apply(src, index, optional_out)[0]; return GatherCOO::apply(src, index, optional_out)[0];
} }
......
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/segment_coo_cpu.h"
#include "macros.h"
#include "utils.h"
#ifdef WITH_CUDA
#include "hip/segment_coo_cuda.h"
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__segment_coo_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__segment_coo_cpu(void) { return NULL; }
#endif
#endif
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_fw(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return segment_coo_cuda(src, index, optional_out, dim_size, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return segment_coo_cpu(src, index, optional_out, dim_size, reduce);
}
}
torch::Tensor gather_coo_fw(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return gather_coo_cuda(src, index, optional_out);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return gather_coo_cpu(src, index, optional_out);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SegmentSumCOO : public torch::autograd::Function<SegmentSumCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_coo_fw(src, index, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
ctx->save_for_backward({index});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_coo_fw(grad_out, index, grad_in);
return {grad_in, Variable(), Variable(), Variable()};
}
};
class SegmentMeanCOO : public torch::autograd::Function<SegmentMeanCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_coo_fw(src, index, optional_out, dim_size, "mean");
auto out = std::get<0>(result);
auto count = std::get<1>(result).value();
ctx->save_for_backward({index, count});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto count = saved[1];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_coo_fw(grad_out, index, grad_in);
count = gather_coo_fw(count, index, torch::nullopt);
for (auto i = 0; i < grad_out.dim() - index.dim(); i++)
count = count.unsqueeze(-1);
grad_in.true_divide_(count);
return {grad_in, Variable(), Variable(), Variable()};
}
};
class SegmentMinCOO : public torch::autograd::Function<SegmentMinCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_coo_fw(src, index, optional_out, dim_size, "min");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[index.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(index.dim() - 1, arg_out, grad_out);
grad_in =
grad_in.narrow(index.dim() - 1, 0, src_shape[index.dim() - 1] - 1);
return {grad_in, Variable(), Variable(), Variable()};
}
};
class SegmentMaxCOO : public torch::autograd::Function<SegmentMaxCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_coo_fw(src, index, optional_out, dim_size, "max");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[index.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(index.dim() - 1, arg_out, grad_out);
grad_in =
grad_in.narrow(index.dim() - 1, 0, src_shape[index.dim() - 1] - 1);
return {grad_in, Variable(), Variable(), Variable()};
}
};
class GatherCOO : public torch::autograd::Function<GatherCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto out = gather_coo_fw(src, index, optional_out);
ctx->save_for_backward({index});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::zeros(src_shape, grad_out.options());
segment_coo_fw(grad_out, index, grad_in, torch::nullopt, "sum");
return {grad_in, Variable(), Variable()};
}
};
SCATTER_API torch::Tensor
segment_sum_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return SegmentSumCOO::apply(src, index, optional_out, dim_size)[0];
}
SCATTER_API torch::Tensor
segment_mean_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return SegmentMeanCOO::apply(src, index, optional_out, dim_size)[0];
}
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_min_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = SegmentMinCOO::apply(src, index, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_max_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = SegmentMaxCOO::apply(src, index, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
SCATTER_API torch::Tensor
gather_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
return GatherCOO::apply(src, index, optional_out)[0];
}
static auto registry =
torch::RegisterOperators()
.op("torch_scatter::segment_sum_coo", &segment_sum_coo)
.op("torch_scatter::segment_mean_coo", &segment_mean_coo)
.op("torch_scatter::segment_min_coo", &segment_min_coo)
.op("torch_scatter::segment_max_coo", &segment_max_coo)
.op("torch_scatter::gather_coo", &gather_coo);
#ifdef WITH_PYTHON
#include <Python.h> #include <Python.h>
#endif
#include <torch/script.h> #include <torch/script.h>
#include "cpu/segment_csr_cpu.h" #include "cpu/segment_csr_cpu.h"
#include "macros.h"
#include "utils.h" #include "utils.h"
#ifdef WITH_HIP #ifdef WITH_CUDA
#include "hip/segment_csr_hip.h" #include "cuda/segment_csr_cuda.h"
#endif #endif
#ifdef _WIN32 #ifdef _WIN32
#ifdef WITH_HIP #ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__segment_csr_cuda(void) { return NULL; } PyMODINIT_FUNC PyInit__segment_csr_cuda(void) { return NULL; }
#else #else
PyMODINIT_FUNC PyInit__segment_csr_cpu(void) { return NULL; } PyMODINIT_FUNC PyInit__segment_csr_cpu(void) { return NULL; }
#endif #endif
#endif #endif
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_fw(torch::Tensor src, torch::Tensor indptr, segment_csr_fw(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
std::string reduce) { std::string reduce) {
if (src.device().is_cuda()) { if (src.device().is_cuda()) {
#ifdef WITH_HIP #ifdef WITH_CUDA
return segment_csr_cuda(src, indptr, optional_out, reduce); return segment_csr_cuda(src, indptr, optional_out, reduce);
#else #else
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
...@@ -34,7 +40,7 @@ segment_csr_fw(torch::Tensor src, torch::Tensor indptr, ...@@ -34,7 +40,7 @@ segment_csr_fw(torch::Tensor src, torch::Tensor indptr,
torch::Tensor gather_csr_fw(torch::Tensor src, torch::Tensor indptr, torch::Tensor gather_csr_fw(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) { torch::optional<torch::Tensor> optional_out) {
if (src.device().is_cuda()) { if (src.device().is_cuda()) {
#ifdef WITH_HIP #ifdef WITH_CUDA
return gather_csr_cuda(src, indptr, optional_out); return gather_csr_cuda(src, indptr, optional_out);
#else #else
AT_ERROR("Not compiled with CUDA support"); AT_ERROR("Not compiled with CUDA support");
...@@ -192,32 +198,35 @@ public: ...@@ -192,32 +198,35 @@ public:
} }
}; };
torch::Tensor segment_sum_csr(torch::Tensor src, torch::Tensor indptr, SCATTER_API torch::Tensor
torch::optional<torch::Tensor> optional_out) { segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return SegmentSumCSR::apply(src, indptr, optional_out)[0]; return SegmentSumCSR::apply(src, indptr, optional_out)[0];
} }
torch::Tensor segment_mean_csr(torch::Tensor src, torch::Tensor indptr, SCATTER_API torch::Tensor
torch::optional<torch::Tensor> optional_out) { segment_mean_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return SegmentMeanCSR::apply(src, indptr, optional_out)[0]; return SegmentMeanCSR::apply(src, indptr, optional_out)[0];
} }
std::tuple<torch::Tensor, torch::Tensor> SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_min_csr(torch::Tensor src, torch::Tensor indptr, segment_min_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) { torch::optional<torch::Tensor> optional_out) {
auto result = SegmentMinCSR::apply(src, indptr, optional_out); auto result = SegmentMinCSR::apply(src, indptr, optional_out);
return std::make_tuple(result[0], result[1]); return std::make_tuple(result[0], result[1]);
} }
std::tuple<torch::Tensor, torch::Tensor> SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_max_csr(torch::Tensor src, torch::Tensor indptr, segment_max_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) { torch::optional<torch::Tensor> optional_out) {
auto result = SegmentMaxCSR::apply(src, indptr, optional_out); auto result = SegmentMaxCSR::apply(src, indptr, optional_out);
return std::make_tuple(result[0], result[1]); return std::make_tuple(result[0], result[1]);
} }
torch::Tensor gather_csr(torch::Tensor src, torch::Tensor indptr, SCATTER_API torch::Tensor
torch::optional<torch::Tensor> optional_out) { gather_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return GatherCSR::apply(src, indptr, optional_out)[0]; return GatherCSR::apply(src, indptr, optional_out)[0];
} }
......
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "cpu/segment_csr_cpu.h"
#include "macros.h"
#include "utils.h"
#ifdef WITH_CUDA
#include "hip/segment_csr_cuda.h"
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__segment_csr_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__segment_csr_cpu(void) { return NULL; }
#endif
#endif
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_fw(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out,
std::string reduce) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return segment_csr_cuda(src, indptr, optional_out, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return segment_csr_cpu(src, indptr, optional_out, reduce);
}
}
torch::Tensor gather_csr_fw(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return gather_csr_cuda(src, indptr, optional_out);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return gather_csr_cpu(src, indptr, optional_out);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SegmentSumCSR : public torch::autograd::Function<SegmentSumCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto out = std::get<0>(segment_csr_fw(src, indptr, optional_out, "sum"));
ctx->save_for_backward({indptr});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_csr_fw(grad_out, indptr, grad_in);
return {grad_in, Variable(), Variable()};
}
};
class SegmentMeanCSR : public torch::autograd::Function<SegmentMeanCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto out = std::get<0>(segment_csr_fw(src, indptr, optional_out, "mean"));
ctx->save_for_backward({indptr});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options());
if (grad_in.numel() > 0) {
gather_csr_fw(grad_out, indptr, grad_in);
auto indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1);
auto indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1);
auto count = (indptr2 - indptr1).to(grad_in.options());
count = gather_csr_fw(count, indptr, torch::nullopt);
for (auto i = 0; i < grad_out.dim() - indptr.dim(); i++)
count = count.unsqueeze(-1);
grad_in.true_divide_(count);
}
return {grad_in, Variable(), Variable()};
}
};
class SegmentMinCSR : public torch::autograd::Function<SegmentMinCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_csr_fw(src, indptr, optional_out, "min");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({indptr, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto arg_out = saved[1];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[indptr.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out);
grad_in =
grad_in.narrow(indptr.dim() - 1, 0, src_shape[indptr.dim() - 1] - 1);
return {grad_in, Variable(), Variable()};
}
};
class SegmentMaxCSR : public torch::autograd::Function<SegmentMaxCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_csr_fw(src, indptr, optional_out, "max");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({indptr, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto arg_out = saved[1];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[indptr.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out);
grad_in =
grad_in.narrow(indptr.dim() - 1, 0, src_shape[indptr.dim() - 1] - 1);
return {grad_in, Variable(), Variable()};
}
};
class GatherCSR : public torch::autograd::Function<GatherCSR> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable indptr,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto out = gather_csr_fw(src, indptr, optional_out);
ctx->save_for_backward({indptr});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto indptr = saved[0];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options());
segment_csr_fw(grad_out, indptr, grad_in, "sum");
return {grad_in, Variable(), Variable()};
}
};
SCATTER_API torch::Tensor
segment_sum_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return SegmentSumCSR::apply(src, indptr, optional_out)[0];
}
SCATTER_API torch::Tensor
segment_mean_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return SegmentMeanCSR::apply(src, indptr, optional_out)[0];
}
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_min_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
auto result = SegmentMinCSR::apply(src, indptr, optional_out);
return std::make_tuple(result[0], result[1]);
}
SCATTER_API std::tuple<torch::Tensor, torch::Tensor>
segment_max_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
auto result = SegmentMaxCSR::apply(src, indptr, optional_out);
return std::make_tuple(result[0], result[1]);
}
SCATTER_API torch::Tensor
gather_csr(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out) {
return GatherCSR::apply(src, indptr, optional_out)[0];
}
static auto registry =
torch::RegisterOperators()
.op("torch_scatter::segment_sum_csr", &segment_sum_csr)
.op("torch_scatter::segment_mean_csr", &segment_mean_csr)
.op("torch_scatter::segment_min_csr", &segment_min_csr)
.op("torch_scatter::segment_max_csr", &segment_max_csr)
.op("torch_scatter::gather_csr", &gather_csr);
#ifdef WITH_PYTHON
#include <Python.h> #include <Python.h>
#endif
#include <torch/script.h> #include <torch/script.h>
#include "scatter.h"
#include "macros.h"
#ifdef WITH_HIP #ifdef WITH_CUDA
#include <hip/hip_runtime.h> #ifdef USE_ROCM
#include <hip/hip_version.h>
#else
#include <cuda.h>
#endif
#endif #endif
#ifdef _WIN32 #ifdef _WIN32
#ifdef WITH_HIP #ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__version_cuda(void) { return NULL; } PyMODINIT_FUNC PyInit__version_cuda(void) { return NULL; }
#else #else
PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; } PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
#endif #endif
#endif #endif
#endif
int64_t cuda_version() { namespace scatter {
#ifdef WITH_HIP SCATTER_API int64_t cuda_version() noexcept {
return TORCH_HIP_VERSION; #ifdef WITH_CUDA
#ifdef USE_ROCM
return HIP_VERSION;
#else
return CUDA_VERSION;
#endif
#else #else
return -1; return -1;
#endif #endif
} }
} // namespace scatter
static auto registry = static auto registry = torch::RegisterOperators().op(
torch::RegisterOperators().op("torch_scatter::cuda_version", &cuda_version); "torch_scatter::cuda_version", [] { return scatter::cuda_version(); });
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
#ifdef WITH_PYTHON
#include <Python.h>
#endif
#include <torch/script.h>
#include "scatter.h"
#include "macros.h"
#ifdef WITH_CUDA
#ifdef USE_ROCM
#include <hip/hip_version.h>
#else
#include <hip/hip_runtime.h>
#endif
#endif
#ifdef _WIN32
#ifdef WITH_PYTHON
#ifdef WITH_CUDA
PyMODINIT_FUNC PyInit__version_cuda(void) { return NULL; }
#else
PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; }
#endif
#endif
#endif
namespace scatter {
SCATTER_API int64_t cuda_version() noexcept {
#ifdef WITH_CUDA
#ifdef USE_ROCM
return HIP_VERSION;
#else
return DTK_VERSION;
#endif
#else
return -1;
#endif
}
} // namespace scatter
static auto registry = torch::RegisterOperators().op(
"torch_scatter::cuda_version", [] { return scatter::cuda_version(); });
SPHINXBUILD := sphinx-build
SPHINXPROJ := pytorch_scatter
SOURCEDIR := source
BUILDDIR := build
.PHONY: help Makefile
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)"
<!DOCTYPE html>
<html>
<head>
<title>Redirect</title>
<meta http-equiv="refresh" content="0; url=https://pytorch-scatter.readthedocs.io" />
</head>
</html>
https://download.pytorch.org/whl/cpu/torch-1.11.0%2Bcpu-cp38-cp38-linux_x86_64.whl
sphinx>=3
sphinx_rtd_theme
This diff is collapsed.
\def\indices{{0, 0, 1, 0, 2, 2, 3, 3}}
\def\inputs{{5, 1, 7, 2, 3, 2, 1, 3}}
\def\outputs{{8, 7, 5, 4}}
\def\colors{{"cyan", "orange", "olive", "magenta"}}
\def\numberInputs{7}
\def\numberOutputs{3}
\def\operation{add}
\input{template}
#!/bin/bash
files=(add sub mul div mean max min std)
for name in "${files[@]}"; do
pdflatex "$name"
pdf2svg "$name.pdf" "$name.svg"
done
This diff is collapsed.
\def\indices{{0, 0, 1, 0, 2, 2, 3, 3}}
\def\inputs{{5, 1, 7, 2, 3, 2, 1, 3}}
\def\outputs{{"$\frac{1}{10}$", "$\frac{1}{7}$", "$\frac{1}{6}$", "$\frac{1}{3}$"}}
\def\colors{{"cyan", "orange", "olive", "magenta"}}
\def\numberInputs{7}
\def\numberOutputs{3}
\def\operation{div}
\input{template}
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment