Commit cebec48f authored by rusty1s's avatar rusty1s
Browse files

possible assertion fix

parent f82bfbac
...@@ -182,9 +182,9 @@ std::tuple<at::Tensor, at::optional<at::Tensor>> ...@@ -182,9 +182,9 @@ std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_csr_cuda(at::Tensor src, at::Tensor indptr, segment_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt, std::string reduce) { at::optional<at::Tensor> out_opt, std::string reduce) {
AT_ASSERTM(src.dim() >= indptr.dim()); AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
for (int i = 0; i < indptr.dim() - 1; i++) for (int i = 0; i < indptr.dim() - 1; i++)
AT_ASSERTM(src.size(i) == indptr.size(i)); AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch");
src = src.contiguous(); src = src.contiguous();
auto reduce_dim = indptr.dim() - 1; auto reduce_dim = indptr.dim() - 1;
...@@ -194,8 +194,9 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr, ...@@ -194,8 +194,9 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
out = out_opt.value().contiguous(); out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++) for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim) if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i)); AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1); AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1, "Input
mismatch");
} else { } else {
auto sizes = src.sizes().vec(); auto sizes = src.sizes().vec();
sizes[reduce_dim] = indptr.size(reduce_dim) - 1; sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
...@@ -341,9 +342,9 @@ __global__ void segment_coo_broadcast_kernel( ...@@ -341,9 +342,9 @@ __global__ void segment_coo_broadcast_kernel(
std::tuple<at::Tensor, at::optional<at::Tensor>> std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
std::string reduce) { std::string reduce) {
AT_ASSERTM(src.dim() >= index.dim()); AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
for (int i = 0; i < index.dim(); i++) for (int i = 0; i < index.dim(); i++)
AT_ASSERTM(src.size(i) == index.size(i)); AT_ASSERTM(src.size(i) == index.size(i), "Input mismatch");
src = src.contiguous(); src = src.contiguous();
out = out.contiguous(); out = out.contiguous();
...@@ -351,7 +352,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -351,7 +352,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
for (int i = 0; i < out.dim(); i++) for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim) if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i)); AT_ASSERTM(src.size(i) == out.size(i), , "Input mismatch");
at::optional<at::Tensor> arg_out = at::nullopt; at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr; int64_t *arg_out_data = nullptr;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment