Commit 4f6fe911 authored by rusty1s's avatar rusty1s
Browse files

fixed string bug

parent cebec48f
...@@ -195,8 +195,8 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr, ...@@ -195,8 +195,8 @@ segment_csr_cuda(at::Tensor src, at::Tensor indptr,
for (int i = 0; i < out.dim(); i++) for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim) if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch"); AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1, "Input AT_ASSERTM(out.size(reduce_dim) == indptr.size(reduce_dim) - 1,
mismatch"); "Input mismatch");
} else { } else {
auto sizes = src.sizes().vec(); auto sizes = src.sizes().vec();
sizes[reduce_dim] = indptr.size(reduce_dim) - 1; sizes[reduce_dim] = indptr.size(reduce_dim) - 1;
...@@ -352,7 +352,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out, ...@@ -352,7 +352,7 @@ segment_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor out,
for (int i = 0; i < out.dim(); i++) for (int i = 0; i < out.dim(); i++)
if (i != reduce_dim) if (i != reduce_dim)
AT_ASSERTM(src.size(i) == out.size(i), , "Input mismatch"); AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
at::optional<at::Tensor> arg_out = at::nullopt; at::optional<at::Tensor> arg_out = at::nullopt;
int64_t *arg_out_data = nullptr; int64_t *arg_out_data = nullptr;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment