Commit d3aabdf3 authored by rusty1s's avatar rusty1s
Browse files

fix negative dim in scatter_mean

parent ff3be8e3
......@@ -71,6 +71,7 @@ public:
Variable index, int64_t dim,
torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes();
......
from itertools import product
import pytest
import torch
from torch_scatter import scatter_add
from torch_scatter import scatter
from .utils import devices
from .utils import reductions, devices
@pytest.mark.parametrize('device', devices)
def test_broadcasting(device):
@pytest.mark.parametrize('reduce,device', product(reductions, devices))
def test_broadcasting(reduce, device):
B, C, H, W = (4, 3, 8, 8)
src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (H, )).to(device, torch.long)
out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
assert out.size() == (B, C, H, W)
src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H)
out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
assert out.size() == (B, C, H, W)
src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (H, )).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H)
out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
assert out.size() == (B, C, H, W)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment