Commit 29e28847 authored by rusty1s's avatar rusty1s
Browse files

grad impl

parent a80a8ba7
......@@ -2,33 +2,7 @@ from nose.tools import assert_equal
import torch
from torch.autograd import Variable
from torch_scatter._ext import ffi
class ScatterAdd(torch.autograd.Function):
def __init__(self, dim):
super(ScatterAdd, self).__init__()
self.dim = dim
def forward(self, output, index, input):
assert not self.needs_input_grad[1], 'Can\'t differentiate the index'
self.mark_dirty(output)
self.save_for_backward(index)
ffi.scatter_add_Float(output, index, input, self.dim)
return output
def backward(self, grad):
index, = self.saved_variables
grad_output = grad_input = None
if self.needs_input_grad[0]:
grad_output = grad
if self.needs_input_grad[2]:
grad_input = grad.gather(self.dim, index.data)
return grad_output, None, grad_input
from torch_scatter import scatter_add_, scatter_add
def test_scatter_add():
......@@ -39,22 +13,29 @@ def test_scatter_add():
output = input.new(2, 6).fill_(0)
expected_output = [[0, 0, 4, 3, 3, 0], [2, 4, 4, 0, 0, 0]]
ffi.scatter_add_Float(output, index, input, 1)
scatter_add_(output, index, input, dim=1)
assert_equal(output.tolist(), expected_output)
output = scatter_add(index, input, dim=1)
assert_equal(output.tolist(), expected_output)
output = Variable(output)
output = Variable(output).fill_(0)
index = Variable(index)
input = Variable(input, requires_grad=True)
scatter_add_(output, index, input, dim=1)
c = output.sum()
c.backward()
# # a = input * 2
# # b = output * 2
# a = input * 2
# b = output * 2
a = input * 2
b = output * 2
ScatterAdd(1)(b, index, a)
# b.scatter_add_(1, index, a)
# ScatterAdd(1)(b, index, a)
# # b.scatter_add_(1, index, a)
c = b.sum()
c.backward()
# c = b.sum()
# c.backward()
print(input.grad)
print(output.grad)
# print(input.grad)
# print(output.grad)
import torch
from torch.autograd import Variable
def test_grad():
input = [[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]
index = [[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]
input = torch.FloatTensor(input)
index = torch.LongTensor(index)
output = input.new(2, 6).fill_(0)
output = Variable(output)
index = Variable(index)
input = Variable(input, requires_grad=True)
output.scatter_add_(1, index, input)
c = output.mean()
c.backward()
# print(index.grad)
from .functions import (scatter_add_, scatter_add, scatter_sub, scatter_sub_,
scatter_mul, scatter_mul_, scatter_div, scatter_div_)
__all__ = [
'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div'
]
......@@ -36,9 +36,3 @@ def scatter_div_(output, index, input, dim=0):
def scatter_div(index, input, dim=0, max_index=None, fill_value=1):
output = gen_output(index, input, dim, max_index, fill_value)
return scatter_div_(output, index, input, dim)
__all__ = [
'scatter_add_', 'scatter_add', 'scatter_sub_', 'scatter_sub',
'scatter_mul_', 'scatter_mul', 'scatter_div_', 'scatter_div'
]
......@@ -5,7 +5,7 @@ from .._ext import ffi
def _scatter(name, output, index, input, dim):
typename = type.__name__.replace('Tensor', '')
typename = type(input).__name__.replace('Tensor', '')
func = getattr(ffi, 'scatter_{}_{}'.format(name, typename))
func(output, index, input, dim)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment