Unverified Commit 9d90faf0 authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

Use ALG2 for SpMM in cuSparse (#2550)

parent 2c6d0716
...@@ -122,8 +122,6 @@ void CusparseCsrmm2( ...@@ -122,8 +122,6 @@ void CusparseCsrmm2(
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle))); CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
} }
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream)); CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream));
// allocate matrix for temporary transposed output
DType* trans_out = static_cast<DType*>(device->AllocWorkspace(ctx, m * n * sizeof(DType)));
// all one data array // all one data array
DType* valptr = nullptr; DType* valptr = nullptr;
if (!A_data) { if (!A_data) {
...@@ -142,25 +140,25 @@ void CusparseCsrmm2( ...@@ -142,25 +140,25 @@ void CusparseCsrmm2(
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_BASE_ZERO, cuda_dtype)); CUSPARSE_INDEX_BASE_ZERO, cuda_dtype));
CUSPARSE_CALL(cusparseCreateDnMat(&matB, CUSPARSE_CALL(cusparseCreateDnMat(&matB,
n, k, n, k, n, n,
const_cast<DType*>(B_data), cuda_dtype, CUSPARSE_ORDER_COL)); const_cast<DType*>(B_data), cuda_dtype, CUSPARSE_ORDER_ROW));
CUSPARSE_CALL(cusparseCreateDnMat(&matC, CUSPARSE_CALL(cusparseCreateDnMat(&matC,
m, n, m, m, n, n,
trans_out, cuda_dtype, CUSPARSE_ORDER_COL)); C_data, cuda_dtype, CUSPARSE_ORDER_ROW));
auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE; auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
auto transB = CUSPARSE_OPERATION_TRANSPOSE; auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE;
size_t workspace_size; size_t workspace_size;
CUSPARSE_CALL(cusparseSpMM_bufferSize( CUSPARSE_CALL(cusparseSpMM_bufferSize(
thr_entry->cusparse_handle, transA, transB, thr_entry->cusparse_handle, transA, transB,
&alpha, matA, matB, &beta, matC, &alpha, matA, matB, &beta, matC,
cuda_dtype, CUSPARSE_CSRMM_ALG1, cuda_dtype, CUSPARSE_SPMM_CSR_ALG2,
&workspace_size)); &workspace_size));
void* workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUSPARSE_CALL(cusparseSpMM( CUSPARSE_CALL(cusparseSpMM(
thr_entry->cusparse_handle, transA, transB, thr_entry->cusparse_handle, transA, transB,
&alpha, matA, matB, &beta, matC, &alpha, matA, matB, &beta, matC,
cuda_dtype, CUSPARSE_CSRMM_ALG1, cuda_dtype, CUSPARSE_SPMM_CSR_ALG2,
workspace)); workspace));
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
...@@ -168,6 +166,9 @@ void CusparseCsrmm2( ...@@ -168,6 +166,9 @@ void CusparseCsrmm2(
CUSPARSE_CALL(cusparseDestroyDnMat(matB)); CUSPARSE_CALL(cusparseDestroyDnMat(matB));
CUSPARSE_CALL(cusparseDestroyDnMat(matC)); CUSPARSE_CALL(cusparseDestroyDnMat(matC));
#else #else
// allocate matrix for temporary transposed output
DType* trans_out = static_cast<DType*>(device->AllocWorkspace(ctx, m * n * sizeof(DType)));
cusparseMatDescr_t descr; cusparseMatDescr_t descr;
CUSPARSE_CALL(cusparseCreateMatDescr(&descr)); CUSPARSE_CALL(cusparseCreateMatDescr(&descr));
CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL)); CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL));
...@@ -182,9 +183,6 @@ void CusparseCsrmm2( ...@@ -182,9 +183,6 @@ void CusparseCsrmm2(
static_cast<int32_t*>(csr.indices->data), static_cast<int32_t*>(csr.indices->data),
B_data, n, &beta, trans_out, m)); B_data, n, &beta, trans_out, m));
CUSPARSE_CALL(cusparseDestroyMatDescr(descr)); CUSPARSE_CALL(cusparseDestroyMatDescr(descr));
#endif
if (valptr)
device->FreeWorkspace(ctx, valptr);
// transpose the output matrix // transpose the output matrix
if (!thr_entry->cublas_handle) if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle))); CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
...@@ -198,6 +196,9 @@ void CusparseCsrmm2( ...@@ -198,6 +196,9 @@ void CusparseCsrmm2(
&beta, nullptr, n, &beta, nullptr, n,
C_data, n)); C_data, n));
device->FreeWorkspace(ctx, trans_out); device->FreeWorkspace(ctx, trans_out);
#endif
if (valptr)
device->FreeWorkspace(ctx, valptr);
} }
} // namespace cusparse } // namespace cusparse
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment