Commit 8a1f0741 authored by rusty1s's avatar rusty1s
Browse files

update

parent eb7da512
...@@ -3,7 +3,6 @@ sudo: enabled ...@@ -3,7 +3,6 @@ sudo: enabled
env: env:
global: global:
- TORCH_VERSION=1.4.0 - TORCH_VERSION=1.4.0
- CIBW_BUILD=cp36-*
# jobs: # jobs:
# - FORCE_CUDA=0 TORCH=${TORCH_VERSION}+cpu # - FORCE_CUDA=0 TORCH=${TORCH_VERSION}+cpu
# - FORCE_CUDA=1 CUDA_SHORT=9.2 CUDA=9.2.148-1 UBUNTU_VERSION=ubuntu1604 CUBLAS=cuda-cublas-dev-9-2 TORCH=${TORCH_VERSION}+cu92 # - FORCE_CUDA=1 CUDA_SHORT=9.2 CUDA=9.2.148-1 UBUNTU_VERSION=ubuntu1604 CUBLAS=cuda-cublas-dev-9-2 TORCH=${TORCH_VERSION}+cu92
...@@ -110,7 +109,7 @@ install: ...@@ -110,7 +109,7 @@ install:
script: script:
- flake8 . - flake8 .
- python3 setup.py test || python setup.py install - python3 setup.py test || python setup.py install
- python3 setup.py sdist bdist_wheel - python3 setup.py bdist_wheel || python3 setup.py bdist_wheel
- ls dist - ls dist
notifications: notifications:
email: false email: false
#include "scatter_cpu.h" #include "scatter_cpu.h"
// #include "index_info.h" #include "index_info.h"
// #include "reducer.h" #include "reducer.h"
// #include "utils.h" #include "utils.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) { torch::optional<int64_t> dim_size, std::string reduce) {
return std::make_tuple(src, optional_out); CHECK_CPU(src);
// CHECK_CPU(src); CHECK_CPU(index);
// CHECK_CPU(index); if (optional_out.has_value())
// if (optional_out.has_value()) CHECK_CPU(optional_out.value());
// CHECK_CPU(optional_out.value());
// CHECK_INPUT(src.dim() == index.dim()); CHECK_INPUT(src.dim() == index.dim());
// for (auto i = 0; i < index.dim() - 1; i++) for (auto i = 0; i < index.dim() - 1; i++)
// CHECK_INPUT(src.size(i) >= index.size(i)); CHECK_INPUT(src.size(i) >= index.size(i));
// src = src.contiguous(); src = src.contiguous();
// torch::Tensor out; torch::Tensor out;
// if (optional_out.has_value()) { if (optional_out.has_value()) {
// out = optional_out.value().contiguous(); out = optional_out.value().contiguous();
// for (auto i = 0; i < out.dim(); i++) for (auto i = 0; i < out.dim(); i++)
// if (i != dim) if (i != dim)
// CHECK_INPUT(src.size(i) == out.size(i)); CHECK_INPUT(src.size(i) == out.size(i));
// } else { } else {
// auto sizes = src.sizes().vec(); auto sizes = src.sizes().vec();
// if (dim_size.has_value()) if (dim_size.has_value())
// sizes[dim] = dim_size.value(); sizes[dim] = dim_size.value();
// else if (index.numel() == 0) else if (index.numel() == 0)
// sizes[dim] = 0; sizes[dim] = 0;
// else else
// sizes[dim] = 1 + *index.max().data_ptr<int64_t>(); sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
// out = torch::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
// } }
// torch::optional<torch::Tensor> arg_out = torch::nullopt; torch::optional<torch::Tensor> arg_out = torch::nullopt;
// int64_t *arg_out_data = nullptr; int64_t *arg_out_data = nullptr;
// if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
// arg_out = torch::full_like(out, src.size(dim), index.options()); arg_out = torch::full_like(out, src.size(dim), index.options());
// arg_out_data = arg_out.value().data_ptr<int64_t>(); arg_out_data = arg_out.value().data_ptr<int64_t>();
// } }
// if (index.numel() == 0) if (index.numel() == 0)
// return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
// auto B = 1; auto B = 1;
// for (auto i = 0; i < dim; i++) for (auto i = 0; i < dim; i++)
// B *= src.size(i); B *= src.size(i);
// auto E = src.size(dim); auto E = src.size(dim);
// auto K = src.numel() / (B * E); auto K = src.numel() / (B * E);
// auto N = out.size(dim); auto N = out.size(dim);
// auto index_info = getTensorInfo<int64_t>(index); auto index_info = getTensorInfo<int64_t>(index);
// AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] { AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] {
// auto src_data = src.data_ptr<scalar_t>(); auto src_data = src.data_ptr<scalar_t>();
// auto out_data = out.data_ptr<scalar_t>(); auto out_data = out.data_ptr<scalar_t>();
// int64_t i, idx; int64_t i, idx;
// AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
// if (!optional_out.has_value()) if (!optional_out.has_value())
// out.fill_(Reducer<scalar_t>::init(REDUCE)); out.fill_(Reducer<scalar_t>::init(REDUCE));
// for (auto b = 0; b < B; b++) { for (auto b = 0; b < B; b++) {
// for (auto e = 0; e < E; e++) { for (auto e = 0; e < E; e++) {
// for (auto k = 0; k < K; k++) { for (auto k = 0; k < K; k++) {
// i = b * E * K + e * K + k; i = b * E * K + e * K + k;
// idx = index_info.data[IndexToOffset<int64_t>::get(i, idx = index_info.data[IndexToOffset<int64_t>::get(i, index_info)];
// index_info)]; Reducer<scalar_t>::update( Reducer<scalar_t>::update(
// REDUCE, out_data + b * N * K + idx * K + k, src_data[i], REDUCE, out_data + b * N * K + idx * K + k, src_data[i],
// arg_out_data + b * N * K + idx * K + k, e); arg_out_data + b * N * K + idx * K + k, e);
// } }
// } }
// } }
// if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX)) if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
// out.masked_fill_(out == Reducer<scalar_t>::init(REDUCE), out.masked_fill_(out == Reducer<scalar_t>::init(REDUCE), (scalar_t)0);
// (scalar_t)0); });
// }); });
// });
// return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
} }
#include <Python.h> #include <Python.h>
#include <torch/script.h> #include <torch/script.h>
// #include "cpu/scatter_cpu.h" #include "cpu/scatter_cpu.h"
// #include "utils.h" #include "utils.h"
// #ifdef WITH_CUDA #ifdef WITH_CUDA
// #include <cuda.h> #include "cuda/scatter_cuda.h"
// #include "cuda/scatter_cuda.h" #endif
// #endif
#ifdef _WIN32 #ifdef _WIN32
PyMODINIT_FUNC PyInit__scatter(void) { return NULL; } PyMODINIT_FUNC PyInit__scatter(void) { return NULL; }
#endif #endif
torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) {
if (src.dim() == 1)
for (auto i = 0; i < dim; i++)
src = src.unsqueeze(0);
for (auto i = src.dim(); i < other.dim(); i++)
src = src.unsqueeze(-1);
src = src.expand(other.sizes().vec());
return src;
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_fw(torch::Tensor src, torch::Tensor index, int64_t dim, scatter_fw(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) { torch::optional<int64_t> dim_size, std::string reduce) {
return std::make_tuple(src, optional_out); if (src.device().is_cuda()) {
// if (src.device().is_cuda()) { #ifdef WITH_CUDA
// #ifdef WITH_CUDA return scatter_cuda(src, index, dim, optional_out, dim_size, reduce);
// return scatter_cuda(src, index, dim, optional_out, dim_size, reduce); #else
// #else AT_ERROR("Not compiled with CUDA support");
// AT_ERROR("Not compiled with CUDA support"); #endif
// #endif } else {
// } else { return scatter_cpu(src, index, dim, optional_out, dim_size, reduce);
// return scatter_cpu(src, index, dim, optional_out, dim_size, reduce); }
// } }
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class ScatterSum : public torch::autograd::Function<ScatterSum> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
ctx->save_for_backward({index});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::gather(grad_out, dim, index, false);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMean : public torch::autograd::Function<ScatterMean> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
auto old_index = index;
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
auto ones = torch::ones(old_index.sizes(), src.options());
result = scatter_fw(ones, old_index,
old_index.dim() <= dim ? old_index.dim() - 1 : dim,
torch::nullopt, out.size(dim), "sum");
auto count = std::get<0>(result);
count.clamp_(1);
count = broadcast(count, out, dim);
out.div_(count);
ctx->save_for_backward({index, count});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto count = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
count = torch::gather(count, dim, index, false);
auto grad_in = torch::gather(grad_out, dim, index, false);
grad_in.div_(count);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMin : public torch::autograd::Function<ScatterMin> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "min");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[dim] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(dim, arg_out, grad_out);
grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMax : public torch::autograd::Function<ScatterMax> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "max");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[dim] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(dim, arg_out, grad_out);
grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
}
torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterMean::apply(src, index, dim, optional_out, dim_size)[0];
}
std::tuple<torch::Tensor, torch::Tensor>
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = ScatterMin::apply(src, index, dim, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
std::tuple<torch::Tensor, torch::Tensor>
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = ScatterMax::apply(src, index, dim, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
} }
static auto registry = static auto registry = torch::RegisterOperators()
torch::RegisterOperators().op("torch_scatter::scatter_fw", &scatter_fw); .op("torch_scatter::scatter_sum", &scatter_sum)
.op("torch_scatter::scatter_mean", &scatter_mean)
.op("torch_scatter::scatter_min", &scatter_min)
.op("torch_scatter::scatter_max", &scatter_max);
#include <torch/script.h>
#include "cpu/scatter_cpu.h"
#include "utils.h"
#ifdef WITH_CUDA
#include "cuda/scatter_cuda.h"
#endif
torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) {
if (src.dim() == 1)
for (auto i = 0; i < dim; i++)
src = src.unsqueeze(0);
for (auto i = src.dim(); i < other.dim(); i++)
src = src.unsqueeze(-1);
src = src.expand(other.sizes().vec());
return src;
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_fw(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return scatter_cuda(src, index, dim, optional_out, dim_size, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return scatter_cpu(src, index, dim, optional_out, dim_size, reduce);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class ScatterSum : public torch::autograd::Function<ScatterSum> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
ctx->save_for_backward({index});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::gather(grad_out, dim, index, false);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMean : public torch::autograd::Function<ScatterMean> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
auto old_index = index;
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
auto ones = torch::ones(old_index.sizes(), src.options());
result = scatter_fw(ones, old_index,
old_index.dim() <= dim ? old_index.dim() - 1 : dim,
torch::nullopt, out.size(dim), "sum");
auto count = std::get<0>(result);
count.clamp_(1);
count = broadcast(count, out, dim);
out.div_(count);
ctx->save_for_backward({index, count});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto count = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
count = torch::gather(count, dim, index, false);
auto grad_in = torch::gather(grad_out, dim, index, false);
grad_in.div_(count);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMin : public torch::autograd::Function<ScatterMin> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "min");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[dim] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(dim, arg_out, grad_out);
grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMax : public torch::autograd::Function<ScatterMax> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "max");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
src_shape[dim] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(dim, arg_out, grad_out);
grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
}
torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterMean::apply(src, index, dim, optional_out, dim_size)[0];
}
std::tuple<torch::Tensor, torch::Tensor>
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = ScatterMin::apply(src, index, dim, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
std::tuple<torch::Tensor, torch::Tensor>
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = ScatterMax::apply(src, index, dim, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
static auto registry = torch::RegisterOperators()
.op("torch_scatter::scatter_sum", &scatter_sum)
.op("torch_scatter::scatter_mean", &scatter_mean)
.op("torch_scatter::scatter_min", &scatter_min)
.op("torch_scatter::scatter_max", &scatter_max);
#include <Python.h>
#include <torch/script.h> #include <torch/script.h>
#include "cpu/segment_coo_cpu.h" #include "cpu/segment_coo_cpu.h"
...@@ -7,6 +8,10 @@ ...@@ -7,6 +8,10 @@
#include "cuda/segment_coo_cuda.h" #include "cuda/segment_coo_cuda.h"
#endif #endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__scatter(void) { return NULL; }
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_fw(torch::Tensor src, torch::Tensor index, segment_coo_fw(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
......
#include <Python.h>
#include <torch/script.h> #include <torch/script.h>
#include "cpu/segment_csr_cpu.h" #include "cpu/segment_csr_cpu.h"
...@@ -7,6 +8,10 @@ ...@@ -7,6 +8,10 @@
#include "cuda/segment_csr_cuda.h" #include "cuda/segment_csr_cuda.h"
#endif #endif
#ifdef _WIN32
PyMODINIT_FUNC PyInit__scatter(void) { return NULL; }
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>> std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_csr_fw(torch::Tensor src, torch::Tensor indptr, segment_csr_fw(torch::Tensor src, torch::Tensor indptr,
torch::optional<torch::Tensor> optional_out, torch::optional<torch::Tensor> optional_out,
......
...@@ -57,7 +57,7 @@ tests_require = ['pytest', 'pytest-cov'] ...@@ -57,7 +57,7 @@ tests_require = ['pytest', 'pytest-cov']
setup( setup(
name='torch_scatter', name='torch_scatter',
version='2.0.2', version='2.0.3',
author='Matthias Fey', author='Matthias Fey',
author_email='matthias.fey@tu-dortmund.de', author_email='matthias.fey@tu-dortmund.de',
url='https://github.com/rusty1s/pytorch_scatter', url='https://github.com/rusty1s/pytorch_scatter',
......
...@@ -9,7 +9,7 @@ from .segment_coo import (segment_sum_coo, segment_add_coo, segment_mean_coo, ...@@ -9,7 +9,7 @@ from .segment_coo import (segment_sum_coo, segment_add_coo, segment_mean_coo,
from .composite import (scatter_std, scatter_logsumexp, scatter_softmax, from .composite import (scatter_std, scatter_logsumexp, scatter_softmax,
scatter_log_softmax) scatter_log_softmax)
__version__ = '2.0.2' __version__ = '2.0.3'
__all__ = [ __all__ = [
'scatter_sum', 'scatter_sum',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment