test_segment.py 915 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
from itertools import product

import pytest
import torch
rusty1s's avatar
rusty1s committed
5
from torch_scatter import segment_coo, segment_csr
rusty1s's avatar
rusty1s committed
6
7
8
9
10
11
12

from .utils import tensor

dtypes = [torch.float]
devices = [torch.device('cuda')]


rusty1s's avatar
rusty1s committed
13
@pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available')
rusty1s's avatar
rusty1s committed
14
15
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_forward(dtype, device):
rusty1s's avatar
rusty1s committed
16
17
    src = tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], dtype,
                 device)
rusty1s's avatar
rusty1s committed
18

rusty1s's avatar
rusty1s committed
19
20
    # src = tensor([1, 2, 3, 4, 5, 6], dtype, device)
    # src.requires_grad_()
rusty1s's avatar
rusty1s committed
21

rusty1s's avatar
rusty1s committed
22
    indptr = tensor([0, 2, 5, 5, 6], torch.long, device)
rusty1s's avatar
rusty1s committed
23
    out = segment_csr(src, indptr, reduce='any')
rusty1s's avatar
rusty1s committed
24
    print('CSR', out)
rusty1s's avatar
rusty1s committed
25
    # out = out[0] if isinstance(out, tuple) else out
rusty1s's avatar
atomics  
rusty1s committed
26

rusty1s's avatar
rusty1s committed
27
    # out.backward(torch.randn_like(out))
rusty1s's avatar
rusty1s committed
28

rusty1s's avatar
atomics  
rusty1s committed
29
    index = tensor([0, 0, 1, 1, 1, 3], torch.long, device)
rusty1s's avatar
rusty1s committed
30
    out = segment_coo(src, index, reduce='any')
rusty1s's avatar
atomics  
rusty1s committed
31
    print('COO', out)