"src/diffusers/pipeline_flax_utils.py" did not exist on "a75846379afc0557352be7df76780c3ac4aa113e"
Unverified Commit 9d90faf0 authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

Use ALG2 for SpMM in cuSparse (#2550)

parent 2c6d0716
......@@ -122,8 +122,6 @@ void CusparseCsrmm2(
CUSPARSE_CALL(cusparseCreate(&(thr_entry->cusparse_handle)));
}
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, thr_entry->stream));
// allocate matrix for temporary transposed output
DType* trans_out = static_cast<DType*>(device->AllocWorkspace(ctx, m * n * sizeof(DType)));
// all one data array
DType* valptr = nullptr;
if (!A_data) {
......@@ -142,25 +140,25 @@ void CusparseCsrmm2(
CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I,
CUSPARSE_INDEX_BASE_ZERO, cuda_dtype));
CUSPARSE_CALL(cusparseCreateDnMat(&matB,
n, k, n,
const_cast<DType*>(B_data), cuda_dtype, CUSPARSE_ORDER_COL));
k, n, n,
const_cast<DType*>(B_data), cuda_dtype, CUSPARSE_ORDER_ROW));
CUSPARSE_CALL(cusparseCreateDnMat(&matC,
m, n, m,
trans_out, cuda_dtype, CUSPARSE_ORDER_COL));
m, n, n,
C_data, cuda_dtype, CUSPARSE_ORDER_ROW));
auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
auto transB = CUSPARSE_OPERATION_TRANSPOSE;
auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE;
size_t workspace_size;
CUSPARSE_CALL(cusparseSpMM_bufferSize(
thr_entry->cusparse_handle, transA, transB,
&alpha, matA, matB, &beta, matC,
cuda_dtype, CUSPARSE_CSRMM_ALG1,
cuda_dtype, CUSPARSE_SPMM_CSR_ALG2,
&workspace_size));
void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUSPARSE_CALL(cusparseSpMM(
thr_entry->cusparse_handle, transA, transB,
&alpha, matA, matB, &beta, matC,
cuda_dtype, CUSPARSE_CSRMM_ALG1,
cuda_dtype, CUSPARSE_SPMM_CSR_ALG2,
workspace));
device->FreeWorkspace(ctx, workspace);
......@@ -168,6 +166,9 @@ void CusparseCsrmm2(
CUSPARSE_CALL(cusparseDestroyDnMat(matB));
CUSPARSE_CALL(cusparseDestroyDnMat(matC));
#else
// allocate matrix for temporary transposed output
DType* trans_out = static_cast<DType*>(device->AllocWorkspace(ctx, m * n * sizeof(DType)));
cusparseMatDescr_t descr;
CUSPARSE_CALL(cusparseCreateMatDescr(&descr));
CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL));
......@@ -182,9 +183,6 @@ void CusparseCsrmm2(
static_cast<int32_t*>(csr.indices->data),
B_data, n, &beta, trans_out, m));
CUSPARSE_CALL(cusparseDestroyMatDescr(descr));
#endif
if (valptr)
device->FreeWorkspace(ctx, valptr);
// transpose the output matrix
if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
......@@ -198,6 +196,9 @@ void CusparseCsrmm2(
&beta, nullptr, n,
C_data, n));
device->FreeWorkspace(ctx, trans_out);
#endif
if (valptr)
device->FreeWorkspace(ctx, valptr);
}
} // namespace cusparse
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment