Unverified Commit 715b3b16 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

upd (#2117)

parent 5cff2f1c
...@@ -115,6 +115,7 @@ void CusparseCsrmm2( ...@@ -115,6 +115,7 @@ void CusparseCsrmm2(
DType* valptr = static_cast<DType*>(device->AllocWorkspace(rtcfg.ctx, nnz * sizeof(DType))); DType* valptr = static_cast<DType*>(device->AllocWorkspace(rtcfg.ctx, nnz * sizeof(DType)));
utils::Fill<kDLGPU>(rtcfg.ctx, valptr, nnz, static_cast<DType>(1.)); utils::Fill<kDLGPU>(rtcfg.ctx, valptr, nnz, static_cast<DType>(1.));
#if CUDART_VERSION >= 11000 #if CUDART_VERSION >= 11000
auto ctx = rtcfg.ctx;
cusparseSpMatDescr_t matA; cusparseSpMatDescr_t matA;
cusparseDnMatDescr_t matB, matC; cusparseDnMatDescr_t matB, matC;
constexpr auto cuda_dtype = std::is_same<DType, float>::value ? CUDA_R_32F: CUDA_R_64F; constexpr auto cuda_dtype = std::is_same<DType, float>::value ? CUDA_R_32F: CUDA_R_64F;
...@@ -122,7 +123,7 @@ void CusparseCsrmm2( ...@@ -122,7 +123,7 @@ void CusparseCsrmm2(
m, k, nnz, m, k, nnz,
static_cast<int32_t*>(csr.indptr->data), static_cast<int32_t*>(csr.indptr->data),
static_cast<int32_t*>(csr.indices->data), static_cast<int32_t*>(csr.indices->data),
const_cast<DType*>(valptr? valptr : A_data), const_cast<DType*>(valptr),
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_BASE_ZERO, cuda_dtype)); CUSPARSE_INDEX_BASE_ZERO, cuda_dtype));
CUSPARSE_CALL(cusparseCreateDnMat(&matB, CUSPARSE_CALL(cusparseCreateDnMat(&matB,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment