Commit 5e2d0f1f authored by rusty1s's avatar rusty1s
Browse files

scatter cpu:

parent 64772d75
......@@ -45,11 +45,11 @@ All included operations are broadcastable, work on varying data types, and are i
## Installation
Ensure that at least PyTorch 1.1.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
Ensure that at least PyTorch 1.3.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
```
$ python -c "import torch; print(torch.__version__)"
>>> 1.1.0
>>> 1.3.0
$ echo $PATH
>>> /usr/local/cuda/bin:...
......
#include "scatter_cpu.h"
#include "index_info.h"
#include "reducer.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(index);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() == index.dim());
for (auto i = 0; i < index.dim() - 1; i++)
CHECK_INPUT(src.size(i) >= index.size(i));
if (dim < 0)
dim = src.dim() + dim;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else
sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
auto B = 1;
for (auto i = 0; i < dim; i++)
B *= src.size(i);
auto E = src.size(dim);
auto K = src.numel() / (B * E);
auto N = out.size(dim);
auto index_info = getTensorInfo<int64_t>(index);
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "scatter", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
int64_t i, idx;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
for (auto b = 0; b < B; b++) {
for (auto e = 0; e < E; e++) {
for (auto k = 0; k < K; k++) {
i = b * E * K + e * K + k;
idx = index_info.data[IndexToOffset<int64_t>::get(i, index_info)];
Reducer<scalar_t, REDUCE>::update(
out_data + b * N * K + idx * K + k, src_data[i],
arg_out_data + b * N * K + idx * K + k, e);
}
}
}
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
});
});
return std::make_tuple(out, arg_out);
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce);
......@@ -16,7 +16,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
CHECK_INPUT(src.dim() >= index.dim());
auto sizes = index.sizes().vec();
for (int i = 0; i < index.dim(); i++)
for (auto i = 0; i < index.dim(); i++)
sizes[i] = src.size(i);
index = index.expand(sizes);
......@@ -27,7 +27,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (int i = 0; i < out.dim(); i++)
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
......
......@@ -27,7 +27,7 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (int i = 0; i < out.dim(); i++)
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1);
......@@ -126,7 +126,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
std::vector<scalar_t> vals(K);
int64_t row_start, row_end;
for (int n = 0; n < N; n++) {
for (auto n = 0; n < N; n++) {
auto offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
row_start = indptr_info.data[offset];
row_end = indptr_info.data[offset + stride];
......
......@@ -106,9 +106,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
atomMul(address, val);
else if (REDUCE == DIV)
atomDiv(address, val);
else if (REDUCE == MIN && val < *address)
else if (REDUCE == MIN)
atomMin(address, val);
else if (REDUCE == MAX && val > *address)
else if (REDUCE == MAX)
atomMax(address, val);
}
};
#include "scatter_cuda.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
return std::make_tuple(src, optional_out);
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce);
#include <torch/script.h>
#include "cpu/scatter_cpu.h"
#ifdef WITH_CUDA
#include "cuda/scatter_cuda.h"
#endif
torch::Tensor broadcast(torch::Tensor src, torch::Tensor other, int64_t dim) {
if (dim < 0)
dim = other.dim() + dim;
if (src.dim() == 1)
for (auto i = 0; i < dim; i++)
src = src.unsqueeze(0);
for (auto i = src.dim(); i < other.dim(); i++)
src = src.unsqueeze(-1);
src = src.expand(other.sizes().vec());
return src;
}
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
scatter_fw(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return scatter_cuda(src, index, dim, optional_out, dim_size, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return scatter_cpu(src, index, dim, optional_out, dim_size, reduce);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class ScatterSum : public torch::autograd::Function<ScatterSum> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
ctx->save_for_backward({index});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto grad_in = torch::gather(grad_out, dim, index, false);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMean : public torch::autograd::Function<ScatterMean> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
auto old_index = index;
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
auto ones = torch::ones(old_index.sizes(), src.options());
result = scatter_fw(ones, old_index,
old_index.dim() <= dim ? old_index.dim() - 1 : dim,
torch::nullopt, out.size(dim), "sum");
auto count = std::get<0>(result);
count.clamp_(1);
count = broadcast(count, out, dim);
out.div_(count);
ctx->save_for_backward({index, count});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto count = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
count = torch::gather(count, dim, index, false);
auto grad_in = torch::gather(grad_out, dim, index, false);
grad_in.div_(count);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMin : public torch::autograd::Function<ScatterMin> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "min");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
src_shape[dim] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(dim, arg_out, grad_out);
grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
class ScatterMax : public torch::autograd::Function<ScatterMax> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
index = broadcast(index, src, dim);
auto result = scatter_fw(src, index, dim, optional_out, dim_size, "max");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
src_shape[dim] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(dim, arg_out, grad_out);
grad_in = grad_in.narrow(dim, 0, src_shape[dim] - 1);
return {grad_in, Variable(), Variable(), Variable(), Variable()};
}
};
torch::Tensor scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterSum::apply(src, index, dim, optional_out, dim_size)[0];
}
torch::Tensor scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return ScatterMean::apply(src, index, dim, optional_out, dim_size)[0];
}
std::tuple<torch::Tensor, torch::Tensor>
scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = ScatterMin::apply(src, index, dim, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
std::tuple<torch::Tensor, torch::Tensor>
scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = ScatterMax::apply(src, index, dim, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
static auto registry = torch::RegisterOperators()
.op("torch_scatter::scatter_sum", &scatter_sum)
.op("torch_scatter::scatter_mean", &scatter_mean)
.op("torch_scatter::scatter_min", &scatter_min)
.op("torch_scatter::scatter_max", &scatter_max);
......@@ -73,7 +73,7 @@ def test_backward(test, device):
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_gather_out(test, dtype, device):
def test_out(test, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
......@@ -93,7 +93,7 @@ def test_gather_out(test, dtype, device):
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_non_contiguous_segment(test, dtype, device):
def test_non_contiguous(test, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
......
from itertools import product
import pytest
import torch
from torch.autograd import gradcheck
import torch_scatter
from .utils import tensor, dtypes, devices
devices = ['cpu']
reductions = ['sum', 'add', 'mean', 'min', 'max']
tests = [
{
'src': [1, 3, 2, 4, 5, 6],
'index': [0, 1, 0, 1, 1, 3],
'dim': 0,
'sum': [3, 12, 0, 6],
'add': [3, 12, 0, 6],
'mean': [1.5, 4, 0, 6],
'min': [1, 3, 0, 6],
'arg_min': [0, 1, 6, 5],
'max': [2, 5, 0, 6],
'arg_max': [2, 4, 6, 5],
},
{
'src': [[1, 2], [5, 6], [3, 4], [7, 8], [9, 10], [11, 12]],
'index': [0, 1, 0, 1, 1, 3],
'dim': 0,
'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
'add': [[4, 6], [21, 24], [0, 0], [11, 12]],
'mean': [[2, 3], [7, 8], [0, 0], [11, 12]],
'min': [[1, 2], [5, 6], [0, 0], [11, 12]],
'arg_min': [[0, 0], [1, 1], [6, 6], [5, 5]],
'max': [[3, 4], [9, 10], [0, 0], [11, 12]],
'arg_max': [[2, 2], [4, 4], [6, 6], [5, 5]],
},
{
'src': [[1, 5, 3, 7, 9, 11], [2, 4, 8, 6, 10, 12]],
'index': [[0, 1, 0, 1, 1, 3], [0, 0, 1, 0, 1, 2]],
'dim': 1,
'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
'add': [[4, 21, 0, 11], [12, 18, 12, 0]],
'mean': [[2, 7, 0, 11], [4, 9, 12, 0]],
'min': [[1, 5, 0, 11], [2, 8, 12, 0]],
'arg_min': [[0, 1, 6, 5], [0, 2, 5, 6]],
'max': [[3, 9, 0, 11], [6, 10, 12, 0]],
'arg_max': [[2, 4, 6, 5], [3, 4, 5, 6]],
},
{
'src': [[[1, 2], [5, 6], [3, 4]], [[10, 11], [7, 9], [12, 13]]],
'index': [[0, 1, 0], [2, 0, 2]],
'dim': 1,
'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]],
'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]],
'arg_min': [[[0, 0], [1, 1], [3, 3]], [[1, 1], [3, 3], [0, 0]]],
'max': [[[3, 4], [5, 6], [0, 0]], [[7, 9], [0, 0], [12, 13]]],
'arg_max': [[[2, 2], [1, 1], [3, 3]], [[1, 1], [3, 3], [2, 2]]],
},
{
'src': [[1, 3], [2, 4]],
'index': [[0, 0], [0, 0]],
'dim': 1,
'sum': [[4], [6]],
'add': [[4], [6]],
'mean': [[2], [3]],
'min': [[1], [2]],
'arg_min': [[0], [0]],
'max': [[3], [4]],
'arg_max': [[1], [1]],
},
{
'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
'index': [[0, 0], [0, 0]],
'dim': 1,
'sum': [[[4, 4]], [[6, 6]]],
'add': [[[4, 4]], [[6, 6]]],
'mean': [[[2, 2]], [[3, 3]]],
'min': [[[1, 1]], [[2, 2]]],
'arg_min': [[[0, 0]], [[0, 0]]],
'max': [[[3, 3]], [[4, 4]]],
'arg_max': [[[1, 1]], [[1, 1]]],
},
]
@pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices))
def test_forward(test, reduce, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
dim = test['dim']
expected = tensor(test[reduce], dtype, device)
out = getattr(torch_scatter, f'scatter_{reduce}')(src, index, dim)
if isinstance(out, tuple):
out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected)
@pytest.mark.parametrize('test,reduce,device',
product(tests, reductions, devices))
def test_backward(test, reduce, device):
src = tensor(test['src'], torch.double, device)
src.requires_grad_()
index = tensor(test['index'], torch.long, device)
dim = test['dim']
assert gradcheck(torch_scatter.scatter,
(src, index, dim, None, None, reduce))
@pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices))
def test_out(test, reduce, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
dim = test['dim']
expected = tensor(test[reduce], dtype, device)
out = torch.full_like(expected, -2)
getattr(torch_scatter, f'scatter_{reduce}')(src, index, dim, out)
if reduce == 'sum' or reduce == 'add':
expected = expected - 2
elif reduce == 'mean':
expected = out # We can not really test this here.
elif reduce == 'min':
expected = expected.fill_(-2)
elif reduce == 'max':
expected[expected == 0] = -2
else:
raise ValueError
assert torch.all(out == expected)
@pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices))
def test_non_contiguous(test, reduce, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
dim = test['dim']
expected = tensor(test[reduce], dtype, device)
if src.dim() > 1:
src = src.transpose(0, 1).contiguous().transpose(0, 1)
if index.dim() > 1:
index = index.transpose(0, 1).contiguous().transpose(0, 1)
out = getattr(torch_scatter, f'scatter_{reduce}')(src, index, dim)
if isinstance(out, tuple):
out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected)
......@@ -7,7 +7,7 @@ import torch_scatter
from .utils import tensor, dtypes, devices
reductions = ['sum', 'mean', 'min', 'max']
reductions = ['sum', 'add', 'mean', 'min', 'max']
tests = [
{
......@@ -15,6 +15,7 @@ tests = [
'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6],
'sum': [3, 12, 0, 6],
'add': [3, 12, 0, 6],
'mean': [1.5, 4, 0, 6],
'min': [1, 3, 0, 6],
'arg_min': [0, 2, 6, 5],
......@@ -26,6 +27,7 @@ tests = [
'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6],
'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
'add': [[4, 6], [21, 24], [0, 0], [11, 12]],
'mean': [[2, 3], [7, 8], [0, 0], [11, 12]],
'min': [[1, 2], [5, 6], [0, 0], [11, 12]],
'arg_min': [[0, 0], [2, 2], [6, 6], [5, 5]],
......@@ -37,6 +39,7 @@ tests = [
'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
'add': [[4, 21, 0, 11], [12, 18, 12, 0]],
'mean': [[2, 7, 0, 11], [4, 9, 12, 0]],
'min': [[1, 5, 0, 11], [2, 8, 12, 0]],
'arg_min': [[0, 2, 6, 5], [0, 3, 5, 6]],
......@@ -48,6 +51,7 @@ tests = [
'index': [[0, 0, 1], [0, 2, 2]],
'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]],
'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]],
'arg_min': [[[0, 0], [2, 2], [3, 3]], [[0, 0], [3, 3], [1, 1]]],
......@@ -59,6 +63,7 @@ tests = [
'index': [[0, 0], [0, 0]],
'indptr': [[0, 2], [0, 2]],
'sum': [[4], [6]],
'add': [[4], [6]],
'mean': [[2], [3]],
'min': [[1], [2]],
'arg_min': [[0], [0]],
......@@ -70,6 +75,7 @@ tests = [
'index': [[0, 0], [0, 0]],
'indptr': [[0, 2], [0, 2]],
'sum': [[[4, 4]], [[6, 6]]],
'add': [[[4, 4]], [[6, 6]]],
'mean': [[[2, 2]], [[3, 3]]],
'min': [[[1, 1]], [[2, 2]]],
'arg_min': [[[0, 0]], [[0, 0]]],
......@@ -117,15 +123,13 @@ def test_backward(test, reduce, device):
@pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices))
def test_segment_out(test, reduce, dtype, device):
def test_out(test, reduce, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test[reduce], dtype, device)
size = list(src.size())
size[indptr.dim() - 1] = indptr.size(-1) - 1
out = src.new_full(size, -2)
out = torch.full_like(expected, -2)
getattr(torch_scatter, f'segment_{reduce}_csr')(src, indptr, out)
assert torch.all(out == expected)
......@@ -134,7 +138,7 @@ def test_segment_out(test, reduce, dtype, device):
getattr(torch_scatter, f'segment_{reduce}_coo')(src, index, out)
if reduce == 'sum':
if reduce == 'sum' or reduce == 'add':
expected = expected - 2
elif reduce == 'mean':
expected = out # We can not really test this here.
......@@ -150,7 +154,7 @@ def test_segment_out(test, reduce, dtype, device):
@pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices))
def test_non_contiguous_segment(test, reduce, dtype, device):
def test_non_contiguous(test, reduce, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
......
from .scatter import (scatter_sum, scatter_add, scatter_mean, scatter_min,
scatter_max, scatter)
from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr,
segment_min_csr, segment_max_csr, segment_csr,
gather_csr)
......@@ -8,12 +10,17 @@ from .segment_coo import (segment_sum_coo, segment_add_coo, segment_mean_coo,
__version__ = '2.0.0'
__all__ = [
'scatter_sum',
'scatter_add',
'scatter_mean',
'scatter_min',
'scatter_max',
'scatter',
'segment_sum_csr',
'segment_add_csr',
'segment_mean_csr',
'segment_min_csr',
'segment_max_csr',
'segment_max_csr',
'segment_csr',
'gather_csr',
'segment_sum_coo',
......@@ -21,7 +28,6 @@ __all__ = [
'segment_mean_coo',
'segment_min_coo',
'segment_max_coo',
'segment_max_coo',
'segment_coo',
'gather_coo',
'torch_scatter',
......
import os.path as osp
from typing import Optional, Tuple
import torch
torch.ops.load_library(
osp.join(osp.dirname(osp.abspath(__file__)), '_scatter.so'))
@torch.jit.script
def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.scatter_sum(src, index, dim, out, dim_size)
@torch.jit.script
def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.scatter_sum(src, index, dim, out, dim_size)
@torch.jit.script
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.scatter_mean(src, index, dim, out, dim_size)
@torch.jit.script
def scatter_min(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)
@torch.jit.script
def scatter_max(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)
@torch.jit.script
def scatter(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, dim_size: Optional[int] = None,
reduce: str = "sum") -> torch.Tensor:
if reduce == 'sum' or reduce == 'add':
return scatter_sum(src, index, dim, out, dim_size)
elif reduce == 'mean':
return scatter_mean(src, index, dim, out, dim_size)
elif reduce == 'min':
return scatter_min(src, index, dim, out, dim_size)[0]
elif reduce == 'max':
return scatter_max(src, index, dim, out, dim_size)[0]
else:
raise ValueError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment