test_zero_tensors.py 229 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
import torch
from torch_scatter import scatter


def test_zero_elements():
    x = torch.randn(0, 16)
    index = torch.tensor([]).view(0, 16)
    print(x)
    print(index)

    scatter(x, index, dim=0, dim_size=0, reduce="add")