Commit f98ff7e8 authored by rusty1s's avatar rusty1s
Browse files

possible assertion fix

parent 4f6fe911
......@@ -61,20 +61,21 @@ __global__ void gather_csr_broadcast_kernel(
at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt) {
AT_ASSERTM(src.dim() >= indptr.dim());
AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
for (int i = 0; i < indptr.dim() - 1; i++)
AT_ASSERTM(src.size(i) == indptr.size(i));
AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch");
src = src.contiguous();
auto gather_dim = indptr.dim() - 1;
AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1);
AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1,
"Input mismatch");
at::Tensor out;
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++)
if (i != gather_dim)
AT_ASSERTM(src.size(i) == out.size(i));
AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
} else {
auto d_gather_size = indptr.flatten()[-1].DATA_PTR<int64_t>();
auto h_gather_size = (int64_t *)malloc(sizeof(int64_t));
......@@ -154,9 +155,9 @@ __global__ void gather_coo_broadcast_kernel(
at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index,
at::optional<at::Tensor> out_opt) {
AT_ASSERTM(src.dim() >= index.dim());
AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
for (int i = 0; i < index.dim() - 1; i++)
AT_ASSERTM(src.size(i) == index.size(i));
AT_ASSERTM(src.size(i) == index.size(i), "Input mismatch");
src = src.contiguous();
auto gather_dim = index.dim() - 1;
......@@ -165,9 +166,9 @@ at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index,
if (out_opt.has_value()) {
out = out_opt.value().contiguous();
for (int i = 0; i < index.dim(); i++)
AT_ASSERTM(out.size(i) == index.size(i));
AT_ASSERTM(out.size(i) == index.size(i), "Input mismatch");
for (int i = index.dim() + 1; i < src.dim(); i++)
AT_ASSERTM(out.size(i) == src.size(i));
AT_ASSERTM(out.size(i) == src.size(i), "Input mismatch");
} else {
auto sizes = src.sizes().vec();
sizes[gather_dim] = index.size(gather_dim);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment