test_segment.py 6.93 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from itertools import product

import pytest
import torch
rusty1s's avatar
rusty1s committed
5
from torch.autograd import gradcheck
rusty1s's avatar
rusty1s committed
6
import torch_scatter
rusty1s's avatar
rusty1s committed
7

rusty1s's avatar
rusty1s committed
8
from .utils import reductions, tensor, dtypes, devices
rusty1s's avatar
rusty1s committed
9
10
11
12
13
14

tests = [
    {
        'src': [1, 2, 3, 4, 5, 6],
        'index': [0, 0, 1, 1, 1, 3],
        'indptr': [0, 2, 5, 5, 6],
rusty1s's avatar
rusty1s committed
15
        'sum': [3, 12, 0, 6],
rusty1s's avatar
rusty1s committed
16
        'add': [3, 12, 0, 6],
rusty1s's avatar
rusty1s committed
17
18
19
20
21
22
23
24
25
26
        'mean': [1.5, 4, 0, 6],
        'min': [1, 3, 0, 6],
        'arg_min': [0, 2, 6, 5],
        'max': [2, 5, 0, 6],
        'arg_max': [1, 4, 6, 5],
    },
    {
        'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]],
        'index': [0, 0, 1, 1, 1, 3],
        'indptr': [0, 2, 5, 5, 6],
rusty1s's avatar
rusty1s committed
27
        'sum': [[4, 6], [21, 24], [0, 0], [11, 12]],
rusty1s's avatar
rusty1s committed
28
        'add': [[4, 6], [21, 24], [0, 0], [11, 12]],
rusty1s's avatar
rusty1s committed
29
30
31
32
33
34
35
36
37
38
        'mean': [[2, 3], [7, 8], [0, 0], [11, 12]],
        'min': [[1, 2], [5, 6], [0, 0], [11, 12]],
        'arg_min': [[0, 0], [2, 2], [6, 6], [5, 5]],
        'max': [[3, 4], [9, 10], [0, 0], [11, 12]],
        'arg_max': [[1, 1], [4, 4], [6, 6], [5, 5]],
    },
    {
        'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]],
        'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]],
        'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]],
rusty1s's avatar
rusty1s committed
39
        'sum': [[4, 21, 0, 11], [12, 18, 12, 0]],
rusty1s's avatar
rusty1s committed
40
        'add': [[4, 21, 0, 11], [12, 18, 12, 0]],
rusty1s's avatar
rusty1s committed
41
42
43
44
45
46
47
        'mean': [[2, 7, 0, 11], [4, 9, 12, 0]],
        'min': [[1, 5, 0, 11], [2, 8, 12, 0]],
        'arg_min': [[0, 2, 6, 5], [0, 3, 5, 6]],
        'max': [[3, 9, 0, 11], [6, 10, 12, 0]],
        'arg_max': [[1, 4, 6, 5], [2, 4, 5, 6]],
    },
    {
rusty1s's avatar
rusty1s committed
48
49
50
        'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]],
        'index': [[0, 0, 1], [0, 2, 2]],
        'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]],
rusty1s's avatar
rusty1s committed
51
        'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
rusty1s's avatar
rusty1s committed
52
        'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]],
rusty1s's avatar
rusty1s committed
53
54
55
56
57
        'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]],
        'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]],
        'arg_min': [[[0, 0], [2, 2], [3, 3]], [[0, 0], [3, 3], [1, 1]]],
        'max': [[[3, 4], [5, 6], [0, 0]], [[7, 9], [0, 0], [12, 13]]],
        'arg_max': [[[1, 1], [2, 2], [3, 3]], [[0, 0], [3, 3], [2, 2]]],
rusty1s's avatar
rusty1s committed
58
59
60
61
62
    },
    {
        'src': [[1, 3], [2, 4]],
        'index': [[0, 0], [0, 0]],
        'indptr': [[0, 2], [0, 2]],
rusty1s's avatar
rusty1s committed
63
        'sum': [[4], [6]],
rusty1s's avatar
rusty1s committed
64
        'add': [[4], [6]],
rusty1s's avatar
rusty1s committed
65
66
67
68
69
70
71
72
73
74
        'mean': [[2], [3]],
        'min': [[1], [2]],
        'arg_min': [[0], [0]],
        'max': [[3], [4]],
        'arg_max': [[1], [1]],
    },
    {
        'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]],
        'index': [[0, 0], [0, 0]],
        'indptr': [[0, 2], [0, 2]],
rusty1s's avatar
rusty1s committed
75
        'sum': [[[4, 4]], [[6, 6]]],
rusty1s's avatar
rusty1s committed
76
        'add': [[[4, 4]], [[6, 6]]],
rusty1s's avatar
rusty1s committed
77
78
79
80
81
82
83
84
85
86
87
        'mean': [[[2, 2]], [[3, 3]]],
        'min': [[[1, 1]], [[2, 2]]],
        'arg_min': [[[0, 0]], [[0, 0]]],
        'max': [[[3, 3]], [[4, 4]]],
        'arg_max': [[[1, 1]], [[1, 1]]],
    },
]


@pytest.mark.parametrize('test,reduce,dtype,device',
                         product(tests, reductions, dtypes, devices))
rusty1s's avatar
rusty1s committed
88
def test_forward(test, reduce, dtype, device):
rusty1s's avatar
rusty1s committed
89
90
91
92
93
    src = tensor(test['src'], dtype, device)
    index = tensor(test['index'], torch.long, device)
    indptr = tensor(test['indptr'], torch.long, device)
    expected = tensor(test[reduce], dtype, device)

94
95
96
97
98
99
100
    fn = getattr(torch_scatter, 'segment_' + reduce + '_csr')
    jit = torch.jit.script(fn)
    out1 = fn(src, indptr)
    out2 = jit(src, indptr)
    if isinstance(out1, tuple):
        out1, arg_out1 = out1
        out2, arg_out2 = out2
rusty1s's avatar
rusty1s committed
101
        arg_expected = tensor(test['arg_' + reduce], torch.long, device)
102
103
104
105
        assert torch.all(arg_out1 == arg_expected)
        assert arg_out1.tolist() == arg_out2.tolist()
    assert torch.all(out1 == expected)
    assert out1.tolist() == out2.tolist()
rusty1s's avatar
rusty1s committed
106

107
108
109
110
111
112
113
    fn = getattr(torch_scatter, 'segment_' + reduce + '_coo')
    jit = torch.jit.script(fn)
    out1 = fn(src, index)
    out2 = jit(src, index)
    if isinstance(out1, tuple):
        out1, arg_out1 = out1
        out2, arg_out2 = out2
rusty1s's avatar
rusty1s committed
114
        arg_expected = tensor(test['arg_' + reduce], torch.long, device)
115
116
117
118
        assert torch.all(arg_out1 == arg_expected)
        assert arg_out1.tolist() == arg_out2.tolist()
    assert torch.all(out1 == expected)
    assert out1.tolist() == out2.tolist()
rusty1s's avatar
rusty1s committed
119

rusty1s's avatar
rusty1s committed
120

rusty1s's avatar
rusty1s committed
121
@pytest.mark.parametrize('test,reduce,device',
rusty1s's avatar
rusty1s committed
122
                         product(tests, reductions, devices))
rusty1s's avatar
rusty1s committed
123
124
125
126
127
128
def test_backward(test, reduce, device):
    src = tensor(test['src'], torch.double, device)
    src.requires_grad_()
    index = tensor(test['index'], torch.long, device)
    indptr = tensor(test['indptr'], torch.long, device)

rusty1s's avatar
rusty1s committed
129
130
131
    assert gradcheck(torch_scatter.segment_csr, (src, indptr, None, reduce))
    assert gradcheck(torch_scatter.segment_coo,
                     (src, index, None, None, reduce))
rusty1s's avatar
rusty1s committed
132
133


rusty1s's avatar
rusty1s committed
134
135
@pytest.mark.parametrize('test,reduce,dtype,device',
                         product(tests, reductions, dtypes, devices))
rusty1s's avatar
rusty1s committed
136
def test_out(test, reduce, dtype, device):
rusty1s's avatar
rusty1s committed
137
138
139
140
141
    src = tensor(test['src'], dtype, device)
    index = tensor(test['index'], torch.long, device)
    indptr = tensor(test['indptr'], torch.long, device)
    expected = tensor(test[reduce], dtype, device)

rusty1s's avatar
rusty1s committed
142
    out = torch.full_like(expected, -2)
rusty1s's avatar
rusty1s committed
143

rusty1s's avatar
rusty1s committed
144
    getattr(torch_scatter, 'segment_' + reduce + '_csr')(src, indptr, out)
rusty1s's avatar
rusty1s committed
145
146
    assert torch.all(out == expected)

rusty1s's avatar
rusty1s committed
147
    out.fill_(-2)
rusty1s's avatar
rusty1s committed
148

rusty1s's avatar
rusty1s committed
149
    getattr(torch_scatter, 'segment_' + reduce + '_coo')(src, index, out)
rusty1s's avatar
rusty1s committed
150

rusty1s's avatar
rusty1s committed
151
    if reduce == 'sum' or reduce == 'add':
rusty1s's avatar
rusty1s committed
152
153
154
155
156
157
158
159
160
        expected = expected - 2
    elif reduce == 'mean':
        expected = out  # We can not really test this here.
    elif reduce == 'min':
        expected = expected.fill_(-2)
    elif reduce == 'max':
        expected[expected == 0] = -2
    else:
        raise ValueError
rusty1s's avatar
rusty1s committed
161

rusty1s's avatar
rusty1s committed
162
    assert torch.all(out == expected)
rusty1s's avatar
rusty1s committed
163
164
165
166


@pytest.mark.parametrize('test,reduce,dtype,device',
                         product(tests, reductions, dtypes, devices))
rusty1s's avatar
rusty1s committed
167
def test_non_contiguous(test, reduce, dtype, device):
rusty1s's avatar
rusty1s committed
168
169
170
171
172
173
174
175
176
177
178
179
    src = tensor(test['src'], dtype, device)
    index = tensor(test['index'], torch.long, device)
    indptr = tensor(test['indptr'], torch.long, device)
    expected = tensor(test[reduce], dtype, device)

    if src.dim() > 1:
        src = src.transpose(0, 1).contiguous().transpose(0, 1)
    if index.dim() > 1:
        index = index.transpose(0, 1).contiguous().transpose(0, 1)
    if indptr.dim() > 1:
        indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1)

rusty1s's avatar
rusty1s committed
180
    out = getattr(torch_scatter, 'segment_' + reduce + '_csr')(src, indptr)
rusty1s's avatar
rusty1s committed
181
182
    if isinstance(out, tuple):
        out, arg_out = out
rusty1s's avatar
rusty1s committed
183
        arg_expected = tensor(test['arg_' + reduce], torch.long, device)
rusty1s's avatar
rusty1s committed
184
185
        assert torch.all(arg_out == arg_expected)
    assert torch.all(out == expected)
rusty1s's avatar
rusty1s committed
186

rusty1s's avatar
rusty1s committed
187
    out = getattr(torch_scatter, 'segment_' + reduce + '_coo')(src, index)
rusty1s's avatar
rusty1s committed
188
189
    if isinstance(out, tuple):
        out, arg_out = out
rusty1s's avatar
rusty1s committed
190
        arg_expected = tensor(test['arg_' + reduce], torch.long, device)
rusty1s's avatar
rusty1s committed
191
192
        assert torch.all(arg_out == arg_expected)
    assert torch.all(out == expected)