"vscode:/vscode.git/clone" did not exist on "4a1f511685387868c72b8860135bfaf5c8deec79"
Commit bed12976 authored by rusty1s's avatar rusty1s
Browse files

backward passes

parent 04fe0806
......@@ -3,6 +3,7 @@ from itertools import product
import pytest
import torch
from torch_scatter import segment_coo, segment_csr
from torch_scatter import scatter_add, scatter_mean, scatter_max, scatter_min
from .utils import tensor
......@@ -13,19 +14,30 @@ devices = [torch.device('cuda')]
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward(dtype, device):
src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype,
device)
# src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
# src.requires_grad_()
# src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype,
# device)
src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
src.requires_grad_()
indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
out = segment_csr(src, indptr, reduce='any')
print('CSR', out)
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = tensor([0, 0, 0, 0], dtype, device)
out.scatter_(0, index, src)
out = scatter_min(src, index, dim=0)[0]
grad_out = torch.randn_like(out)
print(grad_out)
out.backward(grad_out)
print(src.grad)
src.grad = None
out = segment_csr(src, indptr, reduce='min')[0]
out.backward(grad_out)
print(src.grad)
# out = out[0] if isinstance(out, tuple) else out
# out.backward(torch.randn_like(out))
index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
out = segment_coo(src, index, reduce='any')
print('COO', out)
# out = segment_coo(src, index, reduce='any')
# print('COO', out)
import torch
if torch.cuda.is_available():
from torch_scatter import segment_cuda
from torch_scatter import segment_cuda, gather_cuda
class SegmentCSR(torch.autograd.Function):
......@@ -12,25 +12,43 @@ class SegmentCSR(torch.autograd.Function):
if out is not None:
ctx.mark_dirty(out)
ctx.reduce = reduce
ctx.save_for_backward(src, indptr)
ctx.src_size = list(src.size())
out, arg_out = segment_cuda.segment_csr(src, indptr, out, reduce)
ctx.save_for_backward(indptr, arg_out)
return out if arg_out is None else (out, arg_out)
@staticmethod
def backward(ctx, grad_out, *args):
src, indptr = ctx.saved_tensors
(indptr, arg_out), src_size = ctx.saved_tensors, ctx.src_size
grad_src = None
if ctx.needs_input_grad[0]:
grad_src = src
if ctx.reduce == 'any' or ctx.reduce == 'add':
grad_src = gather_cuda.gather_csr(grad_out, indptr,
grad_out.new_empty(src_size))
elif ctx.reduce == 'mean':
grad_src = gather_cuda.gather_csr(grad_out, indptr,
grad_out.new_empty(src_size))
indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1)
indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1)
count = (indptr2 - indptr1).to(grad_src.dtype)
count = gather_cuda.gather_csr(
count, indptr, count.new_empty(src_size[:indptr.dim()]))
grad_src.div_(count)
elif ctx.reduce == 'min' or ctx.reduce == 'max':
src_size[indptr.dim() - 1] += 1
grad_src = grad_out.new_zeros(src_size).scatter_(
indptr.dim() - 1, arg_out, grad_out)
grad_src = grad_src.narrow(indptr.dim() - 1, 0,
src_size[indptr.dim() - 1] - 1)
return grad_src, None, None, None
def segment_coo(src, index, out=None, dim_size=None, reduce='add'):
assert reduce in ['any', 'add', 'mean', 'min', 'max']
if out is None: # TODO: MOVE TO CPP
if out is None:
dim_size = index.max().item() + 1 if dim_size is None else dim_size
size = list(src.size())
size[index.dim() - 1] = dim_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment