Commit 52639ff6 authored by sangwzh's avatar sangwzh
Browse files

update dtype in index_select_csc_impl

parent 8f11ff9b
...@@ -160,7 +160,7 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices( ...@@ -160,7 +160,7 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
_CopyIndicesAlignedKernel, grid, block, 0, _CopyIndicesAlignedKernel, grid, block, 0,
static_cast<indptr_t>(edge_count_aligned_), sliced_indptr, static_cast<indptr_t>(edge_count_aligned_), sliced_indptr,
output_indptr.data_ptr<indptr_t>(), output_indptr_aligned_ptr, output_indptr.data_ptr<indptr_t>(), output_indptr_aligned_ptr,
reinterpret_cast<indices_t*>(cuda::getTensorDevicePointer<indptr_t>(indices)), reinterpret_cast<indices_t*>(cuda::getTensorDevicePointer<indices_t>(indices)),
coo_aligned_rows.data_ptr<coo_rows_t>(), coo_aligned_rows.data_ptr<coo_rows_t>(),
reinterpret_cast<indices_t*>(output_indices.data_ptr()), perm); reinterpret_cast<indices_t*>(output_indices.data_ptr()), perm);
})); }));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment