Commit 0807f87f authored by rusty1s's avatar rusty1s
Browse files

all tests + segment_coo fixes

parent a9f9266b
...@@ -79,7 +79,7 @@ def correctness(dataset): ...@@ -79,7 +79,7 @@ def correctness(dataset):
out2, _ = segment_coo(x, row, reduce='max') out2, _ = segment_coo(x, row, reduce='max')
out3, _ = segment_csr(x, rowptr, reduce='max') out3, _ = segment_csr(x, rowptr, reduce='max')
# assert torch.allclose(out1, out2, atol=1e-4) assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4) assert torch.allclose(out1, out3, atol=1e-4)
except RuntimeError: except RuntimeError:
......
...@@ -80,9 +80,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -80,9 +80,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
} }
static inline __device__ void atomic_write(scalar_t *address, scalar_t val) { static inline __device__ void atomic_write(scalar_t *address, scalar_t val) {
if (REDUCE == ADD) { if (REDUCE == ADD || REDUCE == MEAN) {
atomAdd(address, val);
} else if (REDUCE == MEAN) {
atomAdd(address, val); atomAdd(address, val);
} else if (REDUCE == MIN && val < *address) { } else if (REDUCE == MIN && val < *address) {
atomMin(address, val); atomMin(address, val);
...@@ -108,15 +106,16 @@ segment_csr_kernel(const scalar_t *src_data, ...@@ -108,15 +106,16 @@ segment_csr_kernel(const scalar_t *src_data,
if (row_idx < N) { if (row_idx < N) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info); int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset); int64_t row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset + int64_t row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]); indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = Reducer<scalar_t, REDUCE>::init(); scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg, arg_tmp; int64_t arg, arg_tmp;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E;
for (int src_idx = row_start + lane_idx; src_idx < row_end; src_idx += TB) { for (int64_t src_idx = row_start + lane_idx; src_idx < row_end;
src_idx += TB) {
Reducer<scalar_t, REDUCE>::update(&val, src_data[offset + src_idx], &arg, Reducer<scalar_t, REDUCE>::update(&val, src_data[offset + src_idx], &arg,
src_idx); src_idx);
} }
...@@ -154,15 +153,15 @@ __global__ void segment_csr_broadcast_kernel( ...@@ -154,15 +153,15 @@ __global__ void segment_csr_broadcast_kernel(
if (thread_idx < N * K) { if (thread_idx < N * K) {
int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info); int offset = IndexPtrToOffset<int64_t>::get(row_idx, indptr_info);
int row_start = __ldg(indptr_info.data + offset); int64_t row_start = __ldg(indptr_info.data + offset);
int row_end = __ldg(indptr_info.data + offset + int64_t row_end = __ldg(indptr_info.data + offset +
indptr_info.strides[indptr_info.dims - 1]); indptr_info.strides[indptr_info.dims - 1]);
scalar_t val = Reducer<scalar_t, REDUCE>::init(); scalar_t val = Reducer<scalar_t, REDUCE>::init();
int64_t arg; int64_t arg;
offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K; offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K;
for (int src_idx = row_start; src_idx < row_end; src_idx++) { for (int64_t src_idx = row_start; src_idx < row_end; src_idx++) {
Reducer<scalar_t, REDUCE>::update( Reducer<scalar_t, REDUCE>::update(
&val, src_data[offset + K * src_idx + lane_idx], &arg, src_idx); &val, src_data[offset + K * src_idx + lane_idx], &arg, src_idx);
} }
...@@ -253,7 +252,7 @@ segment_coo_kernel(const scalar_t *src_data, ...@@ -253,7 +252,7 @@ segment_coo_kernel(const scalar_t *src_data,
if (row_idx < E) { if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get( int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info); row_idx, index_info);
int idx = index_info.data[offset], next_idx; int64_t idx = index_info.data[offset], next_idx;
int out_idx = (row_idx / D) * N + idx; int out_idx = (row_idx / D) * N + idx;
scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp; scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp;
...@@ -289,7 +288,7 @@ __global__ void segment_coo_arg_kernel( ...@@ -289,7 +288,7 @@ __global__ void segment_coo_arg_kernel(
if (row_idx < E) { if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get( int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info); row_idx, index_info);
int idx = index_info.data[offset]; int64_t idx = index_info.data[offset];
int out_idx = (row_idx / D) * N + idx; int out_idx = (row_idx / D) * N + idx;
scalar_t val = __ldg(out_data + out_idx); scalar_t val = __ldg(out_data + out_idx);
...@@ -310,7 +309,7 @@ __global__ void segment_coo_broadcast_kernel( ...@@ -310,7 +309,7 @@ __global__ void segment_coo_broadcast_kernel(
int D = index_info.sizes[index_info.dims - 1]; int D = index_info.sizes[index_info.dims - 1];
int E_1 = E / D; int E_1 = E / D;
int E_2 = D + D % TB; int E_2 = D + TB - (D % TB);
int row_idx = blockIdx.x * blockDim.y + threadIdx.y; int row_idx = blockIdx.x * blockDim.y + threadIdx.y;
int col_idx = blockIdx.y * blockDim.x + threadIdx.x; int col_idx = blockIdx.y * blockDim.x + threadIdx.x;
...@@ -319,6 +318,7 @@ __global__ void segment_coo_broadcast_kernel( ...@@ -319,6 +318,7 @@ __global__ void segment_coo_broadcast_kernel(
int row_start = (row_idx * TB) % E_2; int row_start = (row_idx * TB) % E_2;
if (dim_start < E_1 && col_idx < K) { if (dim_start < E_1 && col_idx < K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get( int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
dim_start * D + row_start, index_info); dim_start * D + row_start, index_info);
int idx1 = __ldg(index_info.data + offset), idx2; int idx1 = __ldg(index_info.data + offset), idx2;
...@@ -341,6 +341,7 @@ __global__ void segment_coo_broadcast_kernel( ...@@ -341,6 +341,7 @@ __global__ void segment_coo_broadcast_kernel(
out_data + (dim_start * N + idx1) * K + col_idx, val); out_data + (dim_start * N + idx1) * K + col_idx, val);
val = src_data[K * (dim_start * D + row_start + i) + col_idx]; val = src_data[K * (dim_start * D + row_start + i) + col_idx];
} }
idx1 = idx2; idx1 = idx2;
} }
...@@ -405,7 +406,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -405,7 +406,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
auto E_1 = index.numel() / E_2; auto E_1 = index.numel() / E_2;
auto K = src.numel() / E; auto K = src.numel() / E;
auto N = out.size(reduce_dim); auto N = out.size(reduce_dim);
auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim); auto avg_len = (float)E_2 / (float)N;
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index); auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
auto stream = at::cuda::getCurrentCUDAStream(); auto stream = at::cuda::getCurrentCUDAStream();
......
...@@ -2,6 +2,7 @@ from itertools import product ...@@ -2,6 +2,7 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch.autograd import gradcheck
from torch_scatter import gather_coo, gather_csr from torch_scatter import gather_coo, gather_csr
from .utils import tensor from .utils import tensor
...@@ -9,38 +10,111 @@ from .utils import tensor ...@@ -9,38 +10,111 @@ from .utils import tensor
dtypes = [torch.float] dtypes = [torch.float]
devices = [torch.device('cuda')] devices = [torch.device('cuda')]
tests = [
{
'src': [1, 2, 3, 4],
'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6],
'expected': [1, 1, 2, 2, 2, 4],
},
{
'src': [[1, 2], [3, 4], [5, 6], [7, 8]],
'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6],
'expected': [[1, 2], [1, 2], [3, 4], [3, 4], [3, 4], [7, 8]]
},
{
'src': [[1, 3, 5, 7], [2, 4, 6, 8]],
'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
'expected': [[1, 1, 3, 3, 3, 7], [2, 2, 2, 4, 4, 6]],
},
{
'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]],
'index': [[0, 0, 1], [0, 2, 2]],
'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
'expected': [[[1, 2], [1, 2], [3, 4]], [[7, 9], [12, 13], [12, 13]]],
},
{
'src': [[1], [2]],
'index': [[0, 0], [0, 0]],
'indptr': [[0, 2], [0, 2]],
'expected': [[1, 1], [2, 2]],
},
{
'src': [[[1, 1]], [[2, 2]]],
'index': [[0, 0], [0, 0]],
'indptr': [[0, 2], [0, 2]],
'expected': [[[1, 1], [1, 1]], [[2, 2], [2, 2]]],
},
]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available') @pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_forward(dtype, device): def test_forward(test, dtype, device):
src = tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype, device) src = tensor(test['src'], dtype, device)
src = tensor([1, 2, 3, 4], dtype, device) index = tensor(test['index'], torch.long, device)
src.requires_grad_() indptr = tensor(test['indptr'], torch.long, device)
indptr = tensor([0, 2, 5, 5, 6], torch.long, device) expected = tensor(test['expected'], dtype, device)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = src.index_select(0, index) out = gather_coo(src, index)
grad_out = torch.randn_like(out) assert torch.all(out == expected)
out.backward(grad_out)
print('EXPECTED')
print(out)
print(src.grad)
src.grad = None
out = gather_csr(src, indptr) out = gather_csr(src, indptr)
out.backward(grad_out) assert torch.all(out == expected)
print('CSR')
print(out)
print(src.grad) @pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
# print('CSR', out) @pytest.mark.parametrize('test,device', product(tests, devices))
def test_backward(test, device):
src = tensor(test['src'], torch.double, device)
src.requires_grad_()
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
assert gradcheck(gather_coo, (src, index, None)) is True
assert gradcheck(gather_csr, (src, indptr, None)) is True
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_segment_out(test, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test['expected'], dtype, device)
size = list(src.size())
size[index.dim() - 1] = index.size(-1)
out = src.new_full(size, -2)
gather_coo(src, index, out)
assert torch.all(out == expected)
out.fill_(-2)
gather_csr(src, indptr, out)
assert torch.all(out == expected)
# out = gather_coo(src, index)
# print('COO', out)
# print('Expected', out) @pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
src.grad = None @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_non_contiguous_segment(test, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test['expected'], dtype, device)
if src.dim() > 1:
src = src.transpose(0, 1).contiguous().transpose(0, 1)
if index.dim() > 1:
index = index.transpose(0, 1).contiguous().transpose(0, 1)
if indptr.dim() > 1:
indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1)
out = gather_coo(src, index) out = gather_coo(src, index)
out.backward(grad_out) assert torch.all(out == expected)
print('COO')
print(out) out = gather_csr(src, indptr)
print(src.grad) assert torch.all(out == expected)
...@@ -2,11 +2,13 @@ from itertools import product ...@@ -2,11 +2,13 @@ from itertools import product
import pytest import pytest
import torch import torch
from torch.autograd import gradcheck
from torch_scatter import segment_coo, segment_csr from torch_scatter import segment_coo, segment_csr
from .utils import tensor from .utils import tensor
reductions = ['add', 'mean', 'min', 'max'] reductions = ['add', 'mean', 'min', 'max']
grad_reductions = ['add', 'mean']
dtypes = [torch.float] dtypes = [torch.float]
devices = [torch.device('cuda')] devices = [torch.device('cuda')]
...@@ -46,15 +48,15 @@ tests = [ ...@@ -46,15 +48,15 @@ tests = [
'arg_max': [[1, 4, 6, 5], [2, 4, 5, 6]], 'arg_max': [[1, 4, 6, 5], [2, 4, 5, 6]],
}, },
{ {
'src': [[[1, 3, 5], [2, 4, 6]], [[7, 9, 11], [8, 10, 12]]], 'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]],
'index': [[[0, 0, 1], [0, 2, 2]], [[0, 0, 1], [0, 2, 2]]], 'index': [[0, 0, 1], [0, 2, 2]],
'indptr': [[[0, 2, 3, 3], [0, 1, 1, 3]], [[0, 2, 3, 3], [0, 1, 1, 3]]], 'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
'add': [[[4, 5, 0], [2, 0, 10]], [[16, 11, 0], [8, 0, 22]]], 'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
'mean': [[[2, 5, 0], [2, 0, 5]], [[8, 11, 0], [8, 0, 11]]], 'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]],
'min': [[[1, 5, 0], [2, 0, 4]], [[7, 11, 0], [8, 0, 10]]], 'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]],
'arg_min': [[[0, 2, 3], [0, 3, 1]], [[0, 2, 3], [0, 3, 1]]], 'arg_min': [[[0, 0], [2, 2], [3, 3]], [[0, 0], [3, 3], [1, 1]]],
'max': [[[3, 5, 0], [2, 0, 6]], [[9, 11, 0], [8, 0, 12]]], 'max': [[[3, 4], [5, 6], [0, 0]], [[7, 9], [0, 0], [12, 13]]],
'arg_max': [[[1, 2, 3], [0, 3, 2]], [[1, 2, 3], [0, 3, 2]]], 'arg_max': [[[1, 1], [2, 2], [3, 3]], [[0, 0], [3, 3], [2, 2]]],
}, },
{ {
'src': [[1, 3], [2, 4]], 'src': [[1, 3], [2, 4]],
...@@ -84,7 +86,7 @@ tests = [ ...@@ -84,7 +86,7 @@ tests = [
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available') @pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,reduce,dtype,device', @pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices)) product(tests, reductions, dtypes, devices))
def test_segment(test, reduce, dtype, device): def test_forward(test, reduce, dtype, device):
src = tensor(test['src'], dtype, device) src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device) index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device)
...@@ -105,6 +107,19 @@ def test_segment(test, reduce, dtype, device): ...@@ -105,6 +107,19 @@ def test_segment(test, reduce, dtype, device):
assert torch.all(out == expected) assert torch.all(out == expected)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,reduce,device',
product(tests, grad_reductions, devices))
def test_backward(test, reduce, device):
src = tensor(test['src'], torch.double, device)
src.requires_grad_()
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
assert gradcheck(segment_coo, (src, index, None, None, reduce)) is True
assert gradcheck(segment_csr, (src, indptr, None, reduce)) is True
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available') @pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,reduce,dtype,device', @pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices)) product(tests, reductions, dtypes, devices))
...@@ -118,18 +133,12 @@ def test_segment_out(test, reduce, dtype, device): ...@@ -118,18 +133,12 @@ def test_segment_out(test, reduce, dtype, device):
size[indptr.dim() - 1] = indptr.size(-1) - 1 size[indptr.dim() - 1] = indptr.size(-1) - 1
out = src.new_full(size, -2) out = src.new_full(size, -2)
# Pre-defined `out` values shouldn't do anything. segment_csr(src, indptr, out, reduce=reduce)
out = segment_csr(src, indptr, out, reduce=reduce)
if isinstance(out, tuple):
out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected) assert torch.all(out == expected)
out.fill_(-2) out.fill_(-2)
out = segment_coo(src, index, out, reduce=reduce) segment_coo(src, index, out, reduce=reduce)
out = out[0] if isinstance(out, tuple) else out
if reduce == 'add': if reduce == 'add':
expected = expected - 2 expected = expected - 2
......
...@@ -64,6 +64,7 @@ class SegmentCOO(torch.autograd.Function): ...@@ -64,6 +64,7 @@ class SegmentCOO(torch.autograd.Function):
index.dim() - 1, arg_out, grad_out) index.dim() - 1, arg_out, grad_out)
grad_src = grad_src.narrow(index.dim() - 1, 0, grad_src = grad_src.narrow(index.dim() - 1, 0,
src_size[index.dim() - 1] - 1) src_size[index.dim() - 1] - 1)
return grad_src, None, None, None, None return grad_src, None, None, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment