Commit 88dd792e authored by rusty1s's avatar rusty1s
Browse files

fix zero element tensors

parent bf1f1014
...@@ -29,6 +29,8 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -29,6 +29,8 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
auto sizes = src.sizes().vec(); auto sizes = src.sizes().vec();
if (dim_size.has_value()) if (dim_size.has_value())
sizes[dim] = dim_size.value(); sizes[dim] = dim_size.value();
else if (index.numel() == 0)
sizes[dim] = 0;
else else
sizes[dim] = 1 + *index.max().data_ptr<int64_t>(); sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
out = torch::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
...@@ -41,6 +43,9 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -41,6 +43,9 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
arg_out_data = arg_out.value().data_ptr<int64_t>(); arg_out_data = arg_out.value().data_ptr<int64_t>();
} }
if (index.numel() == 0)
return std::make_tuple(out, arg_out);
auto B = 1; auto B = 1;
for (auto i = 0; i < dim; i++) for (auto i = 0; i < dim; i++)
B *= src.size(i); B *= src.size(i);
......
...@@ -34,6 +34,8 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, ...@@ -34,6 +34,8 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
sizes = src.sizes().vec(); sizes = src.sizes().vec();
if (dim_size.has_value()) if (dim_size.has_value())
sizes[dim] = dim_size.value(); sizes[dim] = dim_size.value();
else if (index.numel() == 0)
sizes[dim] = 0;
else else
sizes[dim] = 1 + *index.max().data_ptr<int64_t>(); sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
out = torch::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
...@@ -44,15 +46,15 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, ...@@ -44,15 +46,15 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options()); arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().data_ptr<int64_t>(); arg_out_data = arg_out.value().data_ptr<int64_t>();
} } else if (reduce2REDUCE.at(reduce) == MEAN) {
torch::optional<torch::Tensor> count = torch::nullopt;
if (reduce2REDUCE.at(reduce) == MEAN) {
auto sizes = index.sizes().vec(); auto sizes = index.sizes().vec();
sizes[dim] = out.size(dim); sizes[dim] = out.size(dim);
count = torch::zeros(sizes, out.options()); arg_out = torch::zeros(sizes, out.options());
} }
if (index.numel() == 0)
return std::make_tuple(out, arg_out);
auto B = index.numel() / src.size(dim); auto B = index.numel() / src.size(dim);
auto E = src.size(dim); auto E = src.size(dim);
auto K = src.numel() / index.numel(); auto K = src.numel() / index.numel();
...@@ -72,7 +74,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, ...@@ -72,7 +74,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
if (!optional_out.has_value()) if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init()); out.fill_(Reducer<scalar_t, REDUCE>::init());
if (REDUCE == MEAN) if (REDUCE == MEAN)
count_data = count.value().data_ptr<scalar_t>(); count_data = arg_out.value().data_ptr<scalar_t>();
for (auto b = 0; b < B; b++) { for (auto b = 0; b < B; b++) {
auto offset = IndexToOffset<int64_t>::get(b * E, index_info); auto offset = IndexToOffset<int64_t>::get(b * E, index_info);
...@@ -122,7 +124,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, ...@@ -122,7 +124,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0); out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
if (REDUCE == MEAN) if (REDUCE == MEAN)
arg_out = count; arg_out.value().clamp_(1);
}); });
}); });
...@@ -156,6 +158,9 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index, ...@@ -156,6 +158,9 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
out = torch::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
} }
if (index.numel() == 0)
return out;
auto B = index.numel() / out.size(dim); auto B = index.numel() / out.size(dim);
auto E = index.size(dim); auto E = index.size(dim);
auto K = out.numel() / index.numel(); auto K = out.numel() / index.numel();
......
...@@ -30,10 +30,10 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, ...@@ -30,10 +30,10 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
for (auto i = 0; i < out.dim(); i++) for (auto i = 0; i < out.dim(); i++)
if (i != dim) if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i)); CHECK_INPUT(src.size(i) == out.size(i));
CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1); CHECK_INPUT(src.numel() == 0 || out.size(dim) == indptr.size(dim) - 1);
} else { } else {
sizes = src.sizes().vec(); sizes = src.sizes().vec();
sizes[dim] = indptr.size(dim) - 1; sizes[dim] = std::max<int64_t>(indptr.size(dim) - 1, 0);
out = torch::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
} }
...@@ -44,6 +44,9 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, ...@@ -44,6 +44,9 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
arg_out_data = arg_out.value().data_ptr<int64_t>(); arg_out_data = arg_out.value().data_ptr<int64_t>();
} }
if (src.numel() == 0)
return std::make_tuple(out, arg_out);
auto N = out.size(dim) * (indptr.numel() / indptr.size(-1)); auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N; auto K = out.numel() / N;
auto E = src.size(dim); auto E = src.size(dim);
...@@ -98,7 +101,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr, ...@@ -98,7 +101,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
indptr = indptr.expand(sizes); indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1; auto dim = indptr.dim() - 1;
CHECK_INPUT(src.size(dim) == indptr.size(dim) - 1); CHECK_INPUT(src.size(dim) == 0 || src.size(dim) == indptr.size(dim) - 1);
src = src.contiguous(); src = src.contiguous();
...@@ -110,10 +113,16 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr, ...@@ -110,10 +113,16 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
CHECK_INPUT(src.size(i) == out.size(i)); CHECK_INPUT(src.size(i) == out.size(i));
} else { } else {
auto sizes = src.sizes().vec(); auto sizes = src.sizes().vec();
if (src.numel() > 0)
sizes[dim] = *indptr.flatten()[-1].data_ptr<int64_t>(); sizes[dim] = *indptr.flatten()[-1].data_ptr<int64_t>();
else
sizes[dim] = 0;
out = torch::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
} }
if (src.numel() == 0)
return out;
auto N = src.size(dim) * (indptr.numel() / indptr.size(-1)); auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N; auto K = src.numel() / N;
auto E = out.size(dim); auto E = out.size(dim);
......
...@@ -81,6 +81,8 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -81,6 +81,8 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
auto sizes = src.sizes().vec(); auto sizes = src.sizes().vec();
if (dim_size.has_value()) if (dim_size.has_value())
sizes[dim] = dim_size.value(); sizes[dim] = dim_size.value();
else if (index.numel() == 0)
sizes[dim] = 0;
else { else {
auto d_size = index.max().data_ptr<int64_t>(); auto d_size = index.max().data_ptr<int64_t>();
auto h_size = (int64_t *)malloc(sizeof(int64_t)); auto h_size = (int64_t *)malloc(sizeof(int64_t));
...@@ -97,6 +99,9 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -97,6 +99,9 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
arg_out_data = arg_out.value().data_ptr<int64_t>(); arg_out_data = arg_out.value().data_ptr<int64_t>();
} }
if (index.numel() == 0)
return std::make_tuple(out, arg_out);
auto B = 1; auto B = 1;
for (auto i = 0; i < dim; i++) for (auto i = 0; i < dim; i++)
B *= src.size(i); B *= src.size(i);
......
...@@ -181,6 +181,8 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, ...@@ -181,6 +181,8 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
sizes = src.sizes().vec(); sizes = src.sizes().vec();
if (dim_size.has_value()) if (dim_size.has_value())
sizes[dim] = dim_size.value(); sizes[dim] = dim_size.value();
else if (index.numel() == 0)
sizes[dim] = 0;
else { else {
auto d_size = index.max().data_ptr<int64_t>(); auto d_size = index.max().data_ptr<int64_t>();
auto h_size = (int64_t *)malloc(sizeof(int64_t)); auto h_size = (int64_t *)malloc(sizeof(int64_t));
...@@ -195,8 +197,15 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, ...@@ -195,8 +197,15 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options()); arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().data_ptr<int64_t>(); arg_out_data = arg_out.value().data_ptr<int64_t>();
} else if (reduce2REDUCE.at(reduce) == MEAN) {
auto sizes = index.sizes().vec();
sizes[dim] = out.size(dim);
arg_out = torch::zeros(sizes, out.options());
} }
if (index.numel() == 0)
return std::make_tuple(out, arg_out);
auto E = index.numel(); auto E = index.numel();
auto E_2 = index.size(dim); auto E_2 = index.size(dim);
auto E_1 = index.numel() / E_2; auto E_1 = index.numel() / E_2;
...@@ -254,17 +263,15 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index, ...@@ -254,17 +263,15 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
} }
if (REDUCE == MEAN) { if (REDUCE == MEAN) {
auto sizes = index.sizes().vec(); auto count_data = arg_out.value().data_ptr<scalar_t>();
sizes[dim] = out.size(dim);
auto count = torch::zeros(sizes, out.options());
auto count_data = count.data_ptr<scalar_t>();
segment_coo_kernel<scalar_t, SUM, false> segment_coo_kernel<scalar_t, SUM, false>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info, <<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
count_data, E, N); count_data, E, N);
arg_out = count; arg_out.value().clamp_(1);
auto count = arg_out.value();
for (int i = dim + 1; i < out.dim(); i++) for (int i = dim + 1; i < out.dim(); i++)
count = count.unsqueeze(-1); count = count.unsqueeze(-1);
out.div_(count.clamp_(1)); out.div_(count);
} }
}); });
}); });
...@@ -346,6 +353,9 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index, ...@@ -346,6 +353,9 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
out = torch::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
} }
if (index.numel() == 0)
return out;
auto E = index.numel(); auto E = index.numel();
auto K = out.numel() / E; auto K = out.numel() / E;
auto N = src.size(dim); auto N = src.size(dim);
......
...@@ -121,10 +121,10 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, ...@@ -121,10 +121,10 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
for (int i = 0; i < out.dim(); i++) for (int i = 0; i < out.dim(); i++)
if (i != dim) if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i)); CHECK_INPUT(src.size(i) == out.size(i));
CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1); CHECK_INPUT(src.numel() == 0 || out.size(dim) == indptr.size(dim) - 1);
} else { } else {
sizes = src.sizes().vec(); sizes = src.sizes().vec();
sizes[dim] = indptr.size(dim) - 1; sizes[dim] = std::max<int64_t>(indptr.size(dim) - 1, 0);
out = torch::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
} }
...@@ -135,6 +135,9 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, ...@@ -135,6 +135,9 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
arg_out_data = arg_out.value().data_ptr<int64_t>(); arg_out_data = arg_out.value().data_ptr<int64_t>();
} }
if (src.numel() == 0)
return std::make_tuple(out, arg_out);
auto N = out.size(dim) * (indptr.numel() / indptr.size(-1)); auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N; auto K = out.numel() / N;
auto E = src.size(dim); auto E = src.size(dim);
...@@ -226,7 +229,7 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, ...@@ -226,7 +229,7 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
indptr = indptr.expand(sizes); indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1; auto dim = indptr.dim() - 1;
CHECK_INPUT(src.size(dim) == indptr.size(dim) - 1); CHECK_INPUT(src.size(dim) == 0 || src.size(dim) == indptr.size(dim) - 1);
src = src.contiguous(); src = src.contiguous();
...@@ -237,14 +240,20 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, ...@@ -237,14 +240,20 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
if (i != dim) if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i)); CHECK_INPUT(src.size(i) == out.size(i));
} else { } else {
auto sizes = src.sizes().vec();
if (src.numel() > 0) {
auto d_size = indptr.flatten()[-1].data_ptr<int64_t>(); auto d_size = indptr.flatten()[-1].data_ptr<int64_t>();
auto h_size = (int64_t *)malloc(sizeof(int64_t)); auto h_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost); cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
auto sizes = src.sizes().vec();
sizes[dim] = *h_size; sizes[dim] = *h_size;
} else
sizes[dim] = 0;
out = torch::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
} }
if (src.numel() == 0)
return out;
auto N = src.size(dim) * (indptr.numel() / indptr.size(-1)); auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N; auto K = src.numel() / N;
auto E = out.size(dim); auto E = out.size(dim);
......
...@@ -82,6 +82,7 @@ public: ...@@ -82,6 +82,7 @@ public:
auto indptr = saved[0]; auto indptr = saved[0];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList()); auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options()); auto grad_in = torch::empty(src_shape, grad_out.options());
if (grad_in.numel() > 0) {
gather_csr_fw(grad_out, indptr, grad_in); gather_csr_fw(grad_out, indptr, grad_in);
auto indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1); auto indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1);
auto indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1); auto indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1);
...@@ -90,6 +91,7 @@ public: ...@@ -90,6 +91,7 @@ public:
for (auto i = 0; i < grad_out.dim() - indptr.dim(); i++) for (auto i = 0; i < grad_out.dim() - indptr.dim(); i++)
count = count.unsqueeze(-1); count = count.unsqueeze(-1);
grad_in.div_(count); grad_in.div_(count);
}
return {grad_in, Variable(), Variable()}; return {grad_in, Variable(), Variable()};
} }
}; };
......
from itertools import product
import pytest
import torch import torch
from torch_scatter import scatter from torch_scatter import scatter, segment_coo, gather_coo
from torch_scatter import segment_csr, gather_csr
from .utils import reductions, tensor, grad_dtypes, devices
@pytest.mark.parametrize('reduce,dtype,device',
product(reductions, grad_dtypes, devices))
def test_zero_elements(reduce, dtype, device):
x = torch.randn(0, 0, 0, 16, dtype=dtype, device=device,
requires_grad=True)
index = tensor([], torch.long, device)
indptr = tensor([], torch.long, device)
out = scatter(x, index, dim=0, dim_size=0, reduce=reduce)
out.backward(torch.randn_like(out))
assert out.size() == (0, 0, 0, 16)
out = segment_coo(x, index, dim_size=0, reduce=reduce)
out.backward(torch.randn_like(out))
assert out.size() == (0, 0, 0, 16)
out = gather_coo(x, index)
out.backward(torch.randn_like(out))
assert out.size() == (0, 0, 0, 16)
def test_zero_elements(): out = segment_csr(x, indptr, reduce=reduce)
x = torch.randn(0, 16) out.backward(torch.randn_like(out))
index = torch.tensor([]).view(0, 16) assert out.size() == (0, 0, 0, 16)
print(x)
print(index)
scatter(x, index, dim=0, dim_size=0, reduce="add") out = gather_csr(x, indptr)
out.backward(torch.randn_like(out))
assert out.size() == (0, 0, 0, 16)
...@@ -12,12 +12,6 @@ try: ...@@ -12,12 +12,6 @@ try:
except OSError: except OSError:
warnings.warn('Failed to load `scatter` binaries.') warnings.warn('Failed to load `scatter` binaries.')
def scatter_placeholder(src: torch.Tensor, index: torch.Tensor, dim: int,
out: Optional[torch.Tensor],
dim_size: Optional[int]) -> torch.Tensor:
raise ImportError
return src
def scatter_with_arg_placeholder(src: torch.Tensor, index: torch.Tensor, def scatter_with_arg_placeholder(src: torch.Tensor, index: torch.Tensor,
dim: int, out: Optional[torch.Tensor], dim: int, out: Optional[torch.Tensor],
dim_size: Optional[int] dim_size: Optional[int]
...@@ -25,7 +19,6 @@ except OSError: ...@@ -25,7 +19,6 @@ except OSError:
raise ImportError raise ImportError
return src, index return src, index
torch.ops.torch_scatter.scatter_mean = scatter_placeholder
torch.ops.torch_scatter.scatter_min = scatter_with_arg_placeholder torch.ops.torch_scatter.scatter_min = scatter_with_arg_placeholder
torch.ops.torch_scatter.scatter_max = scatter_with_arg_placeholder torch.ops.torch_scatter.scatter_max = scatter_with_arg_placeholder
...@@ -37,11 +30,13 @@ def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -37,11 +30,13 @@ def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
index = broadcast(index, src, dim) index = broadcast(index, src, dim)
if out is None: if out is None:
size = src.size() size = src.size()
if dim_size is None: if dim_size is not None:
size[dim] = int(index.max()) + 1
else:
size[dim] = dim_size size[dim] = dim_size
out = src.new_zeros(size) elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max()) + 1
out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_add_(dim, index, src) return out.scatter_add_(dim, index, src)
else: else:
return out.scatter_add_(dim, index, src) return out.scatter_add_(dim, index, src)
...@@ -58,7 +53,22 @@ def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1, ...@@ -58,7 +53,22 @@ def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1, def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor: dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.scatter_mean(src, index, dim, out, dim_size)
out = scatter_sum(src, index, dim, out, dim_size)
dim_size = out.size(dim)
index_dim = dim
if index_dim < 0:
index_dim = index_dim + src.dim()
if index.dim() <= dim:
index_dim = index.dim() - 1
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, index_dim, None, dim_size)
count.clamp_(1)
count = broadcast(count, out, dim)
out.div_(count)
return out
@torch.jit.script @torch.jit.script
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment