"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a1b5ef5ddc1ee19bf51927b365ee632d352b9890"
Commit 1eef2be1 authored by rusty1s's avatar rusty1s
Browse files

fixed zero numel init bug

parent 66105d4b
...@@ -43,8 +43,11 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -43,8 +43,11 @@ scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim,
arg_out_data = arg_out.value().data_ptr<int64_t>(); arg_out_data = arg_out.value().data_ptr<int64_t>();
} }
if (index.numel() == 0) if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
}
auto B = 1; auto B = 1;
for (auto i = 0; i < dim; i++) for (auto i = 0; i < dim; i++)
......
...@@ -52,8 +52,11 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index, ...@@ -52,8 +52,11 @@ segment_coo_cpu(torch::Tensor src, torch::Tensor index,
arg_out = torch::zeros(sizes, out.options()); arg_out = torch::zeros(sizes, out.options());
} }
if (index.numel() == 0) if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
}
auto B = index.numel() / src.size(dim); auto B = index.numel() / src.size(dim);
auto E = src.size(dim); auto E = src.size(dim);
...@@ -158,8 +161,11 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index, ...@@ -158,8 +161,11 @@ torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index,
out = torch::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
} }
if (index.numel() == 0) if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return out; return out;
}
auto B = index.numel() / out.size(dim); auto B = index.numel() / out.size(dim);
auto E = index.size(dim); auto E = index.size(dim);
......
...@@ -44,8 +44,11 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, ...@@ -44,8 +44,11 @@ segment_csr_cpu(torch::Tensor src, torch::Tensor indptr,
arg_out_data = arg_out.value().data_ptr<int64_t>(); arg_out_data = arg_out.value().data_ptr<int64_t>();
} }
if (src.numel() == 0) if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
}
auto N = out.size(dim) * (indptr.numel() / indptr.size(-1)); auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N; auto K = out.numel() / N;
...@@ -120,8 +123,11 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr, ...@@ -120,8 +123,11 @@ torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr,
out = torch::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
} }
if (src.numel() == 0) if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return out; return out;
}
auto N = src.size(dim) * (indptr.numel() / indptr.size(-1)); auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N; auto K = src.numel() / N;
......
...@@ -99,8 +99,11 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, ...@@ -99,8 +99,11 @@ scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim,
arg_out_data = arg_out.value().data_ptr<int64_t>(); arg_out_data = arg_out.value().data_ptr<int64_t>();
} }
if (index.numel() == 0) if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
}
auto B = 1; auto B = 1;
for (auto i = 0; i < dim; i++) for (auto i = 0; i < dim; i++)
......
...@@ -135,8 +135,11 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, ...@@ -135,8 +135,11 @@ segment_csr_cuda(torch::Tensor src, torch::Tensor indptr,
arg_out_data = arg_out.value().data_ptr<int64_t>(); arg_out_data = arg_out.value().data_ptr<int64_t>();
} }
if (src.numel() == 0) if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return std::make_tuple(out, arg_out); return std::make_tuple(out, arg_out);
}
auto N = out.size(dim) * (indptr.numel() / indptr.size(-1)); auto N = out.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = out.numel() / N; auto K = out.numel() / N;
...@@ -251,8 +254,11 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, ...@@ -251,8 +254,11 @@ torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr,
out = torch::empty(sizes, src.options()); out = torch::empty(sizes, src.options());
} }
if (src.numel() == 0) if (src.numel() == 0) {
if (!optional_out.has_value())
out.fill_(0);
return out; return out;
}
auto N = src.size(dim) * (indptr.numel() / indptr.size(-1)); auto N = src.size(dim) * (indptr.numel() / indptr.size(-1));
auto K = src.numel() / N; auto K = src.numel() / N;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment