Commit f98ff7e8 authored by rusty1s's avatar rusty1s
Browse files

possible assertion fix

parent 4f6fe911
...@@ -61,20 +61,21 @@ __global__ void gather_csr_broadcast_kernel( ...@@ -61,20 +61,21 @@ __global__ void gather_csr_broadcast_kernel(
at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr, at::Tensor gather_csr_cuda(at::Tensor src, at::Tensor indptr,
at::optional<at::Tensor> out_opt) { at::optional<at::Tensor> out_opt) {
AT_ASSERTM(src.dim() >= indptr.dim()); AT_ASSERTM(src.dim() >= indptr.dim(), "Input mismatch");
for (int i = 0; i < indptr.dim() - 1; i++) for (int i = 0; i < indptr.dim() - 1; i++)
AT_ASSERTM(src.size(i) == indptr.size(i)); AT_ASSERTM(src.size(i) == indptr.size(i), "Input mismatch");
src = src.contiguous(); src = src.contiguous();
auto gather_dim = indptr.dim() - 1; auto gather_dim = indptr.dim() - 1;
AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1); AT_ASSERTM(src.size(gather_dim) == indptr.size(gather_dim) - 1,
"Input mismatch");
at::Tensor out; at::Tensor out;
if (out_opt.has_value()) { if (out_opt.has_value()) {
out = out_opt.value().contiguous(); out = out_opt.value().contiguous();
for (int i = 0; i < out.dim(); i++) for (int i = 0; i < out.dim(); i++)
if (i != gather_dim) if (i != gather_dim)
AT_ASSERTM(src.size(i) == out.size(i)); AT_ASSERTM(src.size(i) == out.size(i), "Input mismatch");
} else { } else {
auto d_gather_size = indptr.flatten()[-1].DATA_PTR<int64_t>(); auto d_gather_size = indptr.flatten()[-1].DATA_PTR<int64_t>();
auto h_gather_size = (int64_t *)malloc(sizeof(int64_t)); auto h_gather_size = (int64_t *)malloc(sizeof(int64_t));
...@@ -154,9 +155,9 @@ __global__ void gather_coo_broadcast_kernel( ...@@ -154,9 +155,9 @@ __global__ void gather_coo_broadcast_kernel(
at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index, at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index,
at::optional<at::Tensor> out_opt) { at::optional<at::Tensor> out_opt) {
AT_ASSERTM(src.dim() >= index.dim()); AT_ASSERTM(src.dim() >= index.dim(), "Input mismatch");
for (int i = 0; i < index.dim() - 1; i++) for (int i = 0; i < index.dim() - 1; i++)
AT_ASSERTM(src.size(i) == index.size(i)); AT_ASSERTM(src.size(i) == index.size(i), "Input mismatch");
src = src.contiguous(); src = src.contiguous();
auto gather_dim = index.dim() - 1; auto gather_dim = index.dim() - 1;
...@@ -165,9 +166,9 @@ at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index, ...@@ -165,9 +166,9 @@ at::Tensor gather_coo_cuda(at::Tensor src, at::Tensor index,
if (out_opt.has_value()) { if (out_opt.has_value()) {
out = out_opt.value().contiguous(); out = out_opt.value().contiguous();
for (int i = 0; i < index.dim(); i++) for (int i = 0; i < index.dim(); i++)
AT_ASSERTM(out.size(i) == index.size(i)); AT_ASSERTM(out.size(i) == index.size(i), "Input mismatch");
for (int i = index.dim() + 1; i < src.dim(); i++) for (int i = index.dim() + 1; i < src.dim(); i++)
AT_ASSERTM(out.size(i) == src.size(i)); AT_ASSERTM(out.size(i) == src.size(i), "Input mismatch");
} else { } else {
auto sizes = src.sizes().vec(); auto sizes = src.sizes().vec();
sizes[gather_dim] = index.size(gather_dim); sizes[gather_dim] = index.size(gather_dim);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment