test_broadcasting.py 930 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
from itertools import product

rusty1s's avatar
rusty1s committed
3
4
import pytest
import torch
rusty1s's avatar
rusty1s committed
5
from torch_scatter import scatter
rusty1s's avatar
rusty1s committed
6

rusty1s's avatar
rusty1s committed
7
from .utils import reductions, devices
rusty1s's avatar
rusty1s committed
8
9


rusty1s's avatar
rusty1s committed
10
11
@pytest.mark.parametrize('reduce,device', product(reductions, devices))
def test_broadcasting(reduce, device):
rusty1s's avatar
rusty1s committed
12
13
    B, C, H, W = (4, 3, 8, 8)

rusty1s's avatar
rusty1s committed
14
15
16
17
18
    src = torch.randn((B, C, H, W), device=device)
    index = torch.randint(0, H, (H, )).to(device, torch.long)
    out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
    assert out.size() == (B, C, H, W)

rusty1s's avatar
rusty1s committed
19
20
    src = torch.randn((B, C, H, W), device=device)
    index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
rusty1s's avatar
rusty1s committed
21
    out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
rusty1s's avatar
rusty1s committed
22
23
24
25
    assert out.size() == (B, C, H, W)

    src = torch.randn((B, C, H, W), device=device)
    index = torch.randint(0, H, (H, )).to(device, torch.long)
rusty1s's avatar
rusty1s committed
26
    out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
rusty1s's avatar
rusty1s committed
27
    assert out.size() == (B, C, H, W)