Commit cebec48f authored by rusty1s's avatar rusty1s
Browse files

possible assertion fix

parent f82bfbac
......@@ -182,9 +182,9 @@ std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt, std::string reduce) {
AT_ASSERTM(src.dim() >= indptr.dim());
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
for (int i = 0; i < indptr.dim() - 1; i++)
AT_ASSERTM(src.size(i) == indptr.size(i));
AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch");
src = src.contiguous();
auto reduce_dim = indptr.dim() - 1;
......@@ -194,8 +194,9 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i));
AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1);
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1, "Input
mismatch");
} else {
auto sizes = src.sizes().vec();
sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
......@@ -341,9 +342,9 @@ __global__ void segment_coo_broadcast_kernel(
std::tuple<at::Tensor, at::optional<at::Tensor>>
segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
std::string reduce) {
AT_ASSERTM(src.dim() >= index.dim());
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
for (int i = 0; i < index.dim(); i++)
AT_ASSERTM(src.size(i) == index.size(i));
AT_ASSERTM(src.size(i) == index.size(i), "Input mismatch");
src = src.contiguous();
out = out.contiguous();
......@@ -351,7 +352,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i));
AT_ASSERTM(src.size(i) == out.size(i), , "Input mismatch");
at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment