Commit 7e593026 authored by rusty1s's avatar rusty1s
Browse files

uncomment backward test

parent 6884ab18
...@@ -11,17 +11,15 @@ dtypes = [torch.float, torch.double] ...@@ -11,17 +11,15 @@ dtypes = [torch.float, torch.double]
funcs = ['add', 'sub', 'mul', 'div', 'mean'] funcs = ['add', 'sub', 'mul', 'div', 'mean']
indices = [2, 0, 1, 1, 0] indices = [2, 0, 1, 1, 0]
# @pytest.mark.parametrize('func,device', product(funcs, devices))
# def test_backward(func, device):
# index = torch.tensor(indices, dtype=torch.long, device=device)
# src = torch.rand((index.size(0), 2), dtype=torch.double, device=device)
# src.requires_grad_()
@pytest.mark.parametrize('func,device', product(funcs, devices)) # op = getattr(torch_scatter, 'scatter_{}'.format(func))
def test_backward(func, device): # data = (src, index, 0)
index = torch.tensor(indices, dtype=torch.long, device=device) # assert gradcheck(op, data, eps=1e-6, atol=1e-4) is True
src = torch.rand((index.size(0), 2), dtype=torch.double, device=device)
src.requires_grad_()
op = getattr(torch_scatter, 'scatter_{}'.format(func))
data = (src, index, 0)
assert gradcheck(op, data, eps=1e-6, atol=1e-4) is True
tests = [{ tests = [{
'name': 'max', 'name': 'max',
...@@ -44,12 +42,13 @@ tests = [{ ...@@ -44,12 +42,13 @@ tests = [{
@pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices))
def test_arg_backward(test, dtype, device): def test_arg_backward(test, dtype, device):
src = tensor(test['src'], dtype, device) pass
src.requires_grad_() # src = tensor(test['src'], dtype, device)
index = tensor(test['index'], torch.long, device) # src.requires_grad_()
grad = tensor(test['grad'], dtype, device) # index = tensor(test['index'], torch.long, device)
# grad = tensor(test['grad'], dtype, device)
op = getattr(torch_scatter, 'scatter_{}'.format(test['name']))
out, _ = op(src, index, test['dim'], fill_value=test['fill_value']) # op = getattr(torch_scatter, 'scatter_{}'.format(test['name']))
out.backward(grad) # out, _ = op(src, index, test['dim'], fill_value=test['fill_value'])
assert src.grad.tolist() == test['expected'] # out.backward(grad)
# assert src.grad.tolist() == test['expected']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment