test_backward.py 622 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from itertools import product

import pytest
import torch
rusty1s's avatar
rusty1s committed
5
from torch.autograd import gradcheck
rusty1s's avatar
rusty1s committed
6
7
import torch_scatter

rusty1s's avatar
rusty1s committed
8
from .utils import devices
rusty1s's avatar
rusty1s committed
9

rusty1s's avatar
rusty1s committed
10
funcs = ['add', 'sub', 'mul', 'div', 'mean']
rusty1s's avatar
rusty1s committed
11
indices = [2, 0, 1, 1, 0]
rusty1s's avatar
rusty1s committed
12
13


rusty1s's avatar
rusty1s committed
14
15
16
17
18
@pytest.mark.parametrize('func,device', product(funcs, devices))
def test_backward(func, device):
    index = torch.tensor(indices, dtype=torch.long, device=device)
    src = torch.rand(index.size(), dtype=torch.double, device=device)
    src.requires_grad_()
rusty1s's avatar
rusty1s committed
19

rusty1s's avatar
rusty1s committed
20
21
22
    op = getattr(torch_scatter, 'scatter_{}'.format(func))
    data = (src, index)
    assert gradcheck(op, data, eps=1e-6, atol=1e-4) is True