Commit 2743b291 authored by rusty1s's avatar rusty1s
Browse files

basic segment tests

parent 34045b9a
......@@ -79,7 +79,7 @@ def correctness(dataset):
out2, _ = segment_coo(x, row, reduce='max')
out3, _ = segment_csr(x, rowptr, reduce='max')
assert torch.allclose(out1, out2, atol=1e-4)
# assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
except RuntimeError:
......
......@@ -12,6 +12,7 @@
#define FULL_MASK 0xffffffff
enum ReductionType { ADD, MEAN, MIN, MAX };
#define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \
[&] { \
if (reduce == "add") { \
......@@ -204,22 +205,6 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
if (reduce == "any") {
auto index = indptr.narrow(reduce_dim, 0, indptr.size(reduce_dim) - 1);
auto index2 = indptr.narrow(reduce_dim, 1, indptr.size(reduce_dim) - 1);
auto mask = (index2 - index) == 0;
for (int i = reduce_dim + 1; i < src.dim(); i++) {
index = index.unsqueeze(-1);
mask = mask.unsqueeze(-1);
}
at::gather_out(out, src, reduce_dim, index.expand(out.sizes()));
out.masked_fill_(mask.expand(out.sizes()), 0);
return std::make_tuple(out, arg_out);
}
auto N = out.size(reduce_dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(reduce_dim);
......@@ -258,12 +243,13 @@ segment_coo_kernel(const scalar_t *src_data,
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
int lane_idx = row_idx & (32 - 1);
int D = index_info.sizes[index_info.dims - 1];
if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int idx = index_info.data[offset], next_idx;
int out_idx = (row_idx / index_info.sizes[index_info.dims - 1]) * N + idx;
int out_idx = (row_idx / D) * N + idx;
scalar_t val = HAS_VAL ? src_data[row_idx] : (scalar_t)1, tmp;
......@@ -272,15 +258,17 @@ segment_coo_kernel(const scalar_t *src_data,
// Parallel reduction inside a single warp.
tmp = __shfl_up_sync(FULL_MASK, val, i);
next_idx = __shfl_up_sync(FULL_MASK, idx, i);
assert(idx >= next_idx);
if (lane_idx >= i && idx == next_idx)
Reducer<scalar_t, REDUCE>::update(&val, tmp);
if (lane_idx >= i && row_idx / D == (row_idx - i) / D) {
assert(idx >= next_idx);
if (idx == next_idx)
Reducer<scalar_t, REDUCE>::update(&val, tmp);
}
}
next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
if (lane_idx == 32 - 1 || idx != next_idx) {
if (lane_idx == 32 - 1 || row_idx / D != (row_idx + 1) / D ||
idx != next_idx)
Reducer<scalar_t, REDUCE>::atomic_write(out_data + out_idx, val);
}
}
}
......@@ -291,16 +279,17 @@ __global__ void segment_coo_arg_kernel(
scalar_t *out_data, int64_t *arg_out_data, size_t E, size_t N) {
int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
int D = index_info.sizes[index_info.dims - 1];
if (row_idx < E) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int idx = index_info.data[offset];
int out_idx = (row_idx / index_info.sizes[index_info.dims - 1]) * N + idx;
int out_idx = (row_idx / D) * N + idx;
scalar_t val = __ldg(out_data + out_idx);
if (src_data[row_idx] == val)
arg_out_data[out_idx] = row_idx % index_info.sizes[index_info.dims - 1];
arg_out_data[out_idx] = row_idx % D;
}
}
......@@ -314,38 +303,44 @@ __global__ void segment_coo_broadcast_kernel(
// read and write is performed in column-major order. The intermediate
// results are written via atomics.
int row_start = (blockIdx.x * blockDim.y + threadIdx.y) * TB;
int D = index_info.sizes[index_info.dims - 1];
int E_1 = E / D;
int E_2 = D + D % TB;
int row_idx = blockIdx.x * blockDim.y + threadIdx.y;
int col_idx = blockIdx.y * blockDim.x + threadIdx.x;
if (row_start < E && col_idx < K) {
int dim_start = (row_idx * TB) / E_2;
int row_start = (row_idx * TB) % E_2;
if (dim_start < E_1 && col_idx < K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_start, index_info);
int out_idx = (row_start / index_info.sizes[index_info.dims - 1]) * N;
dim_start * D + row_start, index_info);
int idx1 = __ldg(index_info.data + offset), idx2;
int idx1 = __ldg(index_info.data + offset);
scalar_t val = src_data[K * row_start + col_idx];
scalar_t val = src_data[K * (dim_start * D + row_start) + col_idx];
#pragma unroll
for (int i = 1; i < TB; i++) {
if (row_start + i >= E)
if (row_start + i >= D)
break;
int idx2 = __ldg(index_info.data + offset +
i * index_info.strides[index_info.dims - 1]);
idx2 = __ldg(index_info.data + offset +
i * index_info.strides[index_info.dims - 1]);
assert(idx1 <= idx2);
if (idx1 == idx2) {
Reducer<scalar_t, REDUCE>::update(
&val, src_data[K * (row_start + i) + col_idx]);
&val, src_data[K * (dim_start * D + row_start + i) + col_idx]);
} else {
Reducer<scalar_t, REDUCE>::atomic_write(
out_data + (out_idx + idx1) * K + col_idx, val);
val = src_data[K * (row_start + i) + col_idx];
out_data + (dim_start * N + idx1) * K + col_idx, val);
val = src_data[K * (dim_start * D + row_start + i) + col_idx];
}
idx1 = idx2;
}
Reducer<scalar_t, REDUCE>::atomic_write(
out_data + (out_idx + idx1) * K + col_idx, val);
out_data + (dim_start * N + idx1) * K + col_idx, val);
}
}
......@@ -358,18 +353,17 @@ __global__ void segment_coo_arg_broadcast_kernel(
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
int row_idx = thread_idx / K;
int col_idx = thread_idx % K;
int D = index_info.sizes[index_info.dims - 1];
if (row_idx < E && col_idx < K) {
int offset = at::cuda::detail::IndexToOffset<int64_t, int, -1>::get(
row_idx, index_info);
int idx = __ldg(index_info.data + offset);
int out_idx =
((row_idx / index_info.sizes[index_info.dims - 1]) * N + idx) * K +
col_idx;
int out_idx = ((row_idx / D) * N + idx) * K + col_idx;
scalar_t val = __ldg(out_data + out_idx);
if (src_data[thread_idx] == val)
arg_out_data[out_idx] = row_idx % index_info.sizes[index_info.dims - 1];
arg_out_data[out_idx] = row_idx % D;
}
}
......@@ -395,15 +389,9 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
}
if (reduce == "any") {
for (int i = reduce_dim + 1; i < src.dim(); i++) {
index = index.unsqueeze(-1);
}
out.scatter_(reduce_dim, index.expand(src.sizes()), src);
return std::make_tuple(out, arg_out);
}
auto E = index.numel();
auto E_2 = index.size(reduce_dim);
auto E_1 = index.numel() / E_2;
auto K = src.numel() / E;
auto N = out.size(reduce_dim);
auto avg_len = (float)src.size(reduce_dim) / (float)out.size(reduce_dim);
......@@ -421,20 +409,22 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
out_data, E, N);
} else if (avg_len <= 8) {
segment_coo_broadcast_kernel<scalar_t, REDUCE, 4>
<<<dim3(((E + (8 * 4) - 1) / (8 * 4)), (K + 31) / 32), dim3(32, 8),
0, stream>>>(src_data, index_info, out_data, E, K, N);
<<<dim3((E_1 * ((E_2 + 3) / 4) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N);
} else if (avg_len <= 16) {
segment_coo_broadcast_kernel<scalar_t, REDUCE, 8>
<<<dim3(((E + (8 * 8) - 1) / (8 * 8)), (K + 31) / 32), dim3(32, 8),
0, stream>>>(src_data, index_info, out_data, E, K, N);
<<<dim3((E_1 * ((E_2 + 7) / 8) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N);
} else if (avg_len <= 32) {
segment_coo_broadcast_kernel<scalar_t, REDUCE, 16>
<<<dim3(((E + (8 * 16) - 1) / (8 * 16)), (K + 31) / 32),
<<<dim3((E_1 * ((E_2 + 15) / 16) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N);
} else {
segment_coo_broadcast_kernel<scalar_t, REDUCE, 32>
<<<dim3(((E + (8 * 32) - 1) / (8 * 32)), (K + 31) / 32),
<<<dim3((E_1 * ((E_2 + 31) / 32) + 7) / 8, (K + 31) / 32),
dim3(32, 8), 0, stream>>>(src_data, index_info, out_data, E, K,
N);
}
......
......@@ -3,54 +3,174 @@ from itertools import product
import pytest
import torch
from torch_scatter import segment_coo, segment_csr
from torch_scatter import scatter_max
from .utils import tensor
reductions = ['add', 'mean', 'min', 'max']
dtypes = [torch.float]
devices = [torch.device('cuda')]
tests = [
{
'src': [1, 2, 3, 4, 5, 6],
'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6],
'add': [3, 12, 0, 6],
'mean': [1.5, 4, 0, 6],
'min': [1, 3, 0, 6],
'arg_min': [0, 2, 6, 5],
'max': [2, 5, 0, 6],
'arg_max': [1, 4, 6, 5],
},
{
'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]],
'index': [0, 0, 1, 1, 1, 3],
'indptr': [0, 2, 5, 5, 6],
'add': [[4, 6], [21, 24], [0, 0], [11, 12]],
'mean': [[2, 3], [7, 8], [0, 0], [11, 12]],
'min': [[1, 2], [5, 6], [0, 0], [11, 12]],
'arg_min': [[0, 0], [2, 2], [6, 6], [5, 5]],
'max': [[3, 4], [9, 10], [0, 0], [11, 12]],
'arg_max': [[1, 1], [4, 4], [6, 6], [5, 5]],
},
{
'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]],
'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
'add': [[4, 21, 0, 11], [12, 18, 12, 0]],
'mean': [[2, 7, 0, 11], [4, 9, 12, 0]],
'min': [[1, 5, 0, 11], [2, 8, 12, 0]],
'arg_min': [[0, 2, 6, 5], [0, 3, 5, 6]],
'max': [[3, 9, 0, 11], [6, 10, 12, 0]],
'arg_max': [[1, 4, 6, 5], [2, 4, 5, 6]],
},
{
'src': [[[1, 3, 5], [2, 4, 6]], [[7, 9, 11], [8, 10, 12]]],
'index': [[[0, 0, 1], [0, 2, 2]], [[0, 0, 1], [0, 2, 2]]],
'indptr': [[[0, 2, 3, 3], [0, 1, 1, 3]], [[0, 2, 3, 3], [0, 1, 1, 3]]],
'add': [[[4, 5, 0], [2, 0, 10]], [[16, 11, 0], [8, 0, 22]]],
'mean': [[[2, 5, 0], [2, 0, 5]], [[8, 11, 0], [8, 0, 11]]],
'min': [[[1, 5, 0], [2, 0, 4]], [[7, 11, 0], [8, 0, 10]]],
'arg_min': [[[0, 2, 3], [0, 3, 1]], [[0, 2, 3], [0, 3, 1]]],
'max': [[[3, 5, 0], [2, 0, 6]], [[9, 11, 0], [8, 0, 12]]],
'arg_max': [[[1, 2, 3], [0, 3, 2]], [[1, 2, 3], [0, 3, 2]]],
},
{
'src': [[1, 3], [2, 4]],
'index': [[0, 0], [0, 0]],
'indptr': [[0, 2], [0, 2]],
'add': [[4], [6]],
'mean': [[2], [3]],
'min': [[1], [2]],
'arg_min': [[0], [0]],
'max': [[3], [4]],
'arg_max': [[1], [1]],
},
{
'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
'index': [[0, 0], [0, 0]],
'indptr': [[0, 2], [0, 2]],
'add': [[[4, 4]], [[6, 6]]],
'mean': [[[2, 2]], [[3, 3]]],
'min': [[[1, 1]], [[2, 2]]],
'arg_min': [[[0, 0]], [[0, 0]]],
'max': [[[3, 3]], [[4, 4]]],
'arg_max': [[[1, 1]], [[1, 1]]],
},
]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices))
def test_segment(test, reduce, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test[reduce], dtype, device)
out = segment_coo(src, index, reduce=reduce)
if isinstance(out, tuple):
out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected)
out = segment_csr(src, indptr, reduce=reduce)
if isinstance(out, tuple):
out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward(dtype, device):
src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype,
device)
src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
# src = tensor([-1, -2, -3, -4, -5, -6], dtype, device)
src.requires_grad_()
indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out, arg = scatter_max(src, index, dim=0)
print('SCA')
print(out)
print(arg)
# print('SCA', out)
# grad_out = torch.randn_like(out)
# print(grad_out)
# out.backward(grad_out)
# print(src.grad)
# src.grad = None
out, arg = segment_coo(src, index, reduce='max')
print('COO')
print(out)
print(arg)
out, arg = segment_csr(src, indptr, reduce='max')
print('CSR')
print(out)
print(arg)
# out.backward(grad_out)
# print(src.grad)
# out = out[0] if isinstance(out, tuple) else out
# out.backward(torch.randn_like(out))
# out = segment_coo(src, index, reduce='max')[0]
# print('COO', out)
@pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices))
def test_segment_out(test, reduce, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test[reduce], dtype, device)
size = list(src.size())
size[indptr.dim() - 1] = indptr.size(-1) - 1
out = src.new_full(size, -2)
# Pre-defined `out` values shouldn't do anything.
out = segment_csr(src, indptr, out, reduce=reduce)
if isinstance(out, tuple):
out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected)
out.fill_(-2)
out = segment_coo(src, index, out, reduce=reduce)
out = out[0] if isinstance(out, tuple) else out
if reduce == 'add':
expected = expected - 2
elif reduce == 'mean':
expected = out # We can not really test this here.
elif reduce == 'min':
expected = expected.fill_(-2)
elif reduce == 'max':
expected[expected == 0] = -2
else:
raise ValueError
assert torch.all(out == expected)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,reduce,dtype,device',
product(tests, reductions, dtypes, devices))
def test_non_contiguous_segment(test, reduce, dtype, device):
src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device)
expected = tensor(test[reduce], dtype, device)
if src.dim() > 1:
src = src.transpose(0, 1).contiguous().transpose(0, 1)
if index.dim() > 1:
index = index.transpose(0, 1).contiguous().transpose(0, 1)
if indptr.dim() > 1:
indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1)
out = segment_coo(src, index, reduce=reduce)
if isinstance(out, tuple):
out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected)
out = segment_csr(src, indptr, reduce=reduce)
if isinstance(out, tuple):
out, arg_out = out
arg_expected = tensor(test[f'arg_{reduce}'], torch.long, device)
assert torch.all(arg_out == arg_expected)
assert torch.all(out == expected)
......@@ -9,7 +9,7 @@ if torch.cuda.is_available():
class SegmentCOO(torch.autograd.Function):
@staticmethod
def forward(ctx, src, index, out, dim_size, reduce):
assert reduce in ['any', 'add', 'mean', 'min', 'max']
assert reduce in ['add', 'mean', 'min', 'max']
if out is not None:
ctx.mark_dirty(out)
ctx.reduce = reduce
......@@ -46,7 +46,7 @@ class SegmentCOO(torch.autograd.Function):
grad_src = None
if ctx.needs_input_grad[0]:
if ctx.reduce == 'any' or ctx.reduce == 'add':
if ctx.reduce == 'add':
grad_src = gather_cuda.gather_coo(grad_out, index,
grad_out.new_empty(src_size))
elif ctx.reduce == 'mean':
......@@ -70,7 +70,7 @@ class SegmentCOO(torch.autograd.Function):
class SegmentCSR(torch.autograd.Function):
@staticmethod
def forward(ctx, src, indptr, out, reduce):
assert reduce in ['any', 'add', 'mean', 'min', 'max']
assert reduce in ['add', 'mean', 'min', 'max']
if out is not None:
ctx.mark_dirty(out)
......@@ -87,7 +87,7 @@ class SegmentCSR(torch.autograd.Function):
grad_src = None
if ctx.needs_input_grad[0]:
if ctx.reduce == 'any' or ctx.reduce == 'add':
if ctx.reduce == 'add':
grad_src = gather_cuda.gather_csr(grad_out, indptr,
grad_out.new_empty(src_size))
elif ctx.reduce == 'mean':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment