"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "365313edd2658e5d048d97fad04c7729deb9815b"
Unverified Commit 88964a82 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

[Bugfix] Fix cusparseCreateCsr format for cuda12 (#6121)

parent 1e16e4ca
......@@ -57,17 +57,17 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
&matA, A.num_rows, A.num_cols, nnzA, A.indptr.Ptr<IdType>(),
A.indices.Ptr<IdType>(),
// cusparseCreateCsr only accepts non-const pointers.
const_cast<DType*>(A_weights),
idtype, idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
const_cast<DType*>(A_weights), idtype, idtype, CUSPARSE_INDEX_BASE_ZERO,
dtype));
CUSPARSE_CALL(cusparseCreateCsr(
&matB, B.num_rows, B.num_cols, nnzB, B.indptr.Ptr<IdType>(),
B.indices.Ptr<IdType>(),
// cusparseCreateCsr only accepts non-const pointers.
const_cast<DType*>(B_weights),
idtype, idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
const_cast<DType*>(B_weights), idtype, idtype, CUSPARSE_INDEX_BASE_ZERO,
dtype));
CUSPARSE_CALL(cusparseCreateCsr(
&matC, A.num_rows, B.num_cols, 0, nullptr, nullptr, nullptr, idtype,
idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
&matC, A.num_rows, B.num_cols, 0, dC_csrOffsets_data, nullptr, nullptr,
idtype, idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
// SpGEMM Computation
cusparseSpGEMMDescr_t spgemmDesc;
cusparseSpGEMMAlg_t alg = CUSPARSE_SPGEMM_DEFAULT;
......@@ -77,15 +77,12 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
// ask bufferSize1 bytes for external memory
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, alg, spgemmDesc, &workspace_size1,
NULL));
matC, dtype, alg, spgemmDesc, &workspace_size1, NULL));
void* workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
// inspect the matrices A and B to understand the memory requiremnent
cusparseStatus_t e =
cusparseSpGEMM_workEstimation(thr_entry->cusparse_handle, transA,
transB, &alpha, matA, matB, &beta,
matC, dtype, alg, spgemmDesc,
&workspace_size1, workspace1);
cusparseStatus_t e = cusparseSpGEMM_workEstimation(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, alg, spgemmDesc, &workspace_size1, workspace1);
// CUSPARSE_SPGEMM_DEFAULT not support getting num_prods > 2^31 -1
// and throws insufficient memory error within workEstimation call
if (e == CUSPARSE_STATUS_INSUFFICIENT_RESOURCES) {
......@@ -93,17 +90,13 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
alg = CUSPARSE_SPGEMM_ALG2;
device->FreeWorkspace(ctx, workspace1);
// rerun cusparseSpGEMM_workEstimation
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(thr_entry->cusparse_handle,
transA, transB, &alpha, matA,
matB, &beta, matC, dtype, alg,
spgemmDesc, &workspace_size1,
NULL));
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, alg, spgemmDesc, &workspace_size1, NULL));
workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(thr_entry->cusparse_handle,
transA, transB, &alpha, matA,
matB, &beta, matC, dtype, alg,
spgemmDesc, &workspace_size1,
workspace1));
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, alg, spgemmDesc, &workspace_size1, workspace1));
} else {
CHECK(e == CUSPARSE_STATUS_SUCCESS) << "CUSPARSE ERROR in SpGEMM: " << e;
}
......@@ -117,27 +110,23 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
// user-defined medium problem size (below will use DEFAULT)
int64_t MEDIUM_NUM_PRODUCTS = 400000000; // 400*1000*1000;
// user-defined large problem size (above will use ALG3)
int64_t LARGE_NUM_PRODUCTS = 800000000; // 800*1000*1000;
int64_t LARGE_NUM_PRODUCTS = 800000000; // 800*1000*1000;
// switch to ALG2/ALG3 for medium & large problem size
if (alg == CUSPARSE_SPGEMM_DEFAULT && num_prods > MEDIUM_NUM_PRODUCTS) {
// use ALG3 for very large problem
alg = num_prods > LARGE_NUM_PRODUCTS ? CUSPARSE_SPGEMM_ALG3 :
CUSPARSE_SPGEMM_ALG2;
alg = num_prods > LARGE_NUM_PRODUCTS ? CUSPARSE_SPGEMM_ALG3
: CUSPARSE_SPGEMM_ALG2;
device->FreeWorkspace(ctx, workspace1);
// rerun cusparseSpGEMM_workEstimation
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(thr_entry->cusparse_handle,
transA, transB, &alpha, matA,
matB, &beta, matC, dtype, alg,
spgemmDesc, &workspace_size1,
NULL));
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, alg, spgemmDesc, &workspace_size1, NULL));
workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(thr_entry->cusparse_handle,
transA, transB, &alpha, matA,
matB, &beta, matC, dtype, alg,
spgemmDesc, &workspace_size1,
workspace1));
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, alg, spgemmDesc, &workspace_size1, workspace1));
} else if (alg == CUSPARSE_SPGEMM_ALG2 && num_prods > LARGE_NUM_PRODUCTS) {
// no need to rerun cusparseSpGEMM_workEstimation between ALG2 and ALG3
alg = CUSPARSE_SPGEMM_ALG3;
......@@ -147,41 +136,34 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
// estimate memory for ALG2/ALG3; note chunk_fraction is only used by ALG3
// reduce chunk_fraction if crash due to mem., but it trades off speed
float chunk_fraction = num_prods < 4 * LARGE_NUM_PRODUCTS ? 0.15 : 0.05;
CUSPARSE_CALL(cusparseSpGEMM_estimateMemory(thr_entry->cusparse_handle,
transA, transB, &alpha, matA,
matB, &beta, matC, dtype, alg,
spgemmDesc, chunk_fraction,
&workspace_size3,
NULL, NULL));
CUSPARSE_CALL(cusparseSpGEMM_estimateMemory(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, alg, spgemmDesc, chunk_fraction, &workspace_size3, NULL,
NULL));
void* workspace3 = (device->AllocWorkspace(ctx, workspace_size3));
CUSPARSE_CALL(cusparseSpGEMM_estimateMemory(thr_entry->cusparse_handle,
transA, transB, &alpha, matA,
matB, &beta, matC, dtype, alg,
spgemmDesc, chunk_fraction,
&workspace_size3,
workspace3, &workspace_size2));
CUSPARSE_CALL(cusparseSpGEMM_estimateMemory(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, alg, spgemmDesc, chunk_fraction, &workspace_size3,
workspace3, &workspace_size2));
device->FreeWorkspace(ctx, workspace3);
} else {
CUSPARSE_CALL(cusparseSpGEMM_compute(thr_entry->cusparse_handle,
transA, transB, &alpha, matA,
matB, &beta, matC, dtype, alg,
spgemmDesc, &workspace_size2,
NULL));
CUSPARSE_CALL(cusparseSpGEMM_compute(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, alg, spgemmDesc, &workspace_size2, NULL));
}
// ask bufferSize2 bytes for external memory
void* workspace2 = device->AllocWorkspace(ctx, workspace_size2);
// compute the intermediate product of A * B
CUSPARSE_CALL(cusparseSpGEMM_compute(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, alg, spgemmDesc, &workspace_size2,
workspace2));
matC, dtype, alg, spgemmDesc, &workspace_size2, workspace2));
// get matrix C non-zero entries C_nnz1
int64_t C_num_rows1, C_num_cols1, C_nnz1;
CUSPARSE_CALL(
cusparseSpMatGetSize(matC, &C_num_rows1, &C_num_cols1, &C_nnz1));
IdArray dC_columns = IdArray::Empty({C_nnz1}, A.indptr->dtype, A.indptr->ctx);
NDArray dC_weights = NDArray::Empty(
{C_nnz1}, A_weights_array->dtype, A.indptr->ctx);
NDArray dC_weights =
NDArray::Empty({C_nnz1}, A_weights_array->dtype, A.indptr->ctx);
IdType* dC_columns_data = dC_columns.Ptr<IdType>();
DType* dC_weights_data = dC_weights.Ptr<DType>();
// update matC with the new pointers
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment