"docs/vscode:/vscode.git/clone" did not exist on "ddefa23bd04dda125e4ddb71b4ce304dc4013d67"
Commit 88dd792e authored by rusty1s's avatar rusty1s
Browse files

fix zero element tensors

parent bf1f1014
......@@ -29,6 +29,8 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
auto sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else if (index.numel() == 0)
sizes[dim] = 0;
else
sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
out = torch::empty(sizes, src.options());
......@@ -41,6 +43,9 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
if (index.numel() == 0)
return std::make_tuple(out, arg_out);
auto B = 1;
for (auto i = 0; i < dim; i++)
B *= src.size(i);
......
......@@ -34,6 +34,8 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else if (index.numel() == 0)
sizes[dim] = 0;
else
sizes[dim] = 1 + *index.max().data_ptr<int64_t>();
out = torch::empty(sizes, src.options());
......@@ -44,15 +46,15 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
torch::optional<torch::Tensor> count = torch::nullopt;
if (reduce2REDUCE.at(reduce) == MEAN) {
} else if (reduce2REDUCE.at(reduce) == MEAN) {
auto sizes = index.sizes().vec();
sizes[dim] = out.size(dim);
count = torch::zeros(sizes, out.options());
arg_out = torch::zeros(sizes, out.options());
}
if (index.numel() == 0)
return std::make_tuple(out, arg_out);
auto B = index.numel() / src.size(dim);
auto E = src.size(dim);
auto K = src.numel() / index.numel();
......@@ -72,7 +74,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
if (!optional_out.has_value())
out.fill_(Reducer<scalar_t, REDUCE>::init());
if (REDUCE == MEAN)
count_data = count.value().data_ptr<scalar_t>();
count_data = arg_out.value().data_ptr<scalar_t>();
for (auto b = 0; b < B; b++) {
auto offset = IndexToOffset<int64_t>::get(b * E, index_info);
......@@ -122,7 +124,7 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
out.masked_fill_(out == Reducer<scalar_t, REDUCE>::init(), (scalar_t)0);
if (REDUCE == MEAN)
arg_out = count;
arg_out.value().clamp_(1);
});
});
......@@ -156,6 +158,9 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
out = torch::empty(sizes, src.options());
}
if (index.numel() == 0)
return out;
auto B = index.numel() / out.size(dim);
auto E = index.size(dim);
auto K = out.numel() / index.numel();
......
......@@ -30,10 +30,10 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
for (auto i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1);
CHECK_INPUT(src.numel() == 0 || out.size(dim) == indptr.size(dim) - 1);
} else {
sizes = src.sizes().vec();
sizes[dim] = indptr.size(dim) - 1;
sizes[dim] = std::max<int64_t>(indptr.size(dim) - 1, 0);
out = torch::empty(sizes, src.options());
}
......@@ -44,6 +44,9 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
if (src.numel() == 0)
return std::make_tuple(out, arg_out);
auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(dim);
......@@ -98,7 +101,7 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
CHECK_INPUT(src.size(dim) == indptr.size(dim) - 1);
CHECK_INPUT(src.size(dim) == 0 || src.size(dim) == indptr.size(dim) - 1);
src = src.contiguous();
......@@ -110,10 +113,16 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto sizes = src.sizes().vec();
sizes[dim] = *indptr.flatten()[-1].data_ptr<int64_t>();
if (src.numel() > 0)
sizes[dim] = *indptr.flatten()[-1].data_ptr<int64_t>();
else
sizes[dim] = 0;
out = torch::empty(sizes, src.options());
}
if (src.numel() == 0)
return out;
auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N;
auto E = out.size(dim);
......
......@@ -81,6 +81,8 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
auto sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else if (index.numel() == 0)
sizes[dim] = 0;
else {
auto d_size = index.max().data_ptr<int64_t>();
auto h_size = (int64_t *)malloc(sizeof(int64_t));
......@@ -97,6 +99,9 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
if (index.numel() == 0)
return std::make_tuple(out, arg_out);
auto B = 1;
for (auto i = 0; i < dim; i++)
B *= src.size(i);
......
......@@ -181,6 +181,8 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
sizes = src.sizes().vec();
if (dim_size.has_value())
sizes[dim] = dim_size.value();
else if (index.numel() == 0)
sizes[dim] = 0;
else {
auto d_size = index.max().data_ptr<int64_t>();
auto h_size = (int64_t *)malloc(sizeof(int64_t));
......@@ -195,8 +197,15 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) {
arg_out = torch::full_like(out, src.size(dim), index.options());
arg_out_data = arg_out.value().data_ptr<int64_t>();
} else if (reduce2REDUCE.at(reduce) == MEAN) {
auto sizes = index.sizes().vec();
sizes[dim] = out.size(dim);
arg_out = torch::zeros(sizes, out.options());
}
if (index.numel() == 0)
return std::make_tuple(out, arg_out);
auto E = index.numel();
auto E_2 = index.size(dim);
auto E_1 = index.numel() / E_2;
......@@ -254,17 +263,15 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
}
if (REDUCE == MEAN) {
auto sizes = index.sizes().vec();
sizes[dim] = out.size(dim);
auto count = torch::zeros(sizes, out.options());
auto count_data = count.data_ptr<scalar_t>();
auto count_data = arg_out.value().data_ptr<scalar_t>();
segment_coo_kernel<scalar_t, SUM, false>
<<<BLOCKS(1, E), THREADS, 0, stream>>>(nullptr, index_info,
count_data, E, N);
arg_out = count;
arg_out.value().clamp_(1);
auto count = arg_out.value();
for (int i = dim + 1; i < out.dim(); i++)
count = count.unsqueeze(-1);
out.div_(count.clamp_(1));
out.div_(count);
}
});
});
......@@ -346,6 +353,9 @@ torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index,
out = torch::empty(sizes, src.options());
}
if (index.numel() == 0)
return out;
auto E = index.numel();
auto K = out.numel() / E;
auto N = src.size(dim);
......
......@@ -121,10 +121,10 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
for (int i = 0; i < out.dim(); i++)
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
CHECK_INPUT(out.size(dim) == indptr.size(dim) - 1);
CHECK_INPUT(src.numel() == 0 || out.size(dim) == indptr.size(dim) - 1);
} else {
sizes = src.sizes().vec();
sizes[dim] = indptr.size(dim) - 1;
sizes[dim] = std::max<int64_t>(indptr.size(dim) - 1, 0);
out = torch::empty(sizes, src.options());
}
......@@ -135,6 +135,9 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
arg_out_data = arg_out.value().data_ptr<int64_t>();
}
if (src.numel() == 0)
return std::make_tuple(out, arg_out);
auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N;
auto E = src.size(dim);
......@@ -226,7 +229,7 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
indptr = indptr.expand(sizes);
auto dim = indptr.dim() - 1;
CHECK_INPUT(src.size(dim) == indptr.size(dim) - 1);
CHECK_INPUT(src.size(dim) == 0 || src.size(dim) == indptr.size(dim) - 1);
src = src.contiguous();
......@@ -237,14 +240,20 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
if (i != dim)
CHECK_INPUT(src.size(i) == out.size(i));
} else {
auto d_size = indptr.flatten()[-1].data_ptr<int64_t>();
auto h_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
auto sizes = src.sizes().vec();
sizes[dim] = *h_size;
if (src.numel() > 0) {
auto d_size = indptr.flatten()[-1].data_ptr<int64_t>();
auto h_size = (int64_t *)malloc(sizeof(int64_t));
cudaMemcpy(h_size, d_size, sizeof(int64_t), cudaMemcpyDeviceToHost);
sizes[dim] = *h_size;
} else
sizes[dim] = 0;
out = torch::empty(sizes, src.options());
}
if (src.numel() == 0)
return out;
auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N;
auto E = out.size(dim);
......
......@@ -82,14 +82,16 @@ public:
auto indptr = saved[0];
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::empty(src_shape, grad_out.options());
gather_csr_fw(grad_out, indptr, grad_in);
auto indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1);
auto indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1);
auto count = (indptr2 - indptr1).to(grad_in.options());
count = gather_csr_fw(count, indptr, torch::nullopt);
for (auto i = 0; i < grad_out.dim() - indptr.dim(); i++)
count = count.unsqueeze(-1);
grad_in.div_(count);
if (grad_in.numel() > 0) {
gather_csr_fw(grad_out, indptr, grad_in);
auto indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1);
auto indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1);
auto count = (indptr2 - indptr1).to(grad_in.options());
count = gather_csr_fw(count, indptr, torch::nullopt);
for (auto i = 0; i < grad_out.dim() - indptr.dim(); i++)
count = count.unsqueeze(-1);
grad_in.div_(count);
}
return {grad_in, Variable(), Variable()};
}
};
......
from itertools import product
import pytest
import torch
from torch_scatter import scatter
from torch_scatter import scatter, segment_coo, gather_coo
from torch_scatter import segment_csr, gather_csr
from .utils import reductions, tensor, grad_dtypes, devices
@pytest.mark.parametrize('reduce,dtype,device',
product(reductions, grad_dtypes, devices))
def test_zero_elements(reduce, dtype, device):
x = torch.randn(0, 0, 0, 16, dtype=dtype, device=device,
requires_grad=True)
index = tensor([], torch.long, device)
indptr = tensor([], torch.long, device)
out = scatter(x, index, dim=0, dim_size=0, reduce=reduce)
out.backward(torch.randn_like(out))
assert out.size() == (0, 0, 0, 16)
out = segment_coo(x, index, dim_size=0, reduce=reduce)
out.backward(torch.randn_like(out))
assert out.size() == (0, 0, 0, 16)
out = gather_coo(x, index)
out.backward(torch.randn_like(out))
assert out.size() == (0, 0, 0, 16)
def test_zero_elements():
x = torch.randn(0, 16)
index = torch.tensor([]).view(0, 16)
print(x)
print(index)
out = segment_csr(x, indptr, reduce=reduce)
out.backward(torch.randn_like(out))
assert out.size() == (0, 0, 0, 16)
scatter(x, index, dim=0, dim_size=0, reduce="add")
out = gather_csr(x, indptr)
out.backward(torch.randn_like(out))
assert out.size() == (0, 0, 0, 16)
......@@ -12,12 +12,6 @@ try:
except OSError:
warnings.warn('Failed to load `scatter` binaries.')
def scatter_placeholder(src: torch.Tensor, index: torch.Tensor, dim: int,
out: Optional[torch.Tensor],
dim_size: Optional[int]) -> torch.Tensor:
raise ImportError
return src
def scatter_with_arg_placeholder(src: torch.Tensor, index: torch.Tensor,
dim: int, out: Optional[torch.Tensor],
dim_size: Optional[int]
......@@ -25,7 +19,6 @@ except OSError:
raise ImportError
return src, index
torch.ops.torch_scatter.scatter_mean = scatter_placeholder
torch.ops.torch_scatter.scatter_min = scatter_with_arg_placeholder
torch.ops.torch_scatter.scatter_max = scatter_with_arg_placeholder
......@@ -37,11 +30,13 @@ def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
index = broadcast(index, src, dim)
if out is None:
size = src.size()
if dim_size is None:
size[dim] = int(index.max()) + 1
else:
if dim_size is not None:
size[dim] = dim_size
out = src.new_zeros(size)
elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max()) + 1
out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_add_(dim, index, src)
else:
return out.scatter_add_(dim, index, src)
......@@ -58,7 +53,22 @@ def scatter_add(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
return torch.ops.torch_scatter.scatter_mean(src, index, dim, out, dim_size)
out = scatter_sum(src, index, dim, out, dim_size)
dim_size = out.size(dim)
index_dim = dim
if index_dim < 0:
index_dim = index_dim + src.dim()
if index.dim() <= dim:
index_dim = index.dim() - 1
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, index_dim, None, dim_size)
count.clamp_(1)
count = broadcast(count, out, dim)
out.div_(count)
return out
@torch.jit.script
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment