"git@developer.sourcefind.cn:OpenDAS/dlib.git" did not exist on "67e6957b1e3934bf542afd81b061ade8460ae6f2"
Commit d3aabdf3 authored by rusty1s's avatar rusty1s
Browse files

fix negative dim in scatter_mean

parent ff3be8e3
...@@ -71,6 +71,7 @@ public: ...@@ -71,6 +71,7 @@ public:
Variable index, int64_t dim, Variable index, int64_t dim,
torch::optional<Variable> optional_out, torch::optional<Variable> optional_out,
torch::optional<int64_t> dim_size) { torch::optional<int64_t> dim_size) {
dim = dim < 0 ? src.dim() + dim : dim;
ctx->saved_data["dim"] = dim; ctx->saved_data["dim"] = dim;
ctx->saved_data["src_shape"] = src.sizes(); ctx->saved_data["src_shape"] = src.sizes();
......
from itertools import product
import pytest import pytest
import torch import torch
from torch_scatter import scatter_add from torch_scatter import scatter
from .utils import devices from .utils import reductions, devices
@pytest.mark.parametrize('device', devices) @pytest.mark.parametrize('reduce,device', product(reductions, devices))
def test_broadcasting(device): def test_broadcasting(reduce, device):
B, C, H, W = (4, 3, 8, 8) B, C, H, W = (4, 3, 8, 8)
src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (H, )).to(device, torch.long)
out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
assert out.size() == (B, C, H, W)
src = torch.randn((B, C, H, W), device=device) src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long) index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H) out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
assert out.size() == (B, C, H, W) assert out.size() == (B, C, H, W)
src = torch.randn((B, C, H, W), device=device) src = torch.randn((B, C, H, W), device=device)
index = torch.randint(0, H, (H, )).to(device, torch.long) index = torch.randint(0, H, (H, )).to(device, torch.long)
out = scatter_add(src, index, dim=2, dim_size=H) out = scatter(src, index, dim=2, dim_size=H, reduce=reduce)
assert out.size() == (B, C, H, W) assert out.size() == (B, C, H, W)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment