test_scatter.py 5.38 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
from itertools import product

import pytest
import torch
from torch.autograd import gradcheck
import torch_scatter

rusty1s's avatar
rusty1s committed
8
from .utils import reductions, tensor, dtypes, devices
rusty1s's avatar
rusty1s committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

tests = [
    {
        'src': [1, 3, 2, 4, 5, 6],
        'index': [0, 1, 0, 1, 1, 3],
        'dim': 0,
        'sum': [3, 12, 0, 6],
        'add': [3, 12, 0, 6],
        'mean': [1.5, 4, 0, 6],
        'min': [1, 3, 0, 6],
        'arg_min': [0, 1, 6, 5],
        'max': [2, 5, 0, 6],
        'arg_max': [2, 4, 6, 5],
    },
    {
        'src': [[1, 2], [5, 6], [3, 4], [7, 8], [9, 10], [11, 12]],
        'index': [0, 1, 0, 1, 1, 3],
        'dim': 0,
        'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
        'add': [[4, 6], [21, 24], [0, 0], [11, 12]],
        'mean': [[2, 3], [7, 8], [0, 0], [11, 12]],
        'min': [[1, 2], [5, 6], [0, 0], [11, 12]],
        'arg_min': [[0, 0], [1, 1], [6, 6], [5, 5]],
        'max': [[3, 4], [9, 10], [0, 0], [11, 12]],
        'arg_max': [[2, 2], [4, 4], [6, 6], [5, 5]],
    },
    {
        'src': [[1, 5, 3, 7, 9, 11], [2, 4, 8, 6, 10, 12]],
        'index': [[0, 1, 0, 1, 1, 3], [0, 0, 1, 0, 1, 2]],
        'dim': 1,
        'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
        'add': [[4, 21, 0, 11], [12, 18, 12, 0]],
        'mean': [[2, 7, 0, 11], [4, 9, 12, 0]],
        'min': [[1, 5, 0, 11], [2, 8, 12, 0]],
        'arg_min': [[0, 1, 6, 5], [0, 2, 5, 6]],
        'max': [[3, 9, 0, 11], [6, 10, 12, 0]],
        'arg_max': [[2, 4, 6, 5], [3, 4, 5, 6]],
    },
    {
        'src': [[[1, 2], [5, 6], [3, 4]], [[10, 11], [7, 9], [12, 13]]],
        'index': [[0, 1, 0], [2, 0, 2]],
        'dim': 1,
        'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
        'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
        'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]],
        'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]],
        'arg_min': [[[0, 0], [1, 1], [3, 3]], [[1, 1], [3, 3], [0, 0]]],
        'max': [[[3, 4], [5, 6], [0, 0]], [[7, 9], [0, 0], [12, 13]]],
        'arg_max': [[[2, 2], [1, 1], [3, 3]], [[1, 1], [3, 3], [2, 2]]],
    },
    {
        'src': [[1, 3], [2, 4]],
        'index': [[0, 0], [0, 0]],
        'dim': 1,
        'sum': [[4], [6]],
        'add': [[4], [6]],
        'mean': [[2], [3]],
        'min': [[1], [2]],
        'arg_min': [[0], [0]],
        'max': [[3], [4]],
        'arg_max': [[1], [1]],
    },
    {
        'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
        'index': [[0, 0], [0, 0]],
        'dim': 1,
        'sum': [[[4, 4]], [[6, 6]]],
        'add': [[[4, 4]], [[6, 6]]],
        'mean': [[[2, 2]], [[3, 3]]],
        'min': [[[1, 1]], [[2, 2]]],
        'arg_min': [[[0, 0]], [[0, 0]]],
        'max': [[[3, 3]], [[4, 4]]],
        'arg_max': [[[1, 1]], [[1, 1]]],
    },
]


@pytest.mark.parametrize('test,reduce,dtype,device',
                         product(tests, reductions, dtypes, devices))
def test_forward(test, reduce, dtype, device):
    src = tensor(test['src'], dtype, device)
    index = tensor(test['index'], torch.long, device)
    dim = test['dim']
    expected = tensor(test[reduce], dtype, device)

rusty1s's avatar
rusty1s committed
94
    out = getattr(torch_scatter, 'scatter_' + reduce)(src, index, dim)
rusty1s's avatar
rusty1s committed
95
96
    if isinstance(out, tuple):
        out, arg_out = out
rusty1s's avatar
rusty1s committed
97
        arg_expected = tensor(test['arg_' + reduce], torch.long, device)
rusty1s's avatar
rusty1s committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        assert torch.all(arg_out == arg_expected)
    assert torch.all(out == expected)


@pytest.mark.parametrize('test,reduce,device',
                         product(tests, reductions, devices))
def test_backward(test, reduce, device):
    src = tensor(test['src'], torch.double, device)
    src.requires_grad_()
    index = tensor(test['index'], torch.long, device)
    dim = test['dim']

    assert gradcheck(torch_scatter.scatter,
                     (src, index, dim, None, None, reduce))


@pytest.mark.parametrize('test,reduce,dtype,device',
                         product(tests, reductions, dtypes, devices))
def test_out(test, reduce, dtype, device):
    src = tensor(test['src'], dtype, device)
    index = tensor(test['index'], torch.long, device)
    dim = test['dim']
    expected = tensor(test[reduce], dtype, device)

    out = torch.full_like(expected, -2)

rusty1s's avatar
rusty1s committed
124
    getattr(torch_scatter, 'scatter_' + reduce)(src, index, dim, out)
rusty1s's avatar
rusty1s committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152

    if reduce == 'sum' or reduce == 'add':
        expected = expected - 2
    elif reduce == 'mean':
        expected = out  # We can not really test this here.
    elif reduce == 'min':
        expected = expected.fill_(-2)
    elif reduce == 'max':
        expected[expected == 0] = -2
    else:
        raise ValueError

    assert torch.all(out == expected)


@pytest.mark.parametrize('test,reduce,dtype,device',
                         product(tests, reductions, dtypes, devices))
def test_non_contiguous(test, reduce, dtype, device):
    src = tensor(test['src'], dtype, device)
    index = tensor(test['index'], torch.long, device)
    dim = test['dim']
    expected = tensor(test[reduce], dtype, device)

    if src.dim() > 1:
        src = src.transpose(0, 1).contiguous().transpose(0, 1)
    if index.dim() > 1:
        index = index.transpose(0, 1).contiguous().transpose(0, 1)

rusty1s's avatar
rusty1s committed
153
    out = getattr(torch_scatter, 'scatter_' + reduce)(src, index, dim)
rusty1s's avatar
rusty1s committed
154
155
    if isinstance(out, tuple):
        out, arg_out = out
rusty1s's avatar
rusty1s committed
156
        arg_expected = tensor(test['arg_' + reduce], torch.long, device)
rusty1s's avatar
rusty1s committed
157
158
        assert torch.all(arg_out == arg_expected)
    assert torch.all(out == expected)