/*! * Copyright (c) 2020 by Contributors * \file array/cuda/dispatcher.cuh * \brief Templates to dispatch into different cuSPARSE routines based on the type * argument. */ #ifndef DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_ #define DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_ #include #include namespace dgl { namespace aten { /*! \brief cusparseXcsrgemm dispatcher */ template struct CSRGEMM { template static inline cusparseStatus_t bufferSizeExt(Args... args) { BUG_IF_FAIL(false) << "This piece of code should not be reached."; return 0; } template static inline cusparseStatus_t nnz(Args... args) { return cusparseXcsrgemm2Nnz(args...); } template static inline cusparseStatus_t compute(Args... args) { BUG_IF_FAIL(false) << "This piece of code should not be reached."; return 0; } }; template <> struct CSRGEMM { template static inline cusparseStatus_t bufferSizeExt(Args... args) { return cusparseScsrgemm2_bufferSizeExt(args...); } template static inline cusparseStatus_t nnz(Args... args) { return cusparseXcsrgemm2Nnz(args...); } template static inline cusparseStatus_t compute(Args... args) { return cusparseScsrgemm2(args...); } }; template <> struct CSRGEMM { template static inline cusparseStatus_t bufferSizeExt(Args... args) { return cusparseDcsrgemm2_bufferSizeExt(args...); } template static inline cusparseStatus_t nnz(Args... args) { return cusparseXcsrgemm2Nnz(args...); } template static inline cusparseStatus_t compute(Args... args) { return cusparseDcsrgemm2(args...); } }; /*! \brief cusparseXcsrgeam dispatcher */ template struct CSRGEAM { template static inline cusparseStatus_t bufferSizeExt(Args... args) { BUG_IF_FAIL(false) << "This piece of code should not be reached."; return 0; } template static inline cusparseStatus_t nnz(Args... args) { return cusparseXcsrgeam2Nnz(args...); } template static inline cusparseStatus_t compute(Args... args) { BUG_IF_FAIL(false) << "This piece of code should not be reached."; return 0; } }; template <> struct CSRGEAM { template static inline cusparseStatus_t bufferSizeExt(Args... args) { return cusparseScsrgeam2_bufferSizeExt(args...); } template static inline cusparseStatus_t nnz(Args... args) { return cusparseXcsrgeam2Nnz(args...); } template static inline cusparseStatus_t compute(Args... args) { return cusparseScsrgeam2(args...); } }; template <> struct CSRGEAM { template static inline cusparseStatus_t bufferSizeExt(Args... args) { return cusparseDcsrgeam2_bufferSizeExt(args...); } template static inline cusparseStatus_t nnz(Args... args) { return cusparseXcsrgeam2Nnz(args...); } template static inline cusparseStatus_t compute(Args... args) { return cusparseDcsrgeam2(args...); } }; }; // namespace aten }; // namespace dgl #endif // DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_