Unverified Commit a4e19691 authored by Triston's avatar Triston Committed by GitHub
Browse files

[Determinism] Enable environment var to use cusparse spmm deterministic algorithm (#7310)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent a3d20dce
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <cstdlib>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./functor.cuh" #include "./functor.cuh"
#include "./ge_spmm.cuh" #include "./ge_spmm.cuh"
...@@ -28,6 +30,9 @@ void SpMMCsr( ...@@ -28,6 +30,9 @@ void SpMMCsr(
std::vector<NDArray> out_aux) { std::vector<NDArray> out_aux) {
bool is_scalar_efeat = efeat.NumElements() == csr.indices->shape[0]; bool is_scalar_efeat = efeat.NumElements() == csr.indices->shape[0];
bool use_efeat = op != "copy_lhs"; bool use_efeat = op != "copy_lhs";
bool use_deterministic_alg_only = false;
if (NULL != std::getenv("USE_DETERMINISTIC_ALG"))
use_deterministic_alg_only = true;
if (reduce == "sum") { if (reduce == "sum") {
bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols); bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
...@@ -37,7 +42,7 @@ void SpMMCsr( ...@@ -37,7 +42,7 @@ void SpMMCsr(
for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i]; for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i];
CusparseCsrmm2<DType, IdType>( CusparseCsrmm2<DType, IdType>(
ufeat->ctx, csr, static_cast<DType*>(ufeat->data), nullptr, ufeat->ctx, csr, static_cast<DType*>(ufeat->data), nullptr,
static_cast<DType*>(out->data), x_length); static_cast<DType*>(out->data), x_length, use_deterministic_alg_only);
} else if ( } else if (
op == "mul" && is_scalar_efeat && op == "mul" && is_scalar_efeat &&
cusparse_available<DType, IdType>(more_nnz)) { cusparse_available<DType, IdType>(more_nnz)) {
...@@ -50,7 +55,7 @@ void SpMMCsr( ...@@ -50,7 +55,7 @@ void SpMMCsr(
CusparseCsrmm2<DType, IdType>( CusparseCsrmm2<DType, IdType>(
ufeat->ctx, csr, static_cast<DType*>(ufeat->data), ufeat->ctx, csr, static_cast<DType*>(ufeat->data),
static_cast<DType*>(efeat->data), static_cast<DType*>(out->data), static_cast<DType*>(efeat->data), static_cast<DType*>(out->data),
x_length); x_length, use_deterministic_alg_only);
} else { // general kernel } else { // general kernel
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >( cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
......
...@@ -196,7 +196,8 @@ cusparseStatus_t Xcsrmm2<double>( ...@@ -196,7 +196,8 @@ cusparseStatus_t Xcsrmm2<double>(
template <typename DType, typename IdType> template <typename DType, typename IdType>
void CusparseCsrmm2( void CusparseCsrmm2(
const DGLContext& ctx, const CSRMatrix& csr, const DType* B_data, const DGLContext& ctx, const CSRMatrix& csr, const DType* B_data,
const DType* A_data, DType* C_data, int x_length) { const DType* A_data, DType* C_data, int x_length,
bool use_deterministic_alg_only = false) {
// We use csrmm2 to perform following operation: // We use csrmm2 to perform following operation:
// C = A x B, where A is a sparse matrix in csr format, B is the dense matrix // C = A x B, where A is a sparse matrix in csr format, B is the dense matrix
// for node feature tensor. However, since cusparse only supports // for node feature tensor. However, since cusparse only supports
...@@ -244,13 +245,16 @@ void CusparseCsrmm2( ...@@ -244,13 +245,16 @@ void CusparseCsrmm2(
auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE; auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE; auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE;
size_t workspace_size; size_t workspace_size;
cusparseSpMMAlg_t spmm_alg = use_deterministic_alg_only
? CUSPARSE_SPMM_CSR_ALG3
: CUSPARSE_SPMM_CSR_ALG2;
CUSPARSE_CALL(cusparseSpMM_bufferSize( CUSPARSE_CALL(cusparseSpMM_bufferSize(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta, thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, CUSPARSE_SPMM_CSR_ALG2, &workspace_size)); matC, dtype, spmm_alg, &workspace_size));
void* workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUSPARSE_CALL(cusparseSpMM( CUSPARSE_CALL(cusparseSpMM(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta, thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, CUSPARSE_SPMM_CSR_ALG2, workspace)); matC, dtype, spmm_alg, workspace));
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
CUSPARSE_CALL(cusparseDestroySpMat(matA)); CUSPARSE_CALL(cusparseDestroySpMat(matA));
...@@ -283,8 +287,8 @@ void CusparseCsrmm2( ...@@ -283,8 +287,8 @@ void CusparseCsrmm2(
template <typename DType, typename IdType> template <typename DType, typename IdType>
void CusparseCsrmm2Hetero( void CusparseCsrmm2Hetero(
const DGLContext& ctx, const CSRMatrix& csr, const DType* B_data, const DGLContext& ctx, const CSRMatrix& csr, const DType* B_data,
const DType* A_data, DType* C_data, int64_t x_length, const DType* A_data, DType* C_data, int64_t x_length, cudaStream_t strm_id,
cudaStream_t strm_id) { bool use_deterministic_alg_only = false) {
// We use csrmm2 to perform following operation: // We use csrmm2 to perform following operation:
// C = A x B, where A is a sparse matrix in csr format, B is the dense matrix // C = A x B, where A is a sparse matrix in csr format, B is the dense matrix
// for node feature tensor. However, since cusparse only supports // for node feature tensor. However, since cusparse only supports
...@@ -335,13 +339,16 @@ void CusparseCsrmm2Hetero( ...@@ -335,13 +339,16 @@ void CusparseCsrmm2Hetero(
auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE; auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE; auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE;
size_t workspace_size; size_t workspace_size;
cusparseSpMMAlg_t spmm_alg = use_deterministic_alg_only
? CUSPARSE_SPMM_CSR_ALG3
: CUSPARSE_SPMM_CSR_ALG2;
CUSPARSE_CALL(cusparseSpMM_bufferSize( CUSPARSE_CALL(cusparseSpMM_bufferSize(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta, thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, CUSPARSE_SPMM_CSR_ALG2, &workspace_size)); matC, dtype, spmm_alg, &workspace_size));
void* workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUSPARSE_CALL(cusparseSpMM( CUSPARSE_CALL(cusparseSpMM(
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta, thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, CUSPARSE_SPMM_CSR_ALG2, workspace)); matC, dtype, spmm_alg, workspace));
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
CUSPARSE_CALL(cusparseDestroySpMat(matA)); CUSPARSE_CALL(cusparseDestroySpMat(matA));
...@@ -562,8 +569,8 @@ __global__ void SpMMCmpCsrHeteroKernel( ...@@ -562,8 +569,8 @@ __global__ void SpMMCmpCsrHeteroKernel(
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
while (tx < out_len) { while (tx < out_len) {
using accum_type = typename accum_dtype<DType>::type; using accum_type = typename accum_dtype<DType>::type;
accum_type local_accum = static_cast<accum_type>( accum_type local_accum =
out[ty * out_len + tx]); // ReduceOp::zero(); static_cast<accum_type>(out[ty * out_len + tx]); // ReduceOp::zero();
Idx local_argu = 0, local_arge = 0; Idx local_argu = 0, local_arge = 0;
const int lhs_add = UseBcast ? ubcast_off[tx] : tx; const int lhs_add = UseBcast ? ubcast_off[tx] : tx;
const int rhs_add = UseBcast ? ebcast_off[tx] : tx; const int rhs_add = UseBcast ? ebcast_off[tx] : tx;
...@@ -620,7 +627,7 @@ void SpMMCoo( ...@@ -620,7 +627,7 @@ void SpMMCoo(
NDArray out, NDArray argu, NDArray arge) { NDArray out, NDArray argu, NDArray arge) {
/** /**
* TODO(Xin): Disable half precision for SpMMCoo due to the round-off error. * TODO(Xin): Disable half precision for SpMMCoo due to the round-off error.
* We should use fp32 for the accumulation but it's hard to modify the * We should use fp32 for the accumulation but it's hard to modify the
* current implementation. * current implementation.
*/ */
#if BF16_ENABLED #if BF16_ENABLED
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <cstdlib>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./functor.cuh" #include "./functor.cuh"
#include "./ge_spmm.cuh" #include "./ge_spmm.cuh"
...@@ -35,6 +37,9 @@ void SpMMCsrHetero( ...@@ -35,6 +37,9 @@ void SpMMCsrHetero(
bool use_efeat = op != "copy_lhs"; bool use_efeat = op != "copy_lhs";
auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx); auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);
std::vector<DType*> trans_out((*vec_out).size(), NULL); std::vector<DType*> trans_out((*vec_out).size(), NULL);
bool use_deterministic_alg_only = false;
if (NULL != std::getenv("USE_DETERMINISTIC_ALG"))
use_deterministic_alg_only = true;
bool use_legacy_cusparsemm = bool use_legacy_cusparsemm =
(CUDART_VERSION < 11000) && (reduce == "sum") && (CUDART_VERSION < 11000) && (reduce == "sum") &&
...@@ -128,19 +133,19 @@ void SpMMCsrHetero( ...@@ -128,19 +133,19 @@ void SpMMCsrHetero(
: static_cast<DType*>((*vec_out)[dst_id]->data); : static_cast<DType*>((*vec_out)[dst_id]->data);
CusparseCsrmm2Hetero<DType, IdType>( CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data), csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data),
nullptr, out, x_length, stream); nullptr, out, x_length, stream, use_deterministic_alg_only);
} else if ( } else if (
op == "mul" && is_scalar_efeat && op == "mul" && is_scalar_efeat &&
cusparse_available<DType, IdType>(more_nnz)) { // cusparse cusparse_available<DType, IdType>(more_nnz)) { // cusparse
NDArray efeat = vec_efeat[etype]; NDArray efeat = vec_efeat[etype];
if (!IsNullArray(csr.data)) if (!IsNullArray(csr.data)) efeat = IndexSelect(efeat, csr.data);
efeat = IndexSelect(efeat, csr.data);
CusparseCsrmm2Hetero<DType, IdType>( CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data), csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data),
static_cast<DType*>(efeat->data), static_cast<DType*>(efeat->data),
// TODO(Israt): Change (*vec_out) to trans_out to support CUDA // TODO(Israt): Change (*vec_out) to trans_out to support CUDA
// version < 11 // version < 11
static_cast<DType*>((*vec_out)[dst_id]->data), x_length, stream); static_cast<DType*>((*vec_out)[dst_id]->data), x_length, stream,
use_deterministic_alg_only);
} else { // general kernel } else { // general kernel
NDArray ufeat = NDArray ufeat =
(vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id]; (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment