/*! * Copyright (c) 2020 by Contributors * \file array/cuda/csr_mm.cu * \brief SpSpMM/SpGEMM C APIs and definitions. */ #include #include #include "./functor.cuh" #include "./cusparse_dispatcher.cuh" #include "../../runtime/cuda/cuda_common.h" namespace dgl { using namespace dgl::runtime; namespace aten { namespace cusparse { #if 0 // disabling CUDA 11.0+ implementation for now because of problems on bigger graphs /*! \brief Cusparse implementation of SpGEMM on Csr format for CUDA 11.0+ */ template std::pair CusparseSpgemm( const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B, const NDArray B_weights_array) { // We use Spgemm (SpSpMM) to perform following operation: // C = A x B, where A, B and C are sparse matrices in csr format. const int nnzA = A.indices->shape[0]; const int nnzB = B.indices->shape[0]; const DType alpha = 1.0; const DType beta = 0.0; auto transA = HIPSPARSE_OPERATION_NON_TRANSPOSE; auto transB = HIPSPARSE_OPERATION_NON_TRANSPOSE; // device auto ctx = A.indptr->ctx; auto device = runtime::DeviceAPI::Get(ctx); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); hipStream_t stream = runtime::getCurrentCUDAStream(); const DType* A_weights = A_weights_array.Ptr(); const DType* B_weights = B_weights_array.Ptr(); // allocate cusparse handle if needed if (!thr_entry->cusparse_handle) { CUSPARSE_CALL(hipsparseCreate(&(thr_entry->cusparse_handle))); } CUSPARSE_CALL(hipsparseSetStream(thr_entry->cusparse_handle, stream)); // all one data array hipsparseSpMatDescr_t matA, matB, matC; IdArray dC_csrOffsets = IdArray::Empty({A.num_rows+1}, A.indptr->dtype, A.indptr->ctx); IdType* dC_csrOffsets_data = dC_csrOffsets.Ptr(); constexpr auto idtype = cusparse_idtype::value; constexpr auto dtype = cuda_dtype::value; // Create sparse matrix A, B and C in CSR format CUSPARSE_CALL(hipsparseCreateCsr(&matA, A.num_rows, A.num_cols, nnzA, A.indptr.Ptr(), A.indices.Ptr(), const_cast(A_weights), // hipsparseCreateCsr only accepts non-const pointers idtype, idtype, HIPSPARSE_INDEX_BASE_ZERO, dtype)); CUSPARSE_CALL(hipsparseCreateCsr(&matB, B.num_rows, B.num_cols, nnzB, B.indptr.Ptr(), B.indices.Ptr(), const_cast(B_weights), // hipsparseCreateCsr only accepts non-const pointers idtype, idtype, HIPSPARSE_INDEX_BASE_ZERO, dtype)); CUSPARSE_CALL(hipsparseCreateCsr(&matC, A.num_rows, B.num_cols, 0, nullptr, nullptr, nullptr, idtype, idtype, HIPSPARSE_INDEX_BASE_ZERO, dtype)); // SpGEMM Computation hipsparseSpGEMMDescr_t spgemmDesc; CUSPARSE_CALL(hipsparseSpGEMM_createDescr(&spgemmDesc)); size_t workspace_size1 = 0, workspace_size2 = 0; // ask bufferSize1 bytes for external memory CUSPARSE_CALL(hipsparseSpGEMM_workEstimation( thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta, matC, dtype, HIPSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size1, NULL)); void* workspace1 = (device->AllocWorkspace(ctx, workspace_size1)); // inspect the matrices A and B to understand the memory requiremnent // for the next step CUSPARSE_CALL(hipsparseSpGEMM_workEstimation( thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta, matC, dtype, HIPSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size1, workspace1)); // ask bufferSize2 bytes for external memory CUSPARSE_CALL(hipsparseSpGEMM_compute(thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta, matC, dtype, HIPSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size2, NULL)); void* workspace2 = device->AllocWorkspace(ctx, workspace_size2); // compute the intermediate product of A * B CUSPARSE_CALL(hipsparseSpGEMM_compute(thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta, matC, dtype, HIPSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size2, workspace2)); // get matrix C non-zero entries C_nnz1 int64_t C_num_rows1, C_num_cols1, C_nnz1; CUSPARSE_CALL(hipsparseSpMatGetSize(matC, &C_num_rows1, &C_num_cols1, &C_nnz1)); IdArray dC_columns = IdArray::Empty({C_nnz1}, A.indptr->dtype, A.indptr->ctx); NDArray dC_weights = NDArray::Empty({C_nnz1}, A_weights_array->dtype, A.indptr->ctx); IdType* dC_columns_data = dC_columns.Ptr(); DType* dC_weights_data = dC_weights.Ptr(); // update matC with the new pointers CUSPARSE_CALL(hipsparseCsrSetPointers(matC, dC_csrOffsets_data, dC_columns_data, dC_weights_data)); // copy the final products to the matrix C CUSPARSE_CALL(hipsparseSpGEMM_copy(thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta, matC, dtype, HIPSPARSE_SPGEMM_DEFAULT, spgemmDesc)); device->FreeWorkspace(ctx, workspace1); device->FreeWorkspace(ctx, workspace2); // destroy matrix/vector descriptors CUSPARSE_CALL(hipsparseSpGEMM_destroyDescr(spgemmDesc)); CUSPARSE_CALL(hipsparseDestroySpMat(matA)); CUSPARSE_CALL(hipsparseDestroySpMat(matB)); CUSPARSE_CALL(hipsparseDestroySpMat(matC)); return { CSRMatrix(A.num_rows, B.num_cols, dC_csrOffsets, dC_columns, NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx)), dC_weights}; } #else // __CUDACC_VER_MAJOR__ != 11 /*! \brief Cusparse implementation of SpGEMM on Csr format for older CUDA versions */ template std::pair CusparseSpgemm( const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B, const NDArray B_weights_array) { int nnzC; csrgemm2Info_t info = nullptr; size_t workspace_size; const DType alpha = 1.; const int nnzA = A.indices->shape[0]; const int nnzB = B.indices->shape[0]; const int m = A.num_rows; const int n = A.num_cols; const int k = B.num_cols; auto ctx = A.indptr->ctx; auto device = runtime::DeviceAPI::Get(ctx); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); hipStream_t stream = runtime::getCurrentCUDAStream(); auto idtype = A.indptr->dtype; auto dtype = A_weights_array->dtype; const DType* A_weights = A_weights_array.Ptr(); const DType* B_weights = B_weights_array.Ptr(); if (!thr_entry->cusparse_handle) { CUSPARSE_CALL(hipsparseCreate(&(thr_entry->cusparse_handle))); } CUSPARSE_CALL(hipsparseSetStream(thr_entry->cusparse_handle, stream)); CUSPARSE_CALL(hipsparseSetPointerMode( thr_entry->cusparse_handle, HIPSPARSE_POINTER_MODE_HOST)); CUSPARSE_CALL(hipsparseCreateCsrgemm2Info(&info)); hipsparseMatDescr_t matA, matB, matC, matD; CUSPARSE_CALL(hipsparseCreateMatDescr(&matA)); CUSPARSE_CALL(hipsparseCreateMatDescr(&matB)); CUSPARSE_CALL(hipsparseCreateMatDescr(&matC)); CUSPARSE_CALL(hipsparseCreateMatDescr(&matD)); // needed even if D is null CUSPARSE_CALL(CSRGEMM::bufferSizeExt(thr_entry->cusparse_handle, m, n, k, &alpha, matA, nnzA, A.indptr.Ptr(), A.indices.Ptr(), matB, nnzB, B.indptr.Ptr(), B.indices.Ptr(), nullptr, matD, 0, nullptr, nullptr, info, &workspace_size)); void *workspace = device->AllocWorkspace(ctx, workspace_size); IdArray C_indptr = IdArray::Empty({m + 1}, idtype, ctx); CUSPARSE_CALL(CSRGEMM::nnz(thr_entry->cusparse_handle, m, n, k, matA, nnzA, A.indptr.Ptr(), A.indices.Ptr(), matB, nnzB, B.indptr.Ptr(), B.indices.Ptr(), matD, 0, nullptr, nullptr, matC, C_indptr.Ptr(), &nnzC, info, workspace)); IdArray C_indices = IdArray::Empty({nnzC}, idtype, ctx); NDArray C_weights = NDArray::Empty({nnzC}, dtype, ctx); CUSPARSE_CALL(CSRGEMM::compute(thr_entry->cusparse_handle, m, n, k, &alpha, matA, nnzA, A_weights, A.indptr.Ptr(), A.indices.Ptr(), matB, nnzB, B_weights, B.indptr.Ptr(), B.indices.Ptr(), nullptr, matD, 0, nullptr, nullptr, nullptr, matC, C_weights.Ptr(), C_indptr.Ptr(), C_indices.Ptr(), info, workspace)); device->FreeWorkspace(ctx, workspace); CUSPARSE_CALL(hipsparseDestroyCsrgemm2Info(info)); CUSPARSE_CALL(hipsparseDestroyMatDescr(matA)); CUSPARSE_CALL(hipsparseDestroyMatDescr(matB)); CUSPARSE_CALL(hipsparseDestroyMatDescr(matC)); CUSPARSE_CALL(hipsparseDestroyMatDescr(matD)); return { CSRMatrix(m, k, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)), C_weights}; } #endif // __CUDACC_VER_MAJOR__ == 11 } // namespace cusparse template std::pair CSRMM( const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B, NDArray B_weights) { auto ctx = A.indptr->ctx; auto device = runtime::DeviceAPI::Get(ctx); CSRMatrix newA, newB; bool cast = false; // Cast 64 bit indices to 32 bit. if (A.indptr->dtype.bits == 64) { newA = CSRMatrix( A.num_rows, A.num_cols, AsNumBits(A.indptr, 32), AsNumBits(A.indices, 32), AsNumBits(A.data, 32)); newB = CSRMatrix( B.num_rows, B.num_cols, AsNumBits(B.indptr, 32), AsNumBits(B.indices, 32), AsNumBits(B.data, 32)); cast = true; } // Reorder weights if A or B has edge IDs NDArray newA_weights, newB_weights; if (CSRHasData(A)) newA_weights = IndexSelect(A_weights, A.data); if (CSRHasData(B)) newB_weights = IndexSelect(B_weights, B.data); auto result = cusparse::CusparseSpgemm( cast ? newA : A, CSRHasData(A) ? newA_weights : A_weights, cast ? newB : B, CSRHasData(B) ? newB_weights : B_weights); // Cast 32 bit indices back to 64 bit if necessary if (cast) { CSRMatrix C = result.first; return { CSRMatrix(C.num_rows, C.num_cols, AsNumBits(C.indptr, 64), AsNumBits(C.indices, 64), AsNumBits(C.data, 64)), result.second}; } else { return result; } } #ifdef USE_FP16 template std::pair CSRMM( const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); template std::pair CSRMM( const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); #endif template std::pair CSRMM( const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); template std::pair CSRMM( const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); template std::pair CSRMM( const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); template std::pair CSRMM( const CSRMatrix&, NDArray, const CSRMatrix&, NDArray); } // namespace aten } // namespace dgl