Commit 80a7dc52 authored by rusty1s's avatar rusty1s
Browse files

all tests on CPU+GPU

parent 5db00866
...@@ -7,7 +7,6 @@ from scipy.io import loadmat ...@@ -7,7 +7,6 @@ from scipy.io import loadmat
from torch_scatter import gather_coo, gather_csr from torch_scatter import gather_coo, gather_csr
from scatter_segment import iters, sizes
from scatter_segment import short_rows, long_rows, download, bold from scatter_segment import short_rows, long_rows, download, bold
...@@ -125,6 +124,9 @@ if __name__ == '__main__': ...@@ -125,6 +124,9 @@ if __name__ == '__main__':
parser.add_argument('--with_backward', action='store_true') parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args() args = parser.parse_args()
iters = 1 if args.device == 'cpu' else 20
sizes = [1, 16, 32, 64, 128, 256, 512]
sizes = sizes[:3] if args.device == 'cpu' else sizes
for _ in range(10): # Warmup. for _ in range(10): # Warmup.
torch.randn(100, 100, device=args.device).sum() torch.randn(100, 100, device=args.device).sum()
......
#include <torch/extension.h> #include <torch/extension.h>
#include "compat.h"
#include "index_info.h"
#define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor") #define CHECK_CPU(x) AT_ASSERTM(!x.type().is_cuda(), #x " must be CPU tensor")
at::Tensor gather_csr(at::Tensor src, at::Tensor indptr, at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
...@@ -8,8 +11,59 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr, ...@@ -8,8 +11,59 @@ at::Tensor gather_csr(at::Tensor src, at::Tensor indptr,
CHECK_CPU(indptr); CHECK_CPU(indptr);
if (out_opt.has_value()) if (out_opt.has_value())
CHECK_CPU(out_opt.value()); CHECK_CPU(out_opt.value());
AT_ASSERTM(false, "Not yet implemented");
return src; AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
for (int i = 0; i < indptr.dim() - 1; i++)
AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch");
src = src.contiguous();
auto gather_dim = indptr.dim() - 1;
AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1,
"Input mismatch");
at::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != gather_dim)
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
} else {
auto sizes = src.sizes().vec();
sizes[gather_dim] = *indptr.flatten()[-1].DATA_PTR<int64_t>();
out = at::empty(sizes, src.options());
}
auto N = src.size(gather_dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N;
auto E = out.size(gather_dim);
auto indptr_info = getTensorInfo<int64_t>(indptr);
auto stride = indptr_info.strides[indptr_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_csr", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
scalar_t vals[K];
int64_t row_start, row_end;
for (int n = 0; n < N; n++) {
int offset = IndexPtrToOffset<int64_t>::get(n, indptr_info);
row_start = indptr_info.data[offset];
row_end = indptr_info.data[offset + stride];
for (int k = 0; k < K; k++) {
vals[k] = src_data[n * K + k];
}
offset = (n / (indptr.size(-1) - 1)) * E * K;
for (int64_t e = row_start; e < row_end; e++) {
for (int k = 0; k < K; k++) {
out_data[offset + e * K + k] = vals[k];
}
}
}
});
return out;
} }
at::Tensor gather_coo(at::Tensor src, at::Tensor index, at::Tensor gather_coo(at::Tensor src, at::Tensor index,
...@@ -18,8 +72,69 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index, ...@@ -18,8 +72,69 @@ at::Tensor gather_coo(at::Tensor src, at::Tensor index,
CHECK_CPU(index); CHECK_CPU(index);
if (out_opt.has_value()) if (out_opt.has_value())
CHECK_CPU(out_opt.value()); CHECK_CPU(out_opt.value());
AT_ASSERTM(false, "Not yet implemented");
return src; AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
for (int i = 0; i < index.dim() - 1; i++)
AT_ASSERTM(src.size(i) == index.size(i), "Input mismatch");
src = src.contiguous();
auto gather_dim = index.dim() - 1;
at::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < index.dim(); i++)
AT_ASSERTM(out.size(i) == index.size(i), "Input mismatch");
for (int i = index.dim() + 1; i < src.dim(); i++)
AT_ASSERTM(out.size(i) == src.size(i), "Input mismatch");
} else {
auto sizes = src.sizes().vec();
sizes[gather_dim] = index.size(gather_dim);
out = at::empty(sizes, src.options());
}
auto E_1 = index.numel() / out.size(gather_dim);
auto E_2 = index.size(gather_dim);
auto K = out.numel() / index.numel();
auto N = src.size(gather_dim);
auto index_info = getTensorInfo<int64_t>(index);
auto stride = index_info.strides[index_info.dims - 1];
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "gather_coo", [&] {
auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>();
scalar_t vals[K];
int64_t idx, next_idx;
for (int e_1 = 0; e_1 < E_1; e_1++) {
int offset = IndexToOffset<int64_t>::get(e_1 * E_2, index_info);
idx = index_info.data[offset];
for (int k = 0; k < K; k++) {
vals[k] = src_data[e_1 * N * K + idx * K + k];
}
for (int e_2 = 0; e_2 < E_2; e_2++) {
for (int k = 0; k < K; k++) {
out_data[e_1 * E_2 * K + e_2 * K + k] = vals[k];
}
if (e_2 < E_2 - 1) {
next_idx = index_info.data[offset + (e_2 + 1) * stride];
assert(idx <= next_idx);
if (idx != next_idx) {
idx = next_idx;
for (int k = 0; k < K; k++) {
vals[k] = src_data[e_1 * N * K + idx * K + k];
}
}
}
}
}
});
return out;
} }
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
......
...@@ -184,7 +184,6 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -184,7 +184,6 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
arg_out_data = arg_out.value().DATA_PTR<int64_t>(); arg_out_data = arg_out.value().DATA_PTR<int64_t>();
} }
auto E = index.numel();
auto E_1 = index.numel() / src.size(reduce_dim); auto E_1 = index.numel() / src.size(reduce_dim);
auto E_2 = src.size(reduce_dim); auto E_2 = src.size(reduce_dim);
auto K = src.numel() / index.numel(); auto K = src.numel() / index.numel();
...@@ -202,12 +201,12 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -202,12 +201,12 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
for (int e_1 = 0; e_1 < E_1; e_1++) { for (int e_1 = 0; e_1 < E_1; e_1++) {
int offset = IndexToOffset<int64_t>::get(e_1 * E_2, index_info); int offset = IndexToOffset<int64_t>::get(e_1 * E_2, index_info);
idx = index_info.data[offset]; idx = index_info.data[offset];
row_start = 0;
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
vals[k] = out_data[e_1 * N * K + k]; vals[k] = out_data[e_1 * N * K + k];
} }
row_start = 0;
for (int e_2 = 0; e_2 < E_2; e_2++) { for (int e_2 = 0; e_2 < E_2; e_2++) {
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
...@@ -224,6 +223,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -224,6 +223,7 @@ segment_coo(at::Tensor src, at::Tensor index, at::Tensor out,
} }
} else { } else {
next_idx = index_info.data[offset + (e_2 + 1) * stride]; next_idx = index_info.data[offset + (e_2 + 1) * stride];
assert(idx <= next_idx);
if (idx != next_idx) { if (idx != next_idx) {
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
......
...@@ -5,10 +5,7 @@ import torch ...@@ -5,10 +5,7 @@ import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_scatter import gather_coo, gather_csr from torch_scatter import gather_coo, gather_csr
from .utils import tensor from .utils import tensor, dtypes, devices
dtypes = [torch.float]
devices = [torch.device('cuda')]
tests = [ tests = [
{ {
...@@ -50,7 +47,6 @@ tests = [ ...@@ -50,7 +47,6 @@ tests = [
] ]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_forward(test, dtype, device): def test_forward(test, dtype, device):
src = tensor(test['src'], dtype, device) src = tensor(test['src'], dtype, device)
...@@ -65,7 +61,6 @@ def test_forward(test, dtype, device): ...@@ -65,7 +61,6 @@ def test_forward(test, dtype, device):
assert torch.all(out == expected) assert torch.all(out == expected)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,device', product(tests, devices)) @pytest.mark.parametrize('test,device', product(tests, devices))
def test_backward(test, device): def test_backward(test, device):
src = tensor(test['src'], torch.double, device) src = tensor(test['src'], torch.double, device)
...@@ -77,9 +72,8 @@ def test_backward(test, device): ...@@ -77,9 +72,8 @@ def test_backward(test, device):
assert gradcheck(gather_csr, (src, indptr, None)) is True assert gradcheck(gather_csr, (src, indptr, None)) is True
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_segment_out(test, dtype, device): def test_gather_out(test, dtype, device):
src = tensor(test['src'], dtype, device) src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device) index = tensor(test['index'], torch.long, device)
indptr = tensor(test['indptr'], torch.long, device) indptr = tensor(test['indptr'], torch.long, device)
...@@ -98,7 +92,6 @@ def test_segment_out(test, dtype, device): ...@@ -98,7 +92,6 @@ def test_segment_out(test, dtype, device):
assert torch.all(out == expected) assert torch.all(out == expected)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_non_contiguous_segment(test, dtype, device): def test_non_contiguous_segment(test, dtype, device):
src = tensor(test['src'], dtype, device) src = tensor(test['src'], dtype, device)
......
...@@ -5,13 +5,11 @@ import torch ...@@ -5,13 +5,11 @@ import torch
from torch.autograd import gradcheck from torch.autograd import gradcheck
from torch_scatter import segment_coo, segment_csr from torch_scatter import segment_coo, segment_csr
from .utils import tensor, dtypes from .utils import tensor, dtypes, devices
reductions = ['add', 'mean', 'min', 'max'] reductions = ['add', 'mean', 'min', 'max']
grad_reductions = ['add', 'mean'] grad_reductions = ['add', 'mean']
devices = [torch.device('cpu')]
tests = [ tests = [
{ {
'src': [1, 2, 3, 4, 5, 6], 'src': [1, 2, 3, 4, 5, 6],
...@@ -105,7 +103,6 @@ def test_forward(test, reduce, dtype, device): ...@@ -105,7 +103,6 @@ def test_forward(test, reduce, dtype, device):
assert torch.all(out == expected) assert torch.all(out == expected)
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('test,reduce,device', @pytest.mark.parametrize('test,reduce,device',
product(tests, grad_reductions, devices)) product(tests, grad_reductions, devices))
def test_backward(test, reduce, device): def test_backward(test, reduce, device):
......
...@@ -56,12 +56,20 @@ class SegmentCOO(torch.autograd.Function): ...@@ -56,12 +56,20 @@ class SegmentCOO(torch.autograd.Function):
grad_src = None grad_src = None
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
if ctx.reduce == 'add': if ctx.reduce == 'add':
grad_src = gat(grad_out).gather_coo( grad_src = gat(grad_out.is_cuda).gather_coo(
grad_out, index, grad_out.new_empty(src_size)) grad_out, index, grad_out.new_empty(src_size))
elif ctx.reduce == 'mean': elif ctx.reduce == 'mean':
grad_src = gat(grad_out).gather_coo( grad_src = gat(grad_out.is_cuda).gather_coo(
grad_out, index, grad_out.new_empty(src_size)) grad_out, index, grad_out.new_empty(src_size))
count = arg_out
count = arg_out # Gets pre-computed on GPU but not on CPU.
if count is None:
size = list(index.size())
size[-1] = grad_out.size(index.dim() - 1)
count = segment_cpu.segment_coo(
torch.ones_like(index, dtype=grad_out.dtype), index,
grad_out.new_zeros(size), 'add')[0].clamp_(min=1)
count = gat(grad_out.is_cuda).gather_coo( count = gat(grad_out.is_cuda).gather_coo(
count, index, count.new_empty(src_size[:index.dim()])) count, index, count.new_empty(src_size[:index.dim()]))
for _ in range(grad_out.dim() - index.dim()): for _ in range(grad_out.dim() - index.dim()):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment