Unverified Commit 2ff3006c authored by Muhammed Fatih BALIN's avatar Muhammed Fatih BALIN Committed by GitHub
Browse files

[CUDA][Bug] CSR transpose bug in CUDA 12 (#7295)

parent 20e5e266
...@@ -22,6 +22,7 @@ CSRMatrix CSRTranspose(CSRMatrix csr) { ...@@ -22,6 +22,7 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
template <> template <>
CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) { CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) {
#if CUDART_VERSION < 12000
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed // allocate cusparse handle if needed
...@@ -76,6 +77,9 @@ CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) { ...@@ -76,6 +77,9 @@ CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) {
return CSRMatrix( return CSRMatrix(
csr.num_cols, csr.num_rows, t_indptr, t_indices, t_data, false); csr.num_cols, csr.num_rows, t_indptr, t_indices, t_data, false);
#else
return COOToCSR(COOTranspose(CSRToCOO(csr, false)));
#endif
} }
template <> template <>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment