test_segment.py 1.19 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from itertools import product

import pytest
import torch
rusty1s's avatar
rusty1s committed
5
from torch_scatter import segment_coo, segment_csr
rusty1s's avatar
linting  
rusty1s committed
6
from torch_scatter import scatter_add, scatter_mean, scatter_min  # noqa
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13

from .utils import tensor

dtypes = [torch.float]
devices = [torch.device('cuda')]


rusty1s's avatar
rusty1s committed
14
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
rusty1s's avatar
rusty1s committed
15
16
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward(dtype, device):
rusty1s's avatar
rusty1s committed
17
18
    src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype,
                 device)
rusty1s's avatar
rusty1s committed
19

rusty1s's avatar
rusty1s committed
20
21
    src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
    src.requires_grad_()
rusty1s's avatar
rusty1s committed
22
    indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
rusty1s's avatar
rusty1s committed
23
24
    index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)

rusty1s's avatar
rusty1s committed
25
26
27
28
29
    # out = scatter_min(src, index, dim=0)[0]
    # grad_out = torch.randn_like(out)
    # print(grad_out)
    # out.backward(grad_out)
    # print(src.grad)
rusty1s's avatar
rusty1s committed
30
31

    src.grad = None
rusty1s's avatar
rusty1s committed
32
33
34
35
    out = segment_csr(src, indptr, reduce='mean')
    print('CSR', out)
    # out.backward(grad_out)
    # print(src.grad)
rusty1s's avatar
rusty1s committed
36
    # out = out[0] if isinstance(out, tuple) else out
rusty1s's avatar
atomics  
rusty1s committed
37

rusty1s's avatar
rusty1s committed
38
    # out.backward(torch.randn_like(out))
rusty1s's avatar
rusty1s committed
39

rusty1s's avatar
rusty1s committed
40
    out = segment_coo(src, index, reduce='mean')
rusty1s's avatar
linting  
rusty1s committed
41
    print('COO', out)