test_segment.py 1.42 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from itertools import product

import pytest
import torch
rusty1s's avatar
rusty1s committed
5
from torch_scatter import segment_coo, segment_csr
rusty1s's avatar
rusty1s committed
6
from torch_scatter import scatter_max
rusty1s's avatar
rusty1s committed
7
8
9
10
11
12
13

from .utils import tensor

dtypes = [torch.float]
devices = [torch.device('cuda')]


rusty1s's avatar
rusty1s committed
14
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
rusty1s's avatar
rusty1s committed
15
16
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward(dtype, device):
rusty1s's avatar
rusty1s committed
17
18
    src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype,
                 device)
rusty1s's avatar
rusty1s committed
19

rusty1s's avatar
rusty1s committed
20
    src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
rusty1s's avatar
rusty1s committed
21
22
23

    # src = tensor([-1, -2, -3, -4, -5, -6], dtype, device)

rusty1s's avatar
rusty1s committed
24
    src.requires_grad_()
rusty1s's avatar
rusty1s committed
25
    indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
rusty1s's avatar
rusty1s committed
26
27
    index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)

rusty1s's avatar
rusty1s committed
28
29
30
31
32
    out, arg = scatter_max(src, index, dim=0)
    print('SCA')
    print(out)
    print(arg)
    # print('SCA', out)
rusty1s's avatar
rusty1s committed
33
34
35
36
    # grad_out = torch.randn_like(out)
    # print(grad_out)
    # out.backward(grad_out)
    # print(src.grad)
rusty1s's avatar
rusty1s committed
37

rusty1s's avatar
rusty1s committed
38
39
40
41
42
43
44
45
46
47
48
    # src.grad = None
    out, arg = segment_coo(src, index, reduce='max')
    print('COO')
    print(out)
    print(arg)

    out, arg = segment_csr(src, indptr, reduce='max')
    print('CSR')
    print(out)
    print(arg)

rusty1s's avatar
rusty1s committed
49
50
    # out.backward(grad_out)
    # print(src.grad)
rusty1s's avatar
rusty1s committed
51
    # out = out[0] if isinstance(out, tuple) else out
rusty1s's avatar
atomics  
rusty1s committed
52

rusty1s's avatar
rusty1s committed
53
    # out.backward(torch.randn_like(out))
rusty1s's avatar
rusty1s committed
54

rusty1s's avatar
rusty1s committed
55
56
    # out = segment_coo(src, index, reduce='max')[0]
    # print('COO', out)