"vscode:/vscode.git/clone" did not exist on "d90ff389e056a1c690ab4f433750b5443e858968"
Commit 3c89ebc2 authored by rusty1s's avatar rusty1s
Browse files

autograd function

parent 6e561c88
...@@ -68,8 +68,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -68,8 +68,9 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
} }
} }
static inline __device__ void atom_write(scalar_t *address, scalar_t val, static inline __device__ void atomic_write(scalar_t *address, scalar_t val,
int64_t *arg_address, int64_t arg) { int64_t *arg_address,
int64_t arg) {
if (REDUCE == ADD) { if (REDUCE == ADD) {
atomAdd(address, val); atomAdd(address, val);
} else if (REDUCE == MEAN) { } else if (REDUCE == MEAN) {
...@@ -81,6 +82,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer { ...@@ -81,6 +82,7 @@ template <typename scalar_t, ReductionType REDUCE> struct Reducer {
} }
if (REDUCE == MIN || REDUCE == MAX) { if (REDUCE == MIN || REDUCE == MAX) {
assert(false); // TODO
__syncthreads(); __syncthreads();
if (*address == val) { if (*address == val) {
*arg_address = arg; *arg_address = arg;
...@@ -280,7 +282,7 @@ segment_coo_kernel(const scalar_t *src_data, ...@@ -280,7 +282,7 @@ segment_coo_kernel(const scalar_t *src_data,
next_idx = __shfl_down_sync(FULL_MASK, idx, 1); next_idx = __shfl_down_sync(FULL_MASK, idx, 1);
if (lane_idx == 32 - 1 || idx != next_idx) { if (lane_idx == 32 - 1 || idx != next_idx) {
Reducer<scalar_t, REDUCE>::atom_write(out_data + idx, val, Reducer<scalar_t, REDUCE>::atomic_write(out_data + idx, val,
arg_out_data + idx, arg); arg_out_data + idx, arg);
} }
} }
...@@ -343,8 +345,10 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -343,8 +345,10 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
AT_ASSERTM(src.size(i) == out.size(i)); AT_ASSERTM(src.size(i) == out.size(i));
at::optional<at::Tensor> arg_out = at::nullopt; at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr;
if (reduce == "min" || reduce == "max") { if (reduce == "min" || reduce == "max") {
arg_out = at::full_like(out, src.size(reduce_dim), index.options()); arg_out = at::full_like(out, src.size(reduce_dim), index.options());
arg_out_data = arg_out.value().DATA_PTR<int64_t>();
} }
auto E = index.numel(); auto E = index.numel();
...@@ -357,43 +361,41 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -357,43 +361,41 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
auto src_data = src.DATA_PTR<scalar_t>(); auto src_data = src.DATA_PTR<scalar_t>();
auto out_data = out.DATA_PTR<scalar_t>(); auto out_data = out.DATA_PTR<scalar_t>();
// Select the right kernel based on average row length (purely heuristic) AT_DISPATCH_REDUCTION_TYPES(reduce, [&] {
// and whether we need broadcasting capabilties (K > 1): if (K == 1) {
segment_coo_kernel<scalar_t, REDUCE>
if (K == 1 && reduce == "add") { <<<BLOCKS(1, E), THREADS, 0, stream>>>(src_data, index_info,
segment_coo_kernel<scalar_t, ADD><<<BLOCKS(1, E), THREADS, 0, stream>>>( out_data, arg_out_data, E);
src_data, index_info, out_data, nullptr, E); } else if (avg_len <= 8) {
} else if (K == 1 && reduce == "mean") { segment_coo_broadcast_kernel<scalar_t, REDUCE, 4>
segment_coo_kernel<scalar_t, MEAN><<<BLOCKS(1, E), THREADS, 0, stream>>>( <<<dim3(((E + (8 * 4) - 1) / (8 * 4)), (K + 31) / 32), dim3(32, 8),
src_data, index_info, out_data, nullptr, E); 0, stream>>>(src_data, index_info, out_data, arg_out_data, E, K);
} else if (K == 1 && reduce == "min") { } else if (avg_len <= 16) {
auto arg_out_data = arg_out.value().DATA_PTR<int64_t>(); segment_coo_broadcast_kernel<scalar_t, REDUCE, 8>
segment_coo_kernel<scalar_t, MIN><<<BLOCKS(1, E), THREADS, 0, stream>>>( <<<dim3(((E + (8 * 8) - 1) / (8 * 8)), (K + 31) / 32), dim3(32, 8),
src_data, index_info, out_data, arg_out_data, E); 0, stream>>>(src_data, index_info, out_data, arg_out_data, E, K);
} else if (K == 1 && reduce == "max") { } else if (avg_len <= 32) {
auto arg_out_data = arg_out.value().DATA_PTR<int64_t>(); segment_coo_broadcast_kernel<scalar_t, REDUCE, 16>
segment_coo_kernel<scalar_t, MAX><<<BLOCKS(1, E), THREADS, 0, stream>>>( <<<dim3(((E + (8 * 16) - 1) / (8 * 16)), (K + 31) / 32),
src_data, index_info, out_data, arg_out_data, E); dim3(32, 8), 0, stream>>>(src_data, index_info, out_data,
} else if (avg_len <= 8) arg_out_data, E, K);
segment_coo_broadcast_kernel<scalar_t, ADD, 4> } else {
<<<dim3(((E + (8 * 4) - 1) / (8 * 4)), (K + 31) / 32), dim3(32, 8), 0, segment_coo_broadcast_kernel<scalar_t, REDUCE, 32>
stream>>>(src_data, index_info, out_data, nullptr, E, K); <<<dim3(((E + (8 * 32) - 1) / (8 * 32)), (K + 31) / 32),
else if (avg_len <= 16) dim3(32, 8), 0, stream>>>(src_data, index_info, out_data,
segment_coo_broadcast_kernel<scalar_t, ADD, 8> arg_out_data, E, K);
<<<dim3(((E + (8 * 8) - 1) / (8 * 8)), (K + 31) / 32), dim3(32, 8), 0, }
stream>>>(src_data, index_info, out_data, nullptr, E, K); });
else if (avg_len <= 32)
segment_coo_broadcast_kernel<scalar_t, ADD, 16>
<<<dim3(((E + (8 * 16) - 1) / (8 * 16)), (K + 31) / 32), dim3(32, 8),
0, stream>>>(src_data, index_info, out_data, nullptr, E, K);
else
segment_coo_broadcast_kernel<scalar_t, ADD, 32>
<<<dim3(((E + (8 * 32) - 1) / (8 * 32)), (K + 31) / 32), dim3(32, 8),
0, stream>>>(src_data, index_info, out_data, nullptr, E, K);
}); });
if (reduce == "mean") { if (reduce == "mean") {
AT_ASSERTM(false); // TODO: DIVIDE ENTRIES. auto count = at::empty_like(index, out.options());
AT_DISPATCH_ALL_TYPES(out.scalar_type(), "count_kernel", [&] {
auto count_data = count.DATA_PTR<scalar_t>();
AT_ASSERTM(false); // TODO
});
out = out / count;
arg_out = count;
} }
return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
......
...@@ -10,105 +10,22 @@ dtypes = [torch.float] ...@@ -10,105 +10,22 @@ dtypes = [torch.float]
devices = [torch.device('cuda')] devices = [torch.device('cuda')]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype,device', product(dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward(dtype, device): def test_forward(dtype, device):
# src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype, src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype,
# device) device)
src = tensor([1, 2, 3, 4, 5, 6], dtype, device) src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
src.requires_grad_()
indptr = tensor([0, 2, 5, 5, 6], torch.long, device) indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
out = segment_csr(src, indptr, reduce='add') out = segment_csr(src, indptr, reduce='max')
out = out[0] if isinstance(out, tuple) else out
print('CSR', out) print('CSR', out)
out.backward(torch.randn_like(out))
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device) index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = segment_coo(src, index, reduce='add') out = segment_coo(src, index, reduce='add')
print('COO', out) print('COO', out)
# @pytest.mark.parametrize('dtype,device', product(dtypes, devices))
# def test_benchmark(dtype, device):
# from torch_geometric.datasets import Planetoid, Reddit # noqa
# # data = Planetoid('/tmp/Cora', 'Cora')[0].to(device)
# data = Planetoid('/tmp/PubMed', 'PubMed')[0].to(device)
# row, col = data.edge_index
# print(data.num_edges)
# print(row.size(0) / data.num_nodes)
# num_repeats = 1
# row = row.view(-1, 1).repeat(1, num_repeats).view(-1).contiguous()
# col = col.view(-1, 1).repeat(1, num_repeats).view(-1).contiguous()
# # Warmup
# for _ in range(10):
# torch.randn(100, 100, device=device).sum()
# x = torch.randn(row.size(0), device=device)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out1 = scatter_add(x, row, dim=0, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('Scatter Row', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# scatter_add(x, col, dim=0, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('Scatter Col', time.perf_counter() - t)
# rowcount = segment_add(torch.ones_like(row), row)
# rowptr = torch.cat([rowcount.new_zeros(1), rowcount.cumsum(0)], dim=0)
# torch.cuda.synchronize()
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out3 = segment_add_csr(x, rowptr)
# torch.cuda.synchronize()
# print('CSR', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out4 = segment_add_coo(x, row, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('COO', time.perf_counter() - t)
# assert torch.allclose(out1, out3, atol=1e-2)
# assert torch.allclose(out1, out4, atol=1e-2)
# x = torch.randn((row.size(0), 64), device=device)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out5 = scatter_add(x, row, dim=0, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('Scatter Row + Dim', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# scatter_add(x, col, dim=0, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('Scatter Col + Dim', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out6 = segment_add_csr(x, rowptr)
# torch.cuda.synchronize()
# print('CSR + Dim', time.perf_counter() - t)
# torch.cuda.synchronize()
# t = time.perf_counter()
# for _ in range(100):
# out7 = segment_add_coo(x, row, dim_size=data.num_nodes)
# torch.cuda.synchronize()
# print('COO + Dim', time.perf_counter() - t)
# assert torch.allclose(out5, out6, atol=1e-2)
# assert torch.allclose(out5, out7, atol=1e-2)
...@@ -4,6 +4,31 @@ if torch.cuda.is_available(): ...@@ -4,6 +4,31 @@ if torch.cuda.is_available():
from torch_scatter import segment_cuda from torch_scatter import segment_cuda
class SegmentCSR(torch.autograd.Function):
@staticmethod
def forward(ctx, src, indptr, out, reduce):
assert reduce in ['add', 'mean', 'min', 'max']
assert indptr.dtype == torch.long
if out is not None:
ctx.mark_dirty(out)
ctx.reduce = reduce
ctx.save_for_backward(src, indptr)
out, arg_out = segment_cuda.segment_csr(src, indptr, out, reduce)
return out if arg_out is None else (out, arg_out)
@staticmethod
def backward(ctx, grad_out, *args):
src, indptr = ctx.saved_tensors
grad_src = None
if ctx.needs_input_grad[0]:
grad_src = src
return grad_src, None, None, None
def segment_coo(src, index, out=None, dim_size=None, reduce='add'): def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
assert reduce in ['add', 'mean', 'min', 'max'] assert reduce in ['add', 'mean', 'min', 'max']
if out is None: if out is None:
...@@ -17,7 +42,10 @@ def segment_coo(src, index, out=None, dim_size=None, reduce='add'): ...@@ -17,7 +42,10 @@ def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
def segment_csr(src, indptr, out=None, reduce='add'): def segment_csr(src, indptr, out=None, reduce='add'):
assert reduce in ['add', 'mean', 'min', 'max'] return SegmentCSR.apply(src, indptr, out, reduce)
assert indptr.dtype == torch.long
out, arg_out = segment_cuda.segment_csr(src, indptr, out, reduce)
return out if arg_out is None else (out, arg_out) # assert reduce in ['add', 'mean', 'min', 'max']
# assert indptr.dtype == torch.long
# out, arg_out = segment_cuda.segment_csr(src, indptr, out, reduce)
# return out if arg_out is None else (out, arg_out)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment