"tests/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "265840a09887f110e722ae8690365b4ceb784ec0"
Commit a9f9266b authored by rusty1s's avatar rusty1s
Browse files

index expand

parent 2743b291
...@@ -178,8 +178,13 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr, ...@@ -178,8 +178,13 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt, std::string reduce) { at::optional<at::Tensor> out_opt, std::string reduce) {
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch"); AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
for (int i = 0; i < indptr.dim() - 1; i++)
AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch"); // Broadcasting across `index` via `expand`.
auto sizes = indptr.sizes().vec();
for (int i = 0; i < indptr.dim() - 1; i++) {
sizes[i] = src.size(i);
}
indptr = indptr.expand(sizes);
src = src.contiguous(); src = src.contiguous();
auto reduce_dim = indptr.dim() - 1; auto reduce_dim = indptr.dim() - 1;
...@@ -193,7 +198,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr, ...@@ -193,7 +198,7 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1, AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1,
"Input mismatch"); "Input mismatch");
} else { } else {
auto sizes = src.sizes().vec(); sizes = src.sizes().vec();
sizes[reduce_dim] = indptr.size(reduce_dim) - 1; sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
out = at::empty(sizes, src.options()); out = at::empty(sizes, src.options());
} }
...@@ -370,9 +375,15 @@ __global__ void segment_coo_arg_broadcast_kernel( ...@@ -370,9 +375,15 @@ __global__ void segment_coo_arg_broadcast_kernel(
std::tuple<at::Tensor, at::optional<at::Tensor>> std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
std::string reduce) { std::string reduce) {
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch"); AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
for (int i = 0; i < index.dim(); i++)
AT_ASSERTM(src.size(i) == index.size(i), "Input mismatch"); // Broadcasting across `index` via `expand`.
auto sizes = index.sizes().vec();
for (int i = 0; i < index.dim(); i++) {
sizes[i] = src.size(i);
}
index = index.expand(sizes);
src = src.contiguous(); src = src.contiguous();
out = out.contiguous(); out = out.contiguous();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment