Commit bb87ec65 authored by rusty1s's avatar rusty1s
Browse files

coo cpu implementation

parent 0c887ffc
...@@ -166,7 +166,7 @@ torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index, ...@@ -166,7 +166,7 @@ torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
if (e < E - 1) { if (e < E - 1) {
next_idx = index_info.data[offset + (e + 1) * stride]; next_idx = index_info.data[offset + (e + 1) * stride];
CHECK_INPUT(idx < E && idx <= next_idx); CHECK_INPUT(idx <= next_idx);
if (idx != next_idx) { if (idx != next_idx) {
idx = next_idx; idx = next_idx;
......
#include "segment_coo_cpu.h"
#include "index_info.h"
#include "reducer.h"
#include "utils.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cpu(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
CHECK_CPU(src);
CHECK_CPU(index);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= index.dim());
auto sizes = index.sizes().vec();
for (int i = 0; i < index.dim(); i++)
sizes[i] = src.size(i);
index = index.expand(sizes);
auto dim = index.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else
sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
out = torch::empty(sizes, src.options());
}
torch::optional<torch::Tensor> arg_out = torch::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
torch::optional<torch::Tensor> count = torch::nullopt;
if (reduce2REDUCE.at(reduce) == MEAN) {
auto sizes = index.sizes().vec();
sizes[dim] = out.size(dim);
count = torch::zeros(sizes, out.options());
}
auto B = index.numel() / src.size(dim);
auto E = src.size(dim);
auto K = src.numel() / index.numel();
auto N = out.size(dim);
auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1];
std::vector<int64_t> args(K);
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
scalar_t *count_data = nullptr;
std::vector<scalar_t> vals(K);
int64_t idx, next_idx, row_start;
AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
if (REDUCE == MEAN)
count_data = count.value().data_ptr<scalar_t>();
for (auto b = 0; b < B; b++) {
auto offset = IndexToOffset<int64_t>::get(b * E, index_info);
idx = index_info.data[offset];
for (auto k = 0; k < K; k++)
vals[k] = out_data[b * N * K + k];
row_start = 0;
for (auto e = 0; e < E; e++) {
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[b * E * K + e * K + k], &args[k], e);
if (e == E - 1) {
for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(
out_data + b * N * K + idx * K + k, vals[k],
arg_out_data + b * N * K + idx * K + k, args[k],
e + 1 - row_start);
if (REDUCE == MEAN)
count_data[b * N + idx] = (scalar_t)(e + 1 - row_start);
} else {
next_idx = index_info.data[offset + (e + 1) * stride];
assert(idx <= next_idx);
if (idx != next_idx) {
for (auto k = 0; k < K; k++) {
Reducer<scalar_t, REDUCE>::write(
out_data + b * N * K + idx * K + k, vals[k],
arg_out_data + b * N * K + idx * K + k, args[k],
e + 1 - row_start);
vals[k] = out_data[b * N * K + next_idx * K + k];
}
if (REDUCE == MEAN)
count_data[b * N + idx] = (scalar_t)(e + 1 - row_start);
row_start = e + 1;
}
idx = next_idx;
}
}
}
if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX))
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
if (REDUCE == MEAN)
arg_out = count;
});
});
return std::make_tuple(out, arg_out);
}
torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
CHECK_CPU(src);
CHECK_CPU(index);
if (optional_out.has_value())
CHECK_CPU(optional_out.value());
CHECK_INPUT(src.dim() >= index.dim());
for (auto i = 0; i < index.dim() - 1; i++)
CHECK_INPUT(src.size(i) == index.size(i));
auto dim = index.dim() - 1;
src = src.contiguous();
torch::Tensor out;
if (optional_out.has_value()) {
out = optional_out.value().contiguous();
for (auto i = 0; i < src.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
sizes[dim] = index.size(dim);
out = torch::empty(sizes, src.options());
}
auto B = index.numel() / out.size(dim);
auto E = index.size(dim);
auto K = out.numel() / index.numel();
auto N = src.size(dim);
auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo", [&] {
auto src_data = src.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
std::vector<scalar_t> vals(K);
int64_t idx, next_idx;
for (auto b = 0; b < B; b++) {
auto offset = IndexToOffset<int64_t>::get(b * E, index_info);
idx = index_info.data[offset];
for (auto k = 0; k < K; k++)
vals[k] = src_data[b * N * K + idx * K + k];
for (auto e = 0; e < E; e++) {
for (auto k = 0; k < K; k++)
out_data[b * E * K + e * K + k] = vals[k];
if (e < E - 1) {
next_idx = index_info.data[offset + (e + 1) * stride];
CHECK_INPUT(idx <= next_idx);
if (idx != next_idx) {
idx = next_idx;
for (auto k = 0; k < K; k++)
vals[k] = src_data[b * N * K + idx * K + k];
}
}
}
}
});
return out;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cpu(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce);
torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out);
...@@ -67,12 +67,10 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, ...@@ -67,12 +67,10 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
for (auto k = 0; k < K; k++) for (auto k = 0; k < K; k++)
vals[k] = Reducer<scalar_t, REDUCE>::init(); vals[k] = Reducer<scalar_t, REDUCE>::init();
for (auto e = row_start; e < row_end; e++) { for (auto e = row_start; e < row_end; e++)
CHECK_INPUT(e < E);
for (auto k = 0; k < K; k++) for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::update( Reducer<scalar_t, REDUCE>::update(
&vals[k], src_data[offset + e * K + k], &args[k], e); &vals[k], src_data[offset + e * K + k], &args[k], e);
}
for (auto k = 0; k < K; k++) for (auto k = 0; k < K; k++)
Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k], Reducer<scalar_t, REDUCE>::write(out_data + n * K + k, vals[k],
......
#include "segment_coo_cuda.h"
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cuda(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
return std::make_tuple(src, optional_out);
}
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
return src;
}
#pragma once
#include <torch/extension.h>
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_cuda(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce);
torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out);
#include <torch/script.h>
#include "cpu/segment_coo_cpu.h"
#ifdef WITH_CUDA
#include "cuda/segment_coo_cuda.h"
#endif
std::tuple<torch::Tensor, torch::optional<torch::Tensor>>
segment_coo_fw(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size, std::string reduce) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return segment_coo_cuda(src, index, optional_out, dim_size, reduce);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return segment_coo_cpu(src, index, optional_out, dim_size, reduce);
}
}
torch::Tensor gather_coo_fw(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return gather_coo_cuda(src, index, optional_out);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return gather_coo_cpu(src, index, optional_out);
}
}
using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;
class SegmentSumCOO : public torch::autograd::Function<SegmentSumCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_coo_fw(src, index, optional_out, dim_size, "sum");
auto out = std::get<0>(result);
ctx->save_for_backward({index});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_coo_fw(grad_out, index, grad_in);
return {grad_in, Variable(), Variable(), Variable()};
}
};
class SegmentMeanCOO : public torch::autograd::Function<SegmentMeanCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_coo_fw(src, index, optional_out, dim_size, "mean");
auto out = std::get<0>(result);
auto count = std::get<1>(result).value();
ctx->save_for_backward({index, count});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto count = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_coo_fw(grad_out, index, grad_in);
count = gather_coo_fw(count, index, torch::nullopt);
for (auto i = 0; i < grad_out.dim() - index.dim(); i++)
count = count.unsqueeze(-1);
grad_in.div_(count);
return {grad_in, Variable(), Variable(), Variable()};
}
};
class SegmentMinCOO : public torch::autograd::Function<SegmentMinCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_coo_fw(src, index, optional_out, dim_size, "min");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
src_shape[index.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(index.dim() - 1, arg_out, grad_out);
grad_in =
grad_in.narrow(index.dim() - 1, 0, src_shape[index.dim() - 1] - 1);
return {grad_in, Variable(), Variable(), Variable()};
}
};
class SegmentMaxCOO : public torch::autograd::Function<SegmentMaxCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
ctx->saved_data["src_shape"] = src.sizes();
auto result = segment_coo_fw(src, index, optional_out, dim_size, "max");
auto out = std::get<0>(result);
auto arg_out = std::get<1>(result).value();
ctx->save_for_backward({index, arg_out});
ctx->mark_non_differentiable({arg_out});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out, arg_out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto arg_out = saved[1];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
src_shape[index.dim() - 1] += 1;
auto grad_in = torch::zeros(src_shape, grad_out.options());
grad_in.scatter_(index.dim() - 1, arg_out, grad_out);
grad_in =
grad_in.narrow(index.dim() - 1, 0, src_shape[index.dim() - 1] - 1);
return {grad_in, Variable(), Variable(), Variable()};
}
};
class GatherCOO : public torch::autograd::Function<GatherCOO> {
public:
static variable_list forward(AutogradContext *ctx, Variable src,
Variable index,
torch::optional<Variable> optional_out) {
ctx->saved_data["src_shape"] = src.sizes();
auto out = gather_coo_fw(src, index, optional_out);
ctx->save_for_backward({index});
if (optional_out.has_value())
ctx->mark_dirty({optional_out.value()});
return {out};
}
static variable_list backward(AutogradContext *ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto index = saved[0];
auto src_shape = ctx->saved_data["src_shape"].toIntVector();
auto grad_in = torch::zeros(src_shape, grad_out.options());
segment_coo_fw(grad_out, index, grad_in, torch::nullopt, "sum");
return {grad_in, Variable(), Variable()};
}
};
torch::Tensor segment_sum_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return SegmentSumCOO::apply(src, index, optional_out, dim_size)[0];
}
torch::Tensor segment_mean_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
return SegmentMeanCOO::apply(src, index, optional_out, dim_size)[0];
}
std::tuple<torch::Tensor, torch::Tensor>
segment_min_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = SegmentMinCOO::apply(src, index, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
std::tuple<torch::Tensor, torch::Tensor>
segment_max_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out,
torch::optional<int64_t> dim_size) {
auto result = SegmentMaxCOO::apply(src, index, optional_out, dim_size);
return std::make_tuple(result[0], result[1]);
}
torch::Tensor gather_coo(torch::Tensor src, torch::Tensor index,
torch::optional<torch::Tensor> optional_out) {
return GatherCOO::apply(src, index, optional_out)[0];
}
static auto registry =
torch::RegisterOperators()
.op("torch_scatter::segment_sum_coo", &segment_sum_coo)
.op("torch_scatter::segment_mean_coo", &segment_mean_coo)
.op("torch_scatter::segment_min_coo", &segment_min_coo)
.op("torch_scatter::segment_max_coo", &segment_max_coo)
.op("torch_scatter::gather_coo", &gather_coo);
...@@ -3,10 +3,12 @@ from itertools import product ...@@ -3,10 +3,12 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_scatter import gather_coo, gather_csr from torch_scatter import gather_csr, gather_coo
from .utils import tensor, dtypes, devices from .utils import tensor, dtypes, devices
devices = ['cpu']
tests = [ tests = [
{ {
'src': [1, 2, 3, 4], 'src': [1, 2, 3, 4],
...@@ -54,10 +56,10 @@ def test_forward(test, dtype, device): ...@@ -54,10 +56,10 @@ def test_forward(test, dtype, device):
indptr = tensor(test['indptr'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test['expected'], dtype, device) expected = tensor(test['expected'], dtype, device)
out = gather_coo(src, index) out = gather_csr(src, indptr)
assert torch.all(out == expected) assert torch.all(out == expected)
out = gather_csr(src, indptr) out = gather_coo(src, index)
assert torch.all(out == expected) assert torch.all(out == expected)
...@@ -68,8 +70,8 @@ def test_backward(test, device): ...@@ -68,8 +70,8 @@ def test_backward(test, device):
index = tensor(test['index'], torch.long, device) index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device)
assert gradcheck(gather_coo, (src, index, None)) is True
assert gradcheck(gather_csr, (src, indptr, None)) is True assert gradcheck(gather_csr, (src, indptr, None)) is True
assert gradcheck(gather_coo, (src, index, None)) is True
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
...@@ -83,12 +85,12 @@ def test_gather_out(test, dtype, device): ...@@ -83,12 +85,12 @@ def test_gather_out(test, dtype, device):
size[index.dim() - 1] = index.size(-1) size[index.dim() - 1] = index.size(-1)
out = src.new_full(size, -2) out = src.new_full(size, -2)
gather_coo(src, index, out) gather_csr(src, indptr, out)
assert torch.all(out == expected) assert torch.all(out == expected)
out.fill_(-2) out.fill_(-2)
gather_csr(src, indptr, out) gather_coo(src, index, out)
assert torch.all(out == expected) assert torch.all(out == expected)
...@@ -106,8 +108,8 @@ def test_non_contiguous_segment(test, dtype, device): ...@@ -106,8 +108,8 @@ def test_non_contiguous_segment(test, dtype, device):
if indptr.dim() > 1: if indptr.dim() > 1:
indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1) indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1)
out = gather_coo(src, index) out = gather_csr(src, indptr)
assert torch.all(out == expected) assert torch.all(out == expected)
out = gather_csr(src, indptr) out = gather_coo(src, index)
assert torch.all(out == expected) assert torch.all(out == expected)
...@@ -3,12 +3,12 @@ from itertools import product ...@@ -3,12 +3,12 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_scatter import segment_csr import torch_scatter
from .utils import tensor, dtypes, devices from .utils import tensor, dtypes, devices
reductions = ['sum', 'mean', 'min', 'max'] reductions = ['sum', 'mean', 'min', 'max']
grad_reductions = ['sum', 'mean'] devices = ['cpu']
tests = [ tests = [
{ {
...@@ -88,14 +88,14 @@ def test_forward(test, reduce, dtype, device): ...@@ -88,14 +88,14 @@ def test_forward(test, reduce, dtype, device):
indptr = tensor(test['indptr'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test[reduce], dtype, device) expected = tensor(test[reduce], dtype, device)
# out = segment_coo(src, index, reduce=reduce) out = getattr(torch_scatter, f'segment_{reduce}_csr')(src, indptr)
# if isinstance(out, tuple): if isinstance(out, tuple):
# out, arg_out = out out, arg_out = out
# arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
# assert torch.all(arg_out == arg_expected) assert torch.all(arg_out == arg_expected)
# assert torch.all(out == expected) assert torch.all(out == expected)
out = segment_csr(src, indptr, reduce=reduce) out = getattr(torch_scatter, f'segment_{reduce}_coo')(src, index)
if isinstance(out, tuple): if isinstance(out, tuple):
out, arg_out = out out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
...@@ -104,15 +104,16 @@ def test_forward(test, reduce, dtype, device): ...@@ -104,15 +104,16 @@ def test_forward(test, reduce, dtype, device):
@pytest.mark.parametrize('test,reduce,device', @pytest.mark.parametrize('test,reduce,device',
product(tests, grad_reductions, devices)) product(tests, reductions, devices))
def test_backward(test, reduce, device): def test_backward(test, reduce, device):
src = tensor(test['src'], torch.double, device) src = tensor(test['src'], torch.double, device)
src.requires_grad_() src.requires_grad_()
index = tensor(test['index'], torch.long, device) index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device)
# assert gradcheck(segment_coo, (src, index, None, None, reduce)) assert gradcheck(torch_scatter.segment_csr, (src, indptr, None, reduce))
assert gradcheck(segment_csr, (src, indptr, None, reduce)) assert gradcheck(torch_scatter.segment_coo,
(src, index, None, None, reduce))
@pytest.mark.parametrize('test,reduce,dtype,device', @pytest.mark.parametrize('test,reduce,dtype,device',
...@@ -127,25 +128,25 @@ def test_segment_out(test, reduce, dtype, device): ...@@ -127,25 +128,25 @@ def test_segment_out(test, reduce, dtype, device):
size[indptr.dim() - 1] = indptr.size(-1) - 1 size[indptr.dim() - 1] = indptr.size(-1) - 1
out = src.new_full(size, -2) out = src.new_full(size, -2)
segment_csr(src, indptr, out, reduce=reduce) getattr(torch_scatter, f'segment_{reduce}_csr')(src, indptr, out)
assert torch.all(out == expected) assert torch.all(out == expected)
# out.fill_(-2) out.fill_(-2)
# segment_coo(src, index, out, reduce=reduce) getattr(torch_scatter, f'segment_{reduce}_coo')(src, index, out)
# if reduce == 'sum': if reduce == 'sum':
# expected = expected - 2 expected = expected - 2
# elif reduce == 'mean': elif reduce == 'mean':
# expected = out # We can not really test this here. expected = out # We can not really test this here.
# elif reduce == 'min': elif reduce == 'min':
# expected = expected.fill_(-2) expected = expected.fill_(-2)
# elif reduce == 'max': elif reduce == 'max':
# expected[expected == 0] = -2 expected[expected == 0] = -2
# else: else:
# raise ValueError raise ValueError
# assert torch.all(out == expected) assert torch.all(out == expected)
@pytest.mark.parametrize('test,reduce,dtype,device', @pytest.mark.parametrize('test,reduce,dtype,device',
...@@ -163,14 +164,14 @@ def test_non_contiguous_segment(test, reduce, dtype, device): ...@@ -163,14 +164,14 @@ def test_non_contiguous_segment(test, reduce, dtype, device):
if indptr.dim() > 1: if indptr.dim() > 1:
indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1) indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1)
# out = segment_coo(src, index, reduce=reduce) out = getattr(torch_scatter, f'segment_{reduce}_csr')(src, indptr)
# if isinstance(out, tuple): if isinstance(out, tuple):
# out, arg_out = out out, arg_out = out
# arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
# assert torch.all(arg_out == arg_expected) assert torch.all(arg_out == arg_expected)
# assert torch.all(out == expected) assert torch.all(out == expected)
out = segment_csr(src, indptr, reduce=reduce) out = getattr(torch_scatter, f'segment_{reduce}_coo')(src, index)
if isinstance(out, tuple): if isinstance(out, tuple):
out, arg_out = out out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device) arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
......
import torch
torch.ops.load_library('torch_scatter/_C.so')
from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr, from .segment_csr import (segment_sum_csr, segment_add_csr, segment_mean_csr,
segment_min_csr, segment_max_csr, segment_csr, segment_min_csr, segment_max_csr, segment_csr,
gather_csr) gather_csr) # noqa
from .segment_coo import (segment_sum_coo, segment_add_coo, segment_mean_coo,
segment_min_coo, segment_max_coo, segment_coo,
gather_coo) # noqa
__version__ = '1.5.0' __version__ = '1.5.0'
...@@ -13,6 +20,14 @@ __all__ = [ ...@@ -13,6 +20,14 @@ __all__ = [
'segment_max_csr', 'segment_max_csr',
'segment_csr', 'segment_csr',
'gather_csr', 'gather_csr',
'segment_sum_coo',
'segment_add_coo',
'segment_mean_coo',
'segment_min_coo',
'segment_max_coo',
'segment_max_coo',
'segment_coo',
'gather_coo',
'torch_scatter', 'torch_scatter',
'__version__', '__version__',
] ]
from typing import Optional, Tuple
import torch
@torch.jit.script
def segment_sum_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_sum_coo(src, index, out, dim_size)
@torch.jit.script
def segment_add_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_sum_coo(src, index, out, dim_size)
@torch.jit.script
def segment_mean_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.segment_mean_coo(src, index, out, dim_size)
@torch.jit.script
def segment_min_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.segment_min_coo(src, index, out, dim_size)
@torch.jit.script
def segment_max_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.ops.torch_scatter.segment_max_coo(src, index, out, dim_size)
@torch.jit.script
def segment_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
reduce: str = "sum") -> torch.Tensor:
if reduce == 'sum' or reduce == 'add':
return segment_sum_coo(src, index, out, dim_size)
elif reduce == 'mean':
return segment_mean_coo(src, index, out, dim_size)
elif reduce == 'min':
return segment_min_coo(src, index, out, dim_size)[0]
elif reduce == 'max':
return segment_max_coo(src, index, out, dim_size)[0]
else:
raise ValueError
@torch.jit.script
def gather_coo(src: torch.Tensor, index: torch.Tensor,
out: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.ops.torch_scatter.gather_coo(src, index, out)
...@@ -2,8 +2,6 @@ from typing import Optional, Tuple ...@@ -2,8 +2,6 @@ from typing import Optional, Tuple
import torch import torch
torch.ops.load_library('torch_scatter/_C.so')
@torch.jit.script @torch.jit.script
def segment_sum_csr(src: torch.Tensor, indptr: torch.Tensor, def segment_sum_csr(src: torch.Tensor, indptr: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment