Unverified Commit cb5e3489 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Minor code style fix. (#4843)



* [Misc] Change the max line length for cpp to 80 in lint.

* blabla

* blabla

* blabla

* ablabla
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 4cd0a685
......@@ -212,7 +212,7 @@ CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs);
template <DGLDeviceType XPU, typename IdType>
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr);
///////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template <DGLDeviceType XPU, typename IdType>
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col);
......
......@@ -23,9 +23,7 @@ namespace cusparse {
/** @brief Cusparse implementation of SpGEMM on Csr format for CUDA 11.0+ */
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseSpgemm(
const CSRMatrix& A,
const NDArray A_weights_array,
const CSRMatrix& B,
const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,
const NDArray B_weights_array) {
// We use Spgemm (SpSpMM) to perform following operation:
// C = A x B, where A, B and C are sparse matrices in csr format.
......@@ -49,70 +47,70 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
CUSPARSE_CALL(cusparseSetStream(thr_entry->cusparse_handle, stream));
// all one data array
cusparseSpMatDescr_t matA, matB, matC;
IdArray dC_csrOffsets = IdArray::Empty({A.num_rows+1}, A.indptr->dtype, A.indptr->ctx);
IdArray dC_csrOffsets =
IdArray::Empty({A.num_rows + 1}, A.indptr->dtype, A.indptr->ctx);
IdType* dC_csrOffsets_data = dC_csrOffsets.Ptr<IdType>();
constexpr auto idtype = cusparse_idtype<IdType>::value;
constexpr auto dtype = cuda_dtype<DType>::value;
// Create sparse matrix A, B and C in CSR format
CUSPARSE_CALL(cusparseCreateCsr(&matA,
A.num_rows, A.num_cols, nnzA,
A.indptr.Ptr<DType>(),
CUSPARSE_CALL(cusparseCreateCsr(
&matA, A.num_rows, A.num_cols, nnzA, A.indptr.Ptr<DType>(),
A.indices.Ptr<DType>(),
const_cast<DType*>(A_weights), // cusparseCreateCsr only accepts non-const pointers
// cusparseCreateCsr only accepts non-const pointers.
const_cast<DType*>(A_weights),
idtype, idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
CUSPARSE_CALL(cusparseCreateCsr(&matB,
B.num_rows, B.num_cols, nnzB,
B.indptr.Ptr<DType>(),
CUSPARSE_CALL(cusparseCreateCsr(
&matB, B.num_rows, B.num_cols, nnzB, B.indptr.Ptr<DType>(),
B.indices.Ptr<DType>(),
const_cast<DType*>(B_weights), // cusparseCreateCsr only accepts non-const pointers
// cusparseCreateCsr only accepts non-const pointers.
const_cast<DType*>(B_weights),
idtype, idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
CUSPARSE_CALL(cusparseCreateCsr(&matC,
A.num_rows, B.num_cols, 0,
nullptr, nullptr, nullptr, idtype, idtype,
CUSPARSE_INDEX_BASE_ZERO, dtype));
CUSPARSE_CALL(cusparseCreateCsr(
&matC, A.num_rows, B.num_cols, 0, nullptr, nullptr, nullptr, idtype,
idtype, CUSPARSE_INDEX_BASE_ZERO, dtype));
// SpGEMM Computation
cusparseSpGEMMDescr_t spgemmDesc;
CUSPARSE_CALL(cusparseSpGEMM_createDescr(&spgemmDesc));
size_t workspace_size1 = 0, workspace_size2 = 0;
// ask bufferSize1 bytes for external memory
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
thr_entry->cusparse_handle, transA, transB,
&alpha, matA, matB, &beta, matC, dtype,
CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size1,
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size1,
NULL));
void* workspace1 = (device->AllocWorkspace(ctx, workspace_size1));
// inspect the matrices A and B to understand the memory requiremnent
// for the next step
CUSPARSE_CALL(cusparseSpGEMM_workEstimation(
thr_entry->cusparse_handle, transA, transB,
&alpha, matA, matB, &beta, matC, dtype,
CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size1,
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size1,
workspace1));
// ask bufferSize2 bytes for external memory
CUSPARSE_CALL(cusparseSpGEMM_compute(thr_entry->cusparse_handle,
transA, transB, &alpha, matA, matB, &beta, matC,
dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size2,
CUSPARSE_CALL(cusparseSpGEMM_compute(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size2,
NULL));
void* workspace2 = device->AllocWorkspace(ctx, workspace_size2);
// compute the intermediate product of A * B
CUSPARSE_CALL(cusparseSpGEMM_compute(thr_entry->cusparse_handle,
transA, transB, &alpha, matA, matB, &beta, matC,
dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size2,
CUSPARSE_CALL(cusparseSpGEMM_compute(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &workspace_size2,
workspace2));
// get matrix C non-zero entries C_nnz1
int64_t C_num_rows1, C_num_cols1, C_nnz1;
CUSPARSE_CALL(cusparseSpMatGetSize(matC, &C_num_rows1, &C_num_cols1, &C_nnz1));
CUSPARSE_CALL(
cusparseSpMatGetSize(matC, &C_num_rows1, &C_num_cols1, &C_nnz1));
IdArray dC_columns = IdArray::Empty({C_nnz1}, A.indptr->dtype, A.indptr->ctx);
NDArray dC_weights = NDArray::Empty({C_nnz1}, A_weights_array->dtype, A.indptr->ctx);
NDArray dC_weights = NDArray::Empty(
{C_nnz1}, A_weights_array->dtype, A.indptr->ctx);
IdType* dC_columns_data = dC_columns.Ptr<IdType>();
DType* dC_weights_data = dC_weights.Ptr<DType>();
// update matC with the new pointers
CUSPARSE_CALL(cusparseCsrSetPointers(matC, dC_csrOffsets_data,
dC_columns_data, dC_weights_data));
CUSPARSE_CALL(cusparseCsrSetPointers(
matC, dC_csrOffsets_data, dC_columns_data, dC_weights_data));
// copy the final products to the matrix C
CUSPARSE_CALL(cusparseSpGEMM_copy(thr_entry->cusparse_handle,
transA, transB, &alpha, matA, matB, &beta, matC,
dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc));
CUSPARSE_CALL(cusparseSpGEMM_copy(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc));
device->FreeWorkspace(ctx, workspace1);
device->FreeWorkspace(ctx, workspace2);
......@@ -122,7 +120,8 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
CUSPARSE_CALL(cusparseDestroySpMat(matB));
CUSPARSE_CALL(cusparseDestroySpMat(matC));
return {
CSRMatrix(A.num_rows, B.num_cols, dC_csrOffsets, dC_columns,
CSRMatrix(
A.num_rows, B.num_cols, dC_csrOffsets, dC_columns,
NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx)),
dC_weights};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment