Unverified Commit fb223d47 authored by Chang Liu's avatar Chang Liu Committed by GitHub
Browse files

Add support for next cusparse release (#4974)

* Add support for next cusparse release

* Fix lint

* Add switch and tune the performance

* Fix lint issue

* Fine tune the heuristics

* Fix lint issue

* Address comments

* Minor fix

* Address comments
parent 97fbd94d
......@@ -6,10 +6,11 @@
#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
#include <limits>
#include "../../runtime/cuda/cuda_common.h"
#include "./cusparse_dispatcher.cuh"
#include "./functor.cuh"
namespace dgl {
using namespace dgl::runtime;
......@@ -17,10 +18,9 @@ using namespace dgl::runtime;
namespace aten {
namespace cusparse {
#if 0 // disabling CUDA 11.0+ implementation for now because of problems on
// bigger graphs
#if CUDART_VERSION >= 12000
/** @brief Cusparse implementation of SpGEMM on Csr format for CUDA 11.0+ */
/** @brief Cusparse implementation of SpGEMM on Csr format for CUDA 12.0+ */
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseSpgemm(
const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,
......@@ -54,14 +54,14 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
constexpr auto dtype = cuda_dtype<DType>::value;
// Create sparse matrix A, B and C in CSR format
CUSPARSE_CALL(cusparseCreateCsr(
&matA, A.num_rows, A.num_cols, nnzA, A.indptr.Ptr<DType>(),
A.indices.Ptr<DType>(),
&matA, A.num_rows, A.num_cols, nnzA, A.indptr.Ptr<IdType>(),
A.indices.Ptr<IdType>(),
// cusparseCreateCsr only accepts non-const pointers.
const_cast<DType*>(A_weights),
idtype, idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
CUSPARSE_CALL(cusparseCreateCsr(
&matB, B.num_rows, B.num_cols, nnzB, B.indptr.Ptr<DType>(),
B.indices.Ptr<DType>(),
&matB, B.num_rows, B.num_cols, nnzB, B.indptr.Ptr<IdType>(),
B.indices.Ptr<IdType>(),
// cusparseCreateCsr only accepts non-const pointers.
const_cast<DType*>(B_weights),
idtype, idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
......@@ -70,30 +70,110 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
// SpGEMM Computation
cusparseSpGEMMDescr_t spgemmDesc;
cusparseSpGEMMAlg_t alg = CUSPARSE_SPGEMM_DEFAULT;
CUSPARSE_CALL(cusparseSpGEMM_createDescr(&spgemmDesc));
size_t workspace_size1 = 0, workspace_size2 = 0;
size_t workspace_size1 = 0, workspace_size2 = 0, workspace_size3 = 0;
// ask bufferSize1 bytes for external memory
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size1,
matC, dtype, alg, spgemmDesc, &workspace_size1,
NULL));
void* workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
// inspect the matrices A and B to understand the memory requiremnent
// for the next step
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size1,
cusparseStatus_t e =
cusparseSpGEMM_workEstimation(thr_entry->cusparse_handle, transA,
transB, &alpha, matA, matB, &beta,
matC, dtype, alg, spgemmDesc,
&workspace_size1, workspace1);
// CUSPARSE_SPGEMM_DEFAULT not support getting num_prods > 2^31 -1
// and throws insufficient memory error within workEstimation call
if (e == CUSPARSE_STATUS_INSUFFICIENT_RESOURCES) {
// fall back to ALG2 to estimate num_prods
alg = CUSPARSE_SPGEMM_ALG2;
device->FreeWorkspace(ctx, workspace1);
// rerun cusparseSpGEMM_workEstimation
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(thr_entry->cusparse_handle,
transA, transB, &alpha, matA,
matB, &beta, matC, dtype, alg,
spgemmDesc, &workspace_size1,
NULL));
workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(thr_entry->cusparse_handle,
transA, transB, &alpha, matA,
matB, &beta, matC, dtype, alg,
spgemmDesc, &workspace_size1,
workspace1));
// ask bufferSize2 bytes for external memory
CUSPARSE_CALL(cusparseSpGEMM_compute(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size2,
} else {
CHECK(e == CUSPARSE_STATUS_SUCCESS) << "CUSPARSE ERROR in SpGEMM: " << e;
}
// get the number of intermediate products required for SpGEMM compute
// num_prods indicates device memory consumption for SpGEMM if using ALG2/3
int64_t num_prods;
CUSPARSE_CALL(cusparseSpGEMM_getNumProducts(spgemmDesc, &num_prods));
// assume free GPU mem at least ~15G for below heuristics to work
// user-defined medium problem size (below will use DEFAULT)
int64_t MEDIUM_NUM_PRODUCTS = 400000000; // 400*1000*1000;
// user-defined large problem size (above will use ALG3)
int64_t LARGE_NUM_PRODUCTS = 800000000; // 800*1000*1000;
// switch to ALG2/ALG3 for medium & large problem size
if (alg == CUSPARSE_SPGEMM_DEFAULT && num_prods > MEDIUM_NUM_PRODUCTS) {
// use ALG3 for very large problem
alg = num_prods > LARGE_NUM_PRODUCTS ? CUSPARSE_SPGEMM_ALG3 :
CUSPARSE_SPGEMM_ALG2;
device->FreeWorkspace(ctx, workspace1);
// rerun cusparseSpGEMM_workEstimation
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(thr_entry->cusparse_handle,
transA, transB, &alpha, matA,
matB, &beta, matC, dtype, alg,
spgemmDesc, &workspace_size1,
NULL));
workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(thr_entry->cusparse_handle,
transA, transB, &alpha, matA,
matB, &beta, matC, dtype, alg,
spgemmDesc, &workspace_size1,
workspace1));
} else if (alg == CUSPARSE_SPGEMM_ALG2 && num_prods > LARGE_NUM_PRODUCTS) {
// no need to rerun cusparseSpGEMM_workEstimation between ALG2 and ALG3
alg = CUSPARSE_SPGEMM_ALG3;
}
if (alg == CUSPARSE_SPGEMM_ALG2 || alg == CUSPARSE_SPGEMM_ALG3) {
// estimate memory for ALG2/ALG3; note chunk_fraction is only used by ALG3
// reduce chunk_fraction if crash due to mem., but it trades off speed
float chunk_fraction = num_prods < 4 * LARGE_NUM_PRODUCTS ? 0.15 : 0.05;
CUSPARSE_CALL(cusparseSpGEMM_estimateMemory(thr_entry->cusparse_handle,
transA, transB, &alpha, matA,
matB, &beta, matC, dtype, alg,
spgemmDesc, chunk_fraction,
&workspace_size3,
NULL, NULL));
void* workspace3 = (device->AllocWorkspace(ctx, workspace_size3));
CUSPARSE_CALL(cusparseSpGEMM_estimateMemory(thr_entry->cusparse_handle,
transA, transB, &alpha, matA,
matB, &beta, matC, dtype, alg,
spgemmDesc, chunk_fraction,
&workspace_size3,
workspace3, &workspace_size2));
device->FreeWorkspace(ctx, workspace3);
} else {
CUSPARSE_CALL(cusparseSpGEMM_compute(thr_entry->cusparse_handle,
transA, transB, &alpha, matA,
matB, &beta, matC, dtype, alg,
spgemmDesc, &workspace_size2,
NULL));
}
// ask bufferSize2 bytes for external memory
void* workspace2 = device->AllocWorkspace(ctx, workspace_size2);
// compute the intermediate product of A * B
CUSPARSE_CALL(cusparseSpGEMM_compute(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size2,
matC, dtype, alg, spgemmDesc, &workspace_size2,
workspace2));
// get matrix C non-zero entries C_nnz1
int64_t C_num_rows1, C_num_cols1, C_nnz1;
......@@ -110,7 +190,7 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
// copy the final products to the matrix C
CUSPARSE_CALL(cusparseSpGEMM_copy(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc));
matC, dtype, alg, spgemmDesc));
device->FreeWorkspace(ctx, workspace1);
device->FreeWorkspace(ctx, workspace2);
......@@ -126,7 +206,7 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
dC_weights};
}
#else // __CUDACC_VER_MAJOR__ != 11
#else // CUDART_VERSION < 12000
/** @brief Cusparse implementation of SpGEMM on Csr format for older CUDA
* versions */
......@@ -202,7 +282,7 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
C_weights};
}
#endif // __CUDACC_VER_MAJOR__ == 11
#endif // CUDART_VERSION >= 12000
} // namespace cusparse
template <int XPU, typename IdType, typename DType>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment