test_broadcasting.py 617 Bytes
Newer Older
rusty1s's avatar
rusty1s committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import pytest
import torch
from torch_scatter import scatter_add

from .utils import devices


@pytest.mark.parametrize('device', devices)
def test_broadcasting(device):
    B, C, H, W = (4, 3, 8, 8)

    src = torch.randn((B, C, H, W), device=device)
    index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
    out = scatter_add(src, index, dim=2, dim_size=H)
    assert out.size() == (B, C, H, W)

    src = torch.randn((B, C, H, W), device=device)
    index = torch.randint(0, H, (H, )).to(device, torch.long)
    out = scatter_add(src, index, dim=2, dim_size=H)
    assert out.size() == (B, C, H, W)