test_zero_tensors.py 1.2 KB
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
from itertools import product

import pytest
4
import torch
Matthias Fey's avatar
Matthias Fey committed
5
6
7
from torch_scatter import (gather_coo, gather_csr, scatter, segment_coo,
                           segment_csr)
from torch_scatter.testing import devices, grad_dtypes, reductions, tensor
rusty1s's avatar
rusty1s committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24


@pytest.mark.parametrize('reduce,dtype,device',
                         product(reductions, grad_dtypes, devices))
def test_zero_elements(reduce, dtype, device):
    x = torch.randn(0, 0, 0, 16, dtype=dtype, device=device,
                    requires_grad=True)
    index = tensor([], torch.long, device)
    indptr = tensor([], torch.long, device)

    out = scatter(x, index, dim=0, dim_size=0, reduce=reduce)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)

    out = segment_coo(x, index, dim_size=0, reduce=reduce)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)
25

rusty1s's avatar
rusty1s committed
26
27
28
    out = gather_coo(x, index)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)
29

rusty1s's avatar
rusty1s committed
30
31
32
    out = segment_csr(x, indptr, reduce=reduce)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)
33

rusty1s's avatar
rusty1s committed
34
35
36
    out = gather_csr(x, indptr)
    out.backward(torch.randn_like(out))
    assert out.size() == (0, 0, 0, 16)