Unverified Commit 8ac27dad authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4824)



* [Misc] clang-format auto fix.

* blabla

* ablabla

* blabla
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent bcd37684
...@@ -7,9 +7,11 @@ ...@@ -7,9 +7,11 @@
#define DGL_ARRAY_CUDA_ATOMIC_CUH_ #define DGL_ARRAY_CUDA_ATOMIC_CUH_
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cassert> #include <cassert>
#include "fp16.cuh"
#include "bf16.cuh" #include "bf16.cuh"
#include "fp16.cuh"
#if __CUDA_ARCH__ >= 600 #if __CUDA_ARCH__ >= 600
#include <cuda_fp16.h> #include <cuda_fp16.h>
...@@ -20,22 +22,27 @@ namespace aten { ...@@ -20,22 +22,27 @@ namespace aten {
namespace cuda { namespace cuda {
// Type trait for selecting code type // Type trait for selecting code type
template <int Bytes> struct Code { }; template <int Bytes>
struct Code {};
template <> struct Code<2> { template <>
struct Code<2> {
typedef unsigned short int Type; // NOLINT typedef unsigned short int Type; // NOLINT
}; };
template <> struct Code<4> { template <>
struct Code<4> {
typedef unsigned int Type; // NOLINT typedef unsigned int Type; // NOLINT
}; };
template <> struct Code<8> { template <>
struct Code<8> {
typedef unsigned long long int Type; // NOLINT typedef unsigned long long int Type; // NOLINT
}; };
// Helper class for converting to/from atomicCAS compatible types. // Helper class for converting to/from atomicCAS compatible types.
template <typename T> struct Cast { template <typename T>
struct Cast {
typedef typename Code<sizeof(T)>::Type Type; typedef typename Code<sizeof(T)>::Type Type;
static __device__ __forceinline__ Type Encode(T val) { static __device__ __forceinline__ Type Encode(T val) {
return static_cast<Type>(val); return static_cast<Type>(val);
...@@ -45,7 +52,8 @@ template <typename T> struct Cast { ...@@ -45,7 +52,8 @@ template <typename T> struct Cast {
} }
}; };
template <> struct Cast<half> { template <>
struct Cast<half> {
typedef Code<sizeof(half)>::Type Type; typedef Code<sizeof(half)>::Type Type;
static __device__ __forceinline__ Type Encode(half val) { static __device__ __forceinline__ Type Encode(half val) {
return __half_as_ushort(val); return __half_as_ushort(val);
...@@ -56,13 +64,15 @@ template <> struct Cast<half> { ...@@ -56,13 +64,15 @@ template <> struct Cast<half> {
}; };
#if BF16_ENABLED #if BF16_ENABLED
template <> struct Cast<__nv_bfloat16> { template <>
struct Cast<__nv_bfloat16> {
typedef Code<sizeof(__nv_bfloat16)>::Type Type; typedef Code<sizeof(__nv_bfloat16)>::Type Type;
static __device__ __forceinline__ Type Encode(__nv_bfloat16 val) { static __device__ __forceinline__ Type Encode(__nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __bfloat16_as_ushort(val); return __bfloat16_as_ushort(val);
#else #else
printf("Atomic operations are not supported for bfloat16 (BF16) " printf(
"Atomic operations are not supported for bfloat16 (BF16) "
"on GPUs with compute capability less than 8.0.\n"); "on GPUs with compute capability less than 8.0.\n");
__trap(); __trap();
return static_cast<Type>(0); return static_cast<Type>(0);
...@@ -72,7 +82,8 @@ template <> struct Cast<__nv_bfloat16> { ...@@ -72,7 +82,8 @@ template <> struct Cast<__nv_bfloat16> {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __ushort_as_bfloat16(code); return __ushort_as_bfloat16(code);
#else #else
printf("Atomic operations are not supported for bfloat16 (BF16) " printf(
"Atomic operations are not supported for bfloat16 (BF16) "
"on GPUs with compute capability less than 8.0.\n"); "on GPUs with compute capability less than 8.0.\n");
__trap(); __trap();
return static_cast<__nv_bfloat16>(0.0f); return static_cast<__nv_bfloat16>(0.0f);
...@@ -81,7 +92,8 @@ template <> struct Cast<__nv_bfloat16> { ...@@ -81,7 +92,8 @@ template <> struct Cast<__nv_bfloat16> {
}; };
#endif // BF16_ENABLED #endif // BF16_ENABLED
template <> struct Cast<float> { template <>
struct Cast<float> {
typedef Code<sizeof(float)>::Type Type; typedef Code<sizeof(float)>::Type Type;
static __device__ __forceinline__ Type Encode(float val) { static __device__ __forceinline__ Type Encode(float val) {
return __float_as_uint(val); return __float_as_uint(val);
...@@ -91,7 +103,8 @@ template <> struct Cast<float> { ...@@ -91,7 +103,8 @@ template <> struct Cast<float> {
} }
}; };
template <> struct Cast<double> { template <>
struct Cast<double> {
typedef Code<sizeof(double)>::Type Type; typedef Code<sizeof(double)>::Type Type;
static __device__ __forceinline__ Type Encode(double val) { static __device__ __forceinline__ Type Encode(double val) {
return __double_as_longlong(val); return __double_as_longlong(val);
...@@ -102,7 +115,7 @@ template <> struct Cast<double> { ...@@ -102,7 +115,7 @@ template <> struct Cast<double> {
}; };
static __device__ __forceinline__ unsigned short int atomicCASshort( // NOLINT static __device__ __forceinline__ unsigned short int atomicCASshort( // NOLINT
unsigned short int *address, // NOLINT unsigned short int* address, // NOLINT
unsigned short int compare, // NOLINT unsigned short int compare, // NOLINT
unsigned short int val) { // NOLINT unsigned short int val) { // NOLINT
static_assert(CUDART_VERSION >= 10000, "Requires at least CUDA 10"); static_assert(CUDART_VERSION >= 10000, "Requires at least CUDA 10");
...@@ -112,41 +125,45 @@ static __device__ __forceinline__ unsigned short int atomicCASshort( // NOLINT ...@@ -112,41 +125,45 @@ static __device__ __forceinline__ unsigned short int atomicCASshort( // NOLINT
(void)address; (void)address;
(void)compare; (void)compare;
(void)val; (void)val;
printf("Atomic operations are not supported for half precision (FP16) " printf(
"Atomic operations are not supported for half precision (FP16) "
"on this GPU.\n"); "on this GPU.\n");
__trap(); __trap();
return val; return val;
#endif // (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700) #endif // (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700)
} }
#define DEFINE_ATOMIC(NAME) \ #define DEFINE_ATOMIC(NAME) \
template <typename T> \ template <typename T> \
__device__ __forceinline__ T Atomic##NAME(T* addr, T val) { \ __device__ __forceinline__ T Atomic##NAME(T* addr, T val) { \
typedef typename Cast<T>::Type CT; \ typedef typename Cast<T>::Type CT; \
CT* addr_as_ui = reinterpret_cast<CT*>(addr); \ CT* addr_as_ui = reinterpret_cast<CT*>(addr); \
CT old = *addr_as_ui; \ CT old = *addr_as_ui; \
CT assumed = old; \ CT assumed = old; \
do { \ do { \
assumed = old; \ assumed = old; \
old = atomicCAS(addr_as_ui, assumed, \ old = atomicCAS( \
Cast<T>::Encode(OP(val, Cast<T>::Decode(old)))); \ addr_as_ui, assumed, \
} while (assumed != old); \ Cast<T>::Encode(OP(val, Cast<T>::Decode(old)))); \
return Cast<T>::Decode(old); \ } while (assumed != old); \
return Cast<T>::Decode(old); \
} }
#define DEFINE_ATOMIC_16BIT(NAME, dtype) \ #define DEFINE_ATOMIC_16BIT(NAME, dtype) \
template <> \ template <> \
__device__ __forceinline__ dtype Atomic##NAME<dtype>(dtype* addr, dtype val) { \ __device__ __forceinline__ dtype Atomic##NAME<dtype>( \
typedef uint16_t CT; \ dtype * addr, dtype val) { \
CT* addr_as_ui = reinterpret_cast<CT*>(addr); \ typedef uint16_t CT; \
CT old = *addr_as_ui; \ CT* addr_as_ui = reinterpret_cast<CT*>(addr); \
CT assumed = old; \ CT old = *addr_as_ui; \
do { \ CT assumed = old; \
assumed = old; \ do { \
old = atomicCASshort(addr_as_ui, assumed, \ assumed = old; \
old = atomicCASshort( \
addr_as_ui, assumed, \
Cast<dtype>::Encode(OP(val, Cast<dtype>::Decode(old)))); \ Cast<dtype>::Encode(OP(val, Cast<dtype>::Decode(old)))); \
} while (assumed != old); \ } while (assumed != old); \
return Cast<dtype>::Decode(old); \ return Cast<dtype>::Decode(old); \
} }
#define OP(a, b) max(a, b) #define OP(a, b) max(a, b)
...@@ -169,84 +186,72 @@ DEFINE_ATOMIC_16BIT(Min, __nv_bfloat16) ...@@ -169,84 +186,72 @@ DEFINE_ATOMIC_16BIT(Min, __nv_bfloat16)
DEFINE_ATOMIC(Add) DEFINE_ATOMIC(Add)
#undef OP #undef OP
/** /**
* @brief Performs an atomic compare-and-swap on 64 bit integers. That is, * @brief Performs an atomic compare-and-swap on 64 bit integers. That is,
* it the word `old` at the memory location `address`, computes * it the word `old` at the memory location `address`, computes
* `(old == compare ? val : old)` , and stores the result back to memory at * `(old == compare ? val : old)` , and stores the result back to memory at
* the same address. * the same address.
* *
* @param address The address to perform the atomic operation on. * @param address The address to perform the atomic operation on.
* @param compare The value to compare to. * @param compare The value to compare to.
* @param val The new value to conditionally store. * @param val The new value to conditionally store.
* *
* @return The old value at the address. * @return The old value at the address.
*/ */
inline __device__ int64_t AtomicCAS( inline __device__ int64_t
int64_t * const address, AtomicCAS(int64_t* const address, const int64_t compare, const int64_t val) {
const int64_t compare,
const int64_t val) {
// match the type of "::atomicCAS", so ignore lint warning // match the type of "::atomicCAS", so ignore lint warning
using Type = unsigned long long int; // NOLINT using Type = unsigned long long int; // NOLINT
static_assert(sizeof(Type) == sizeof(*address), "Type width must match"); static_assert(sizeof(Type) == sizeof(*address), "Type width must match");
return atomicCAS(reinterpret_cast<Type*>(address), return atomicCAS(
static_cast<Type>(compare), reinterpret_cast<Type*>(address), static_cast<Type>(compare),
static_cast<Type>(val)); static_cast<Type>(val));
} }
/** /**
* @brief Performs an atomic compare-and-swap on 32 bit integers. That is, * @brief Performs an atomic compare-and-swap on 32 bit integers. That is,
* it the word `old` at the memory location `address`, computes * it the word `old` at the memory location `address`, computes
* `(old == compare ? val : old)` , and stores the result back to memory at * `(old == compare ? val : old)` , and stores the result back to memory at
* the same address. * the same address.
* *
* @param address The address to perform the atomic operation on. * @param address The address to perform the atomic operation on.
* @param compare The value to compare to. * @param compare The value to compare to.
* @param val The new value to conditionally store. * @param val The new value to conditionally store.
* *
* @return The old value at the address. * @return The old value at the address.
*/ */
inline __device__ int32_t AtomicCAS( inline __device__ int32_t
int32_t * const address, AtomicCAS(int32_t* const address, const int32_t compare, const int32_t val) {
const int32_t compare,
const int32_t val) {
// match the type of "::atomicCAS", so ignore lint warning // match the type of "::atomicCAS", so ignore lint warning
using Type = int; // NOLINT using Type = int; // NOLINT
static_assert(sizeof(Type) == sizeof(*address), "Type width must match"); static_assert(sizeof(Type) == sizeof(*address), "Type width must match");
return atomicCAS(reinterpret_cast<Type*>(address), return atomicCAS(
static_cast<Type>(compare), reinterpret_cast<Type*>(address), static_cast<Type>(compare),
static_cast<Type>(val)); static_cast<Type>(val));
} }
inline __device__ int64_t AtomicMax( inline __device__ int64_t AtomicMax(int64_t* const address, const int64_t val) {
int64_t * const address,
const int64_t val) {
// match the type of "::atomicCAS", so ignore lint warning // match the type of "::atomicCAS", so ignore lint warning
using Type = unsigned long long int; // NOLINT using Type = unsigned long long int; // NOLINT
static_assert(sizeof(Type) == sizeof(*address), "Type width must match"); static_assert(sizeof(Type) == sizeof(*address), "Type width must match");
return atomicMax(reinterpret_cast<Type*>(address), return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));
static_cast<Type>(val));
} }
inline __device__ int32_t AtomicMax( inline __device__ int32_t AtomicMax(int32_t* const address, const int32_t val) {
int32_t * const address,
const int32_t val) {
// match the type of "::atomicCAS", so ignore lint warning // match the type of "::atomicCAS", so ignore lint warning
using Type = int; // NOLINT using Type = int; // NOLINT
static_assert(sizeof(Type) == sizeof(*address), "Type width must match"); static_assert(sizeof(Type) == sizeof(*address), "Type width must match");
return atomicMax(reinterpret_cast<Type*>(address), return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));
static_cast<Type>(val));
} }
template <> template <>
__device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) { __device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) {
#if __CUDA_ARCH__ >= 200 #if __CUDA_ARCH__ >= 200
...@@ -259,8 +264,8 @@ __device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) { ...@@ -259,8 +264,8 @@ __device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) {
CT assumed = old; CT assumed = old;
do { do {
assumed = old; assumed = old;
old = atomicCAS(addr_as_ui, assumed, old = atomicCAS(
Cast<T>::Encode(Cast<T>::Decode(old) + val)); addr_as_ui, assumed, Cast<T>::Encode(Cast<T>::Decode(old) + val));
} while (assumed != old); } while (assumed != old);
return Cast<T>::Decode(old); return Cast<T>::Decode(old);
#endif // __CUDA_ARCH__ #endif // __CUDA_ARCH__
...@@ -278,8 +283,8 @@ __device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) { ...@@ -278,8 +283,8 @@ __device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) {
CT assumed = old; CT assumed = old;
do { do {
assumed = old; assumed = old;
old = atomicCAS(addr_as_ui, assumed, old = atomicCAS(
Cast<T>::Encode(Cast<T>::Decode(old) + val)); addr_as_ui, assumed, Cast<T>::Encode(Cast<T>::Decode(old) + val));
} while (assumed != old); } while (assumed != old);
return Cast<T>::Decode(old); return Cast<T>::Decode(old);
#endif #endif
...@@ -294,7 +299,8 @@ __device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) { ...@@ -294,7 +299,8 @@ __device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
#else #else
(void)addr; (void)addr;
(void)val; (void)val;
printf("Atomic operations are not supported for half precision (FP16) " printf(
"Atomic operations are not supported for half precision (FP16) "
"on this GPU.\n"); "on this GPU.\n");
__trap(); __trap();
return val; return val;
...@@ -304,15 +310,16 @@ __device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) { ...@@ -304,15 +310,16 @@ __device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
#if BF16_ENABLED #if BF16_ENABLED
template <> template <>
__device__ __forceinline__ __nv_bfloat16 AtomicAdd<__nv_bfloat16>( __device__ __forceinline__ __nv_bfloat16
__nv_bfloat16* addr, __nv_bfloat16 val) { AtomicAdd<__nv_bfloat16>(__nv_bfloat16* addr, __nv_bfloat16 val) {
// make sure we have bfloat16 support // make sure we have bfloat16 support
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return atomicAdd(addr, val); return atomicAdd(addr, val);
#else #else
(void)addr; (void)addr;
(void)val; (void)val;
printf("Atomic operations are not supported for bfloat16 (BF16) " printf(
"Atomic operations are not supported for bfloat16 (BF16) "
"on GPUs with compute capability less than 8.0.\n"); "on GPUs with compute capability less than 8.0.\n");
__trap(); __trap();
return val; return val;
...@@ -320,7 +327,6 @@ __device__ __forceinline__ __nv_bfloat16 AtomicAdd<__nv_bfloat16>( ...@@ -320,7 +327,6 @@ __device__ __forceinline__ __nv_bfloat16 AtomicAdd<__nv_bfloat16>(
} }
#endif // BF16_ENABLED #endif // BF16_ENABLED
} // namespace cuda } // namespace cuda
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -4,9 +4,11 @@ ...@@ -4,9 +4,11 @@
* @brief Retrieve entries of a CSR matrix * @brief Retrieve entries of a CSR matrix
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <vector>
#include <unordered_set>
#include <numeric> #include <numeric>
#include <unordered_set>
#include <vector>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
...@@ -19,56 +21,57 @@ namespace impl { ...@@ -19,56 +21,57 @@ namespace impl {
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
NDArray CSRGetData( NDArray CSRGetData(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) { CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, DType filler) {
const int64_t rowlen = rows->shape[0]; const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0]; const int64_t collen = cols->shape[0];
CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1)) CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
<< "Invalid row and col id array."; << "Invalid row and col id array.";
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1; const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1; const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
const int64_t rstlen = std::max(rowlen, collen); const int64_t rstlen = std::max(rowlen, collen);
IdArray rst = NDArray::Empty({rstlen}, weights->dtype, rows->ctx); IdArray rst = NDArray::Empty({rstlen}, weights->dtype, rows->ctx);
if (rstlen == 0) if (rstlen == 0) return rst;
return rst;
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const int nt = cuda::FindNumThreads(rstlen); const int nt = cuda::FindNumThreads(rstlen);
const int nb = (rstlen + nt - 1) / nt; const int nb = (rstlen + nt - 1) / nt;
if (return_eids) if (return_eids)
BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype) << BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype)
"DType does not match row's dtype."; << "DType does not match row's dtype.";
const IdType* indptr_data = csr.indptr.Ptr<IdType>(); const IdType* indptr_data = csr.indptr.Ptr<IdType>();
const IdType* indices_data = csr.indices.Ptr<IdType>(); const IdType* indices_data = csr.indices.Ptr<IdType>();
const IdType* data_data = CSRHasData(csr) ? csr.data.Ptr<IdType>() : nullptr; const IdType* data_data = CSRHasData(csr) ? csr.data.Ptr<IdType>() : nullptr;
if (csr.is_pinned) { if (csr.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(
&indptr_data, csr.indptr.Ptr<IdType>(), 0)); cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(
&indices_data, csr.indices.Ptr<IdType>(), 0)); cudaHostGetDevicePointer(&indices_data, csr.indices.Ptr<IdType>(), 0));
if (CSRHasData(csr)) { if (CSRHasData(csr)) {
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(
&data_data, csr.data.Ptr<IdType>(), 0)); cudaHostGetDevicePointer(&data_data, csr.data.Ptr<IdType>(), 0));
} }
} }
// TODO(minjie): use binary search for sorted csr // TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL(cuda::_LinearSearchKernel, CUDA_KERNEL_CALL(
nb, nt, 0, stream, cuda::_LinearSearchKernel, nb, nt, 0, stream, indptr_data, indices_data,
indptr_data, indices_data, data_data, data_data, rows.Ptr<IdType>(), cols.Ptr<IdType>(), row_stride, col_stride,
rows.Ptr<IdType>(), cols.Ptr<IdType>(), rstlen, return_eids ? nullptr : weights.Ptr<DType>(), filler,
row_stride, col_stride, rstlen, rst.Ptr<DType>());
return_eids ? nullptr : weights.Ptr<DType>(), filler, rst.Ptr<DType>());
return rst; return rst;
} }
template NDArray CSRGetData<kDGLCUDA, int32_t, __half>( template NDArray CSRGetData<kDGLCUDA, int32_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, __half filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, __half>( template NDArray CSRGetData<kDGLCUDA, int64_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, __half filler);
#if BF16_ENABLED #if BF16_ENABLED
template NDArray CSRGetData<kDGLCUDA, int32_t, __nv_bfloat16>( template NDArray CSRGetData<kDGLCUDA, int32_t, __nv_bfloat16>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
...@@ -78,19 +81,25 @@ template NDArray CSRGetData<kDGLCUDA, int64_t, __nv_bfloat16>( ...@@ -78,19 +81,25 @@ template NDArray CSRGetData<kDGLCUDA, int64_t, __nv_bfloat16>(
NDArray weights, __nv_bfloat16 filler); NDArray weights, __nv_bfloat16 filler);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template NDArray CSRGetData<kDGLCUDA, int32_t, float>( template NDArray CSRGetData<kDGLCUDA, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, float filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, float>( template NDArray CSRGetData<kDGLCUDA, int64_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, float filler);
template NDArray CSRGetData<kDGLCUDA, int32_t, double>( template NDArray CSRGetData<kDGLCUDA, int32_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, double filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, double>( template NDArray CSRGetData<kDGLCUDA, int64_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, double filler);
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray) // For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
template NDArray CSRGetData<kDGLCUDA, int32_t, int32_t>( template NDArray CSRGetData<kDGLCUDA, int32_t, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, int32_t filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, int64_t>( template NDArray CSRGetData<kDGLCUDA, int64_t, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler); CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, int64_t filler);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include "./functor.cuh"
#include "./cusparse_dispatcher.cuh"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./cusparse_dispatcher.cuh"
#include "./functor.cuh"
namespace dgl { namespace dgl {
...@@ -16,7 +17,8 @@ using namespace dgl::runtime; ...@@ -16,7 +17,8 @@ using namespace dgl::runtime;
namespace aten { namespace aten {
namespace cusparse { namespace cusparse {
#if 0 // disabling CUDA 11.0+ implementation for now because of problems on bigger graphs #if 0 // disabling CUDA 11.0+ implementation for now because of problems on
// bigger graphs
/** @brief Cusparse implementation of SpGEMM on Csr format for CUDA 11.0+ */ /** @brief Cusparse implementation of SpGEMM on Csr format for CUDA 11.0+ */
template <typename DType, typename IdType> template <typename DType, typename IdType>
...@@ -125,14 +127,13 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm( ...@@ -125,14 +127,13 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
dC_weights}; dC_weights};
} }
#else // __CUDACC_VER_MAJOR__ != 11 #else // __CUDACC_VER_MAJOR__ != 11
/** @brief Cusparse implementation of SpGEMM on Csr format for older CUDA versions */ /** @brief Cusparse implementation of SpGEMM on Csr format for older CUDA
* versions */
template <typename DType, typename IdType> template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseSpgemm( std::pair<CSRMatrix, NDArray> CusparseSpgemm(
const CSRMatrix& A, const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,
const NDArray A_weights_array,
const CSRMatrix& B,
const NDArray B_weights_array) { const NDArray B_weights_array) {
int nnzC; int nnzC;
csrgemm2Info_t info = nullptr; csrgemm2Info_t info = nullptr;
...@@ -164,36 +165,30 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm( ...@@ -164,36 +165,30 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
CUSPARSE_CALL(cusparseCreateMatDescr(&matA)); CUSPARSE_CALL(cusparseCreateMatDescr(&matA));
CUSPARSE_CALL(cusparseCreateMatDescr(&matB)); CUSPARSE_CALL(cusparseCreateMatDescr(&matB));
CUSPARSE_CALL(cusparseCreateMatDescr(&matC)); CUSPARSE_CALL(cusparseCreateMatDescr(&matC));
CUSPARSE_CALL(cusparseCreateMatDescr(&matD)); // needed even if D is null CUSPARSE_CALL(cusparseCreateMatDescr(&matD)); // needed even if D is null
CUSPARSE_CALL(CSRGEMM<DType>::bufferSizeExt(thr_entry->cusparse_handle, CUSPARSE_CALL(CSRGEMM<DType>::bufferSizeExt(
m, n, k, &alpha, thr_entry->cusparse_handle, m, n, k, &alpha, matA, nnzA,
matA, nnzA, A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), matB, nnzB,
matB, nnzB, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), nullptr, matD, 0,
nullptr, nullptr, nullptr, info, &workspace_size));
matD, 0, nullptr, nullptr,
info,
&workspace_size));
void *workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
IdArray C_indptr = IdArray::Empty({m + 1}, idtype, ctx); IdArray C_indptr = IdArray::Empty({m + 1}, idtype, ctx);
CUSPARSE_CALL(CSRGEMM<DType>::nnz(thr_entry->cusparse_handle, CUSPARSE_CALL(CSRGEMM<DType>::nnz(
m, n, k, thr_entry->cusparse_handle, m, n, k, matA, nnzA, A.indptr.Ptr<IdType>(),
matA, nnzA, A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), A.indices.Ptr<IdType>(), matB, nnzB, B.indptr.Ptr<IdType>(),
matB, nnzB, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), B.indices.Ptr<IdType>(), matD, 0, nullptr, nullptr, matC,
matD, 0, nullptr, nullptr, C_indptr.Ptr<IdType>(), &nnzC, info, workspace));
matC, C_indptr.Ptr<IdType>(), &nnzC, info, workspace));
IdArray C_indices = IdArray::Empty({nnzC}, idtype, ctx); IdArray C_indices = IdArray::Empty({nnzC}, idtype, ctx);
NDArray C_weights = NDArray::Empty({nnzC}, dtype, ctx); NDArray C_weights = NDArray::Empty({nnzC}, dtype, ctx);
CUSPARSE_CALL(CSRGEMM<DType>::compute(thr_entry->cusparse_handle, CUSPARSE_CALL(CSRGEMM<DType>::compute(
m, n, k, &alpha, thr_entry->cusparse_handle, m, n, k, &alpha, matA, nnzA, A_weights,
matA, nnzA, A_weights, A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), matB, nnzB, B_weights,
matB, nnzB, B_weights, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), nullptr, matD, 0,
nullptr, nullptr, nullptr, nullptr, matC, C_weights.Ptr<DType>(),
matD, 0, nullptr, nullptr, nullptr, C_indptr.Ptr<IdType>(), C_indices.Ptr<IdType>(), info, workspace));
matC, C_weights.Ptr<DType>(), C_indptr.Ptr<IdType>(), C_indices.Ptr<IdType>(),
info, workspace));
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
CUSPARSE_CALL(cusparseDestroyCsrgemm2Info(info)); CUSPARSE_CALL(cusparseDestroyCsrgemm2Info(info));
...@@ -203,7 +198,8 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm( ...@@ -203,7 +198,8 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
CUSPARSE_CALL(cusparseDestroyMatDescr(matD)); CUSPARSE_CALL(cusparseDestroyMatDescr(matD));
return { return {
CSRMatrix(m, k, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)), CSRMatrix(
m, k, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)),
C_weights}; C_weights};
} }
...@@ -212,9 +208,7 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm( ...@@ -212,9 +208,7 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRMM( std::pair<CSRMatrix, NDArray> CSRMM(
const CSRMatrix& A, const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B,
NDArray A_weights,
const CSRMatrix& B,
NDArray B_weights) { NDArray B_weights) {
auto ctx = A.indptr->ctx; auto ctx = A.indptr->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
...@@ -224,20 +218,18 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -224,20 +218,18 @@ std::pair<CSRMatrix, NDArray> CSRMM(
// Cast 64 bit indices to 32 bit. // Cast 64 bit indices to 32 bit.
if (A.indptr->dtype.bits == 64) { if (A.indptr->dtype.bits == 64) {
newA = CSRMatrix( newA = CSRMatrix(
A.num_rows, A.num_cols, A.num_rows, A.num_cols, AsNumBits(A.indptr, 32),
AsNumBits(A.indptr, 32), AsNumBits(A.indices, 32), AsNumBits(A.data, 32)); AsNumBits(A.indices, 32), AsNumBits(A.data, 32));
newB = CSRMatrix( newB = CSRMatrix(
B.num_rows, B.num_cols, B.num_rows, B.num_cols, AsNumBits(B.indptr, 32),
AsNumBits(B.indptr, 32), AsNumBits(B.indices, 32), AsNumBits(B.data, 32)); AsNumBits(B.indices, 32), AsNumBits(B.data, 32));
cast = true; cast = true;
} }
// Reorder weights if A or B has edge IDs // Reorder weights if A or B has edge IDs
NDArray newA_weights, newB_weights; NDArray newA_weights, newB_weights;
if (CSRHasData(A)) if (CSRHasData(A)) newA_weights = IndexSelect(A_weights, A.data);
newA_weights = IndexSelect(A_weights, A.data); if (CSRHasData(B)) newB_weights = IndexSelect(B_weights, B.data);
if (CSRHasData(B))
newB_weights = IndexSelect(B_weights, B.data);
auto result = cusparse::CusparseSpgemm<DType, int32_t>( auto result = cusparse::CusparseSpgemm<DType, int32_t>(
cast ? newA : A, CSRHasData(A) ? newA_weights : A_weights, cast ? newA : A, CSRHasData(A) ? newA_weights : A_weights,
...@@ -247,9 +239,10 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -247,9 +239,10 @@ std::pair<CSRMatrix, NDArray> CSRMM(
if (cast) { if (cast) {
CSRMatrix C = result.first; CSRMatrix C = result.first;
return { return {
CSRMatrix(C.num_rows, C.num_cols, AsNumBits(C.indptr, 64), AsNumBits(C.indices, 64), CSRMatrix(
AsNumBits(C.data, 64)), C.num_rows, C.num_cols, AsNumBits(C.indptr, 64),
result.second}; AsNumBits(C.indices, 64), AsNumBits(C.data, 64)),
result.second};
} else { } else {
return result; return result;
} }
......
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include "./functor.cuh"
#include "./cusparse_dispatcher.cuh"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./cusparse_dispatcher.cuh"
#include "./functor.cuh"
namespace dgl { namespace dgl {
...@@ -19,9 +20,7 @@ namespace cusparse { ...@@ -19,9 +20,7 @@ namespace cusparse {
/** Cusparse implementation of SpSum on Csr format. */ /** Cusparse implementation of SpSum on Csr format. */
template <typename DType, typename IdType> template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseCsrgeam2( std::pair<CSRMatrix, NDArray> CusparseCsrgeam2(
const CSRMatrix& A, const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,
const NDArray A_weights_array,
const CSRMatrix& B,
const NDArray B_weights_array) { const NDArray B_weights_array) {
const int m = A.num_rows; const int m = A.num_rows;
const int n = A.num_cols; const int n = A.num_cols;
...@@ -46,7 +45,8 @@ std::pair<CSRMatrix, NDArray> CusparseCsrgeam2( ...@@ -46,7 +45,8 @@ std::pair<CSRMatrix, NDArray> CusparseCsrgeam2(
CUSPARSE_CALL(cusparseCreateMatDescr(&matB)); CUSPARSE_CALL(cusparseCreateMatDescr(&matB));
CUSPARSE_CALL(cusparseCreateMatDescr(&matC)); CUSPARSE_CALL(cusparseCreateMatDescr(&matC));
cusparseSetPointerMode(thr_entry->cusparse_handle, CUSPARSE_POINTER_MODE_HOST); cusparseSetPointerMode(
thr_entry->cusparse_handle, CUSPARSE_POINTER_MODE_HOST);
size_t workspace_size = 0; size_t workspace_size = 0;
/* prepare output C */ /* prepare output C */
IdArray dC_csrOffsets = IdArray::Empty({m + 1}, A.indptr->dtype, ctx); IdArray dC_csrOffsets = IdArray::Empty({m + 1}, A.indptr->dtype, ctx);
...@@ -57,25 +57,16 @@ std::pair<CSRMatrix, NDArray> CusparseCsrgeam2( ...@@ -57,25 +57,16 @@ std::pair<CSRMatrix, NDArray> CusparseCsrgeam2(
DType* dC_weights_data = dC_weights.Ptr<DType>(); DType* dC_weights_data = dC_weights.Ptr<DType>();
/* prepare buffer */ /* prepare buffer */
CUSPARSE_CALL(CSRGEAM<DType>::bufferSizeExt( CUSPARSE_CALL(CSRGEAM<DType>::bufferSizeExt(
thr_entry->cusparse_handle, m, n, &alpha, thr_entry->cusparse_handle, m, n, &alpha, matA, nnzA, A_weights,
matA, nnzA, A_weights, A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), &beta, matB, nnzB,
A.indptr.Ptr<IdType>(), B_weights, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), matC,
A.indices.Ptr<IdType>(), dC_weights_data, dC_csrOffsets_data, dC_columns_data, &workspace_size));
&beta, matB, nnzB, B_weights,
B.indptr.Ptr<IdType>(), void* workspace = device->AllocWorkspace(ctx, workspace_size);
B.indices.Ptr<IdType>(), CUSPARSE_CALL(CSRGEAM<DType>::nnz(
matC, dC_weights_data, dC_csrOffsets_data, dC_columns_data, thr_entry->cusparse_handle, m, n, matA, nnzA, A.indptr.Ptr<IdType>(),
&workspace_size)); A.indices.Ptr<IdType>(), matB, nnzB, B.indptr.Ptr<IdType>(),
B.indices.Ptr<IdType>(), matC, dC_csrOffsets_data, &nnzC, workspace));
void *workspace = device->AllocWorkspace(ctx, workspace_size);
CUSPARSE_CALL(CSRGEAM<DType>::nnz(thr_entry->cusparse_handle,
m, n, matA, nnzA,
A.indptr.Ptr<IdType>(),
A.indices.Ptr<IdType>(),
matB, nnzB,
B.indptr.Ptr<IdType>(),
B.indices.Ptr<IdType>(),
matC, dC_csrOffsets_data, &nnzC, workspace));
dC_columns = IdArray::Empty({nnzC}, A.indptr->dtype, ctx); dC_columns = IdArray::Empty({nnzC}, A.indptr->dtype, ctx);
dC_weights = NDArray::Empty({nnzC}, A_weights_array->dtype, ctx); dC_weights = NDArray::Empty({nnzC}, A_weights_array->dtype, ctx);
...@@ -83,15 +74,10 @@ std::pair<CSRMatrix, NDArray> CusparseCsrgeam2( ...@@ -83,15 +74,10 @@ std::pair<CSRMatrix, NDArray> CusparseCsrgeam2(
dC_weights_data = dC_weights.Ptr<DType>(); dC_weights_data = dC_weights.Ptr<DType>();
CUSPARSE_CALL(CSRGEAM<DType>::compute( CUSPARSE_CALL(CSRGEAM<DType>::compute(
thr_entry->cusparse_handle, m, n, &alpha, thr_entry->cusparse_handle, m, n, &alpha, matA, nnzA, A_weights,
matA, nnzA, A_weights, A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), &beta, matB, nnzB,
A.indptr.Ptr<IdType>(), B_weights, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), matC,
A.indices.Ptr<IdType>(), dC_weights_data, dC_csrOffsets_data, dC_columns_data, workspace));
&beta, matB, nnzB, B_weights,
B.indptr.Ptr<IdType>(),
B.indices.Ptr<IdType>(),
matC, dC_weights_data, dC_csrOffsets_data, dC_columns_data,
workspace));
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
// destroy matrix/vector descriptors // destroy matrix/vector descriptors
...@@ -99,16 +85,16 @@ std::pair<CSRMatrix, NDArray> CusparseCsrgeam2( ...@@ -99,16 +85,16 @@ std::pair<CSRMatrix, NDArray> CusparseCsrgeam2(
CUSPARSE_CALL(cusparseDestroyMatDescr(matB)); CUSPARSE_CALL(cusparseDestroyMatDescr(matB));
CUSPARSE_CALL(cusparseDestroyMatDescr(matC)); CUSPARSE_CALL(cusparseDestroyMatDescr(matC));
return { return {
CSRMatrix(A.num_rows, A.num_cols, dC_csrOffsets, dC_columns, CSRMatrix(
NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx), true), A.num_rows, A.num_cols, dC_csrOffsets, dC_columns,
dC_weights}; NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx), true),
dC_weights};
} }
} // namespace cusparse } // namespace cusparse
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRSum( std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& As, const std::vector<CSRMatrix>& As, const std::vector<NDArray>& A_weights) {
const std::vector<NDArray>& A_weights) {
const int64_t M = As[0].num_rows; const int64_t M = As[0].num_rows;
const int64_t N = As[0].num_cols; const int64_t N = As[0].num_cols;
const int64_t n = As.size(); const int64_t n = As.size();
...@@ -120,19 +106,18 @@ std::pair<CSRMatrix, NDArray> CSRSum( ...@@ -120,19 +106,18 @@ std::pair<CSRMatrix, NDArray> CSRSum(
if (As[0].indptr->dtype.bits == 64) { if (As[0].indptr->dtype.bits == 64) {
for (int i = 0; i < n; ++i) for (int i = 0; i < n; ++i)
newAs.emplace_back( newAs.emplace_back(
As[i].num_rows, As[i].num_cols, AsNumBits(As[i].indptr, 32), As[i].num_rows, As[i].num_cols, AsNumBits(As[i].indptr, 32),
AsNumBits(As[i].indices, 32), AsNumBits(As[i].data, 32)); AsNumBits(As[i].indices, 32), AsNumBits(As[i].data, 32));
cast = true; cast = true;
} else { } else {
for (int i = 0; i < n; ++i) for (int i = 0; i < n; ++i) newAs.push_back(As[i]);
newAs.push_back(As[i]);
} }
// cuSPARSE csrgeam2 requires the CSR to be sorted. // cuSPARSE csrgeam2 requires the CSR to be sorted.
// TODO(BarclayII): ideally the sorted CSR should be cached but I'm not sure how to do it. // TODO(BarclayII): ideally the sorted CSR should be cached but I'm not sure
// how to do it.
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
if (!newAs[i].sorted) if (!newAs[i].sorted) newAs[i] = CSRSort(newAs[i]);
newAs[i] = CSRSort(newAs[i]);
} }
// Reorder weights if A[i] has edge IDs // Reorder weights if A[i] has edge IDs
...@@ -147,10 +132,11 @@ std::pair<CSRMatrix, NDArray> CSRSum( ...@@ -147,10 +132,11 @@ std::pair<CSRMatrix, NDArray> CSRSum(
// Loop and sum // Loop and sum
auto result = std::make_pair( auto result = std::make_pair(
CSRMatrix( CSRMatrix(
newAs[0].num_rows, newAs[0].num_cols, newAs[0].num_rows, newAs[0].num_cols, newAs[0].indptr,
newAs[0].indptr, newAs[0].indices, newAs[0].indices,
NullArray(newAs[0].indptr->dtype, newAs[0].indptr->ctx)), NullArray(newAs[0].indptr->dtype, newAs[0].indptr->ctx)),
A_weights_reordered[0]); // Weights already reordered so we don't need As[0].data A_weights_reordered[0]); // Weights already reordered so we don't need
// As[0].data
for (int64_t i = 1; i < n; ++i) for (int64_t i = 1; i < n; ++i)
result = cusparse::CusparseCsrgeam2<DType, int32_t>( result = cusparse::CusparseCsrgeam2<DType, int32_t>(
result.first, result.second, newAs[i], A_weights_reordered[i]); result.first, result.second, newAs[i], A_weights_reordered[i]);
...@@ -159,9 +145,10 @@ std::pair<CSRMatrix, NDArray> CSRSum( ...@@ -159,9 +145,10 @@ std::pair<CSRMatrix, NDArray> CSRSum(
if (cast) { if (cast) {
CSRMatrix C = result.first; CSRMatrix C = result.first;
return { return {
CSRMatrix(C.num_rows, C.num_cols, AsNumBits(C.indptr, 64), AsNumBits(C.indices, 64), CSRMatrix(
AsNumBits(C.data, 64), true), C.num_rows, C.num_cols, AsNumBits(C.indptr, 64),
result.second}; AsNumBits(C.indices, 64), AsNumBits(C.data, 64), true),
result.second};
} else { } else {
return result; return result;
} }
......
/** /**
* Copyright (c) 2020 by Contributors * Copyright (c) 2020 by Contributors
* @file array/cuda/dispatcher.cuh * @file array/cuda/dispatcher.cuh
* @brief Templates to dispatch into different cuSPARSE routines based on the type * @brief Templates to dispatch into different cuSPARSE routines based on the
* argument. * type argument.
*/ */
#ifndef DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_ #ifndef DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_
#define DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_ #define DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_
#include <cusparse.h> #include <cusparse.h>
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
#include "fp16.cuh"
#include "bf16.cuh" #include "bf16.cuh"
#include "fp16.cuh"
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -40,8 +41,8 @@ template <> ...@@ -40,8 +41,8 @@ template <>
struct CSRGEMM<__half> { struct CSRGEMM<__half> {
template <typename... Args> template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) { static inline cusparseStatus_t bufferSizeExt(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a different // TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a
// implementation would be required. // different implementation would be required.
LOG(FATAL) << "CSRGEMM::bufferSizeExt does not support dtype half (FP16)."; LOG(FATAL) << "CSRGEMM::bufferSizeExt does not support dtype half (FP16).";
return static_cast<cusparseStatus_t>(0); return static_cast<cusparseStatus_t>(0);
} }
...@@ -65,9 +66,10 @@ template <> ...@@ -65,9 +66,10 @@ template <>
struct CSRGEMM<__nv_bfloat16> { struct CSRGEMM<__nv_bfloat16> {
template <typename... Args> template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) { static inline cusparseStatus_t bufferSizeExt(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a different // TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a
// implementation would be required. // different implementation would be required.
LOG(FATAL) << "CSRGEMM::bufferSizeExt does not support dtype bfloat16 (BF16)."; LOG(FATAL)
<< "CSRGEMM::bufferSizeExt does not support dtype bfloat16 (BF16).";
return static_cast<cusparseStatus_t>(0); return static_cast<cusparseStatus_t>(0);
} }
...@@ -147,8 +149,8 @@ template <> ...@@ -147,8 +149,8 @@ template <>
struct CSRGEAM<__half> { struct CSRGEAM<__half> {
template <typename... Args> template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) { static inline cusparseStatus_t bufferSizeExt(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a different // TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a
// implementation would be required. // different implementation would be required.
LOG(FATAL) << "CSRGEAM::bufferSizeExt does not support dtype half (FP16)."; LOG(FATAL) << "CSRGEAM::bufferSizeExt does not support dtype half (FP16).";
return static_cast<cusparseStatus_t>(0); return static_cast<cusparseStatus_t>(0);
} }
...@@ -172,9 +174,10 @@ template <> ...@@ -172,9 +174,10 @@ template <>
struct CSRGEAM<__nv_bfloat16> { struct CSRGEAM<__nv_bfloat16> {
template <typename... Args> template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) { static inline cusparseStatus_t bufferSizeExt(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a different // TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a
// implementation would be required. // different implementation would be required.
LOG(FATAL) << "CSRGEAM::bufferSizeExt does not support dtype bfloat16 (BF16)."; LOG(FATAL)
<< "CSRGEAM::bufferSizeExt does not support dtype bfloat16 (BF16).";
return static_cast<cusparseStatus_t>(0); return static_cast<cusparseStatus_t>(0);
} }
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
#ifndef DGL_ARRAY_CUDA_FP16_CUH_ #ifndef DGL_ARRAY_CUDA_FP16_CUH_
#define DGL_ARRAY_CUDA_FP16_CUH_ #define DGL_ARRAY_CUDA_FP16_CUH_
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <algorithm> #include <algorithm>
static __device__ __forceinline__ half max(half a, half b) { static __device__ __forceinline__ half max(half a, half b) {
...@@ -42,45 +42,64 @@ static __device__ __forceinline__ half min(half a, half b) { ...@@ -42,45 +42,64 @@ static __device__ __forceinline__ half min(half a, half b) {
} }
#ifdef __CUDACC__ #ifdef __CUDACC__
// Arithmetic FP16 operations for architecture >= 5.3 are already defined in cuda_fp16.h // Arithmetic FP16 operations for architecture >= 5.3 are already defined in
// cuda_fp16.h
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530) #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530)
__device__ __forceinline__ __half operator+(const __half& lh, const __half& rh) { __device__ __forceinline__ __half
operator+(const __half& lh, const __half& rh) {
return __half(float(lh) + float(rh)); // NOLINT return __half(float(lh) + float(rh)); // NOLINT
} }
__device__ __forceinline__ __half operator-(const __half& lh, const __half& rh) { __device__ __forceinline__ __half
operator-(const __half& lh, const __half& rh) {
return __half(float(lh) - float(rh)); // NOLINT return __half(float(lh) - float(rh)); // NOLINT
} }
__device__ __forceinline__ __half operator*(const __half& lh, const __half& rh) { __device__ __forceinline__ __half
operator*(const __half& lh, const __half& rh) {
return __half(float(lh) * float(rh)); // NOLINT return __half(float(lh) * float(rh)); // NOLINT
} }
__device__ __forceinline__ __half operator/(const __half& lh, const __half& rh) { __device__ __forceinline__ __half
operator/(const __half& lh, const __half& rh) {
return __half(float(lh) / float(rh)); // NOLINT return __half(float(lh) / float(rh)); // NOLINT
} }
__device__ __forceinline__ __half& operator+=(__half& lh, const __half& rh) { // NOLINT __device__ __forceinline__ __half& operator+=(
lh = __half(float(lh) + float(rh)); return lh; // NOLINT __half& lh, const __half& rh) { // NOLINT
} lh = __half(float(lh) + float(rh)); // NOLINT
__device__ __forceinline__ __half& operator-=(__half& lh, const __half& rh) { // NOLINT return lh;
lh = __half(float(lh) - float(rh)); return lh; // NOLINT }
} __device__ __forceinline__ __half& operator-=(
__device__ __forceinline__ __half& operator*=(__half& lh, const __half& rh) { // NOLINT __half& lh, const __half& rh) { // NOLINT
lh = __half(float(lh) * float(rh)); return lh; // NOLINT lh = __half(float(lh) - float(rh)); // NOLINT
} return lh;
__device__ __forceinline__ __half& operator/=(__half& lh, const __half& rh) { // NOLINT }
lh = __half(float(lh) / float(rh)); return lh; // NOLINT __device__ __forceinline__ __half& operator*=(
__half& lh, const __half& rh) { // NOLINT
lh = __half(float(lh) * float(rh)); // NOLINT
return lh;
}
__device__ __forceinline__ __half& operator/=(
__half& lh, const __half& rh) { // NOLINT
lh = __half(float(lh) / float(rh)); // NOLINT
return lh;
} }
__device__ __forceinline__ __half& operator++(__half& h) { // NOLINT __device__ __forceinline__ __half& operator++(__half& h) { // NOLINT
h = __half(float(h) + 1.0f); return h; // NOLINT h = __half(float(h) + 1.0f); // NOLINT
return h;
} }
__device__ __forceinline__ __half& operator--(__half& h) { // NOLINT __device__ __forceinline__ __half& operator--(__half& h) { // NOLINT
h = __half(float(h) - 1.0f); return h; // NOLINT h = __half(float(h) - 1.0f); // NOLINT
return h;
} }
__device__ __forceinline__ __half operator++(__half& h, int) { // NOLINT __device__ __forceinline__ __half operator++(__half& h, int) { // NOLINT
__half ret = h; h = __half(float(h) + 1.0f); return ret; // NOLINT __half ret = h;
h = __half(float(h) + 1.0f); // NOLINT
return ret;
} }
__device__ __forceinline__ __half operator--(__half& h, int) { // NOLINT __device__ __forceinline__ __half operator--(__half& h, int) { // NOLINT
__half ret = h; h = __half(float(h) - 1.0f); return ret; // NOLINT __half ret = h;
h = __half(float(h) - 1.0f); // NOLINT
return ret;
} }
__device__ __forceinline__ __half operator+(const __half& h) { return h; } __device__ __forceinline__ __half operator+(const __half& h) { return h; }
...@@ -94,11 +113,11 @@ __device__ __forceinline__ bool operator==(const __half& lh, const __half& rh) { ...@@ -94,11 +113,11 @@ __device__ __forceinline__ bool operator==(const __half& lh, const __half& rh) {
__device__ __forceinline__ bool operator!=(const __half& lh, const __half& rh) { __device__ __forceinline__ bool operator!=(const __half& lh, const __half& rh) {
return float(lh) != float(rh); // NOLINT return float(lh) != float(rh); // NOLINT
} }
__device__ __forceinline__ bool operator> (const __half& lh, const __half& rh) { __device__ __forceinline__ bool operator>(const __half& lh, const __half& rh) {
return float(lh) > float(rh); // NOLINT return float(lh) > float(rh); // NOLINT
} }
__device__ __forceinline__ bool operator< (const __half& lh, const __half& rh) { __device__ __forceinline__ bool operator<(const __half& lh, const __half& rh) {
return float(lh) < float(rh); // NOLINT return float(lh) < float(rh); // NOLINT
} }
__device__ __forceinline__ bool operator>=(const __half& lh, const __half& rh) { __device__ __forceinline__ bool operator>=(const __half& lh, const __half& rh) {
return float(lh) >= float(rh); // NOLINT return float(lh) >= float(rh); // NOLINT
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <cmath> #include <cmath>
#include <limits> #include <limits>
#include "./atomic.cuh" #include "./atomic.cuh"
#include "./fp16.cuh" #include "./fp16.cuh"
#include "bf16.cuh" #include "bf16.cuh"
...@@ -16,99 +17,117 @@ namespace dgl { ...@@ -16,99 +17,117 @@ namespace dgl {
namespace aten { namespace aten {
namespace cuda { namespace cuda {
/////////////////////////////// CUDA binary operators /////////////////////////////// /////////////////////////// CUDA binary operators //////////////////////////////
namespace binary { namespace binary {
template <typename DType> template <typename DType>
struct Add { struct Add {
static constexpr bool use_lhs = true; static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true; static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = false; static constexpr bool reduce_last_dim = false;
static __device__ __forceinline__ DType Call( static __device__ __forceinline__ DType
const DType *lhs, const DType *rhs, int64_t len = 1) { Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
return lhs[0] + rhs[0]; return lhs[0] + rhs[0];
} }
}; };
template <typename DType> constexpr bool Add<DType>::use_lhs; template <typename DType>
template <typename DType> constexpr bool Add<DType>::use_rhs; constexpr bool Add<DType>::use_lhs;
template <typename DType> constexpr bool Add<DType>::reduce_last_dim; template <typename DType>
constexpr bool Add<DType>::use_rhs;
template <typename DType>
constexpr bool Add<DType>::reduce_last_dim;
template <typename DType> template <typename DType>
struct Sub { struct Sub {
static constexpr bool use_lhs = true; static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true; static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = false; static constexpr bool reduce_last_dim = false;
static __device__ __forceinline__ DType Call( static __device__ __forceinline__ DType
const DType *lhs, const DType *rhs, int64_t len = 1) { Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
return lhs[0] - rhs[0]; return lhs[0] - rhs[0];
} }
}; };
template <typename DType> constexpr bool Sub<DType>::use_lhs; template <typename DType>
template <typename DType> constexpr bool Sub<DType>::use_rhs; constexpr bool Sub<DType>::use_lhs;
template <typename DType> constexpr bool Sub<DType>::reduce_last_dim; template <typename DType>
constexpr bool Sub<DType>::use_rhs;
template <typename DType>
constexpr bool Sub<DType>::reduce_last_dim;
template <typename DType> template <typename DType>
struct Mul { struct Mul {
static constexpr bool use_lhs = true; static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true; static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = false; static constexpr bool reduce_last_dim = false;
static __device__ __forceinline__ DType Call( static __device__ __forceinline__ DType
const DType *lhs, const DType *rhs, int64_t len = 1) { Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
return lhs[0] * rhs[0]; return lhs[0] * rhs[0];
} }
}; };
template <typename DType> constexpr bool Mul<DType>::use_lhs; template <typename DType>
template <typename DType> constexpr bool Mul<DType>::use_rhs; constexpr bool Mul<DType>::use_lhs;
template <typename DType> constexpr bool Mul<DType>::reduce_last_dim; template <typename DType>
constexpr bool Mul<DType>::use_rhs;
template <typename DType>
constexpr bool Mul<DType>::reduce_last_dim;
template <typename DType> template <typename DType>
struct Div { struct Div {
static constexpr bool use_lhs = true; static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true; static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = false; static constexpr bool reduce_last_dim = false;
static __device__ __forceinline__ DType Call( static __device__ __forceinline__ DType
const DType *lhs, const DType *rhs, int64_t len = 1) { Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
return lhs[0] / rhs[0]; return lhs[0] / rhs[0];
} }
}; };
template <typename DType> constexpr bool Div<DType>::use_lhs; template <typename DType>
template <typename DType> constexpr bool Div<DType>::use_rhs; constexpr bool Div<DType>::use_lhs;
template <typename DType> constexpr bool Div<DType>::reduce_last_dim; template <typename DType>
constexpr bool Div<DType>::use_rhs;
template <typename DType>
constexpr bool Div<DType>::reduce_last_dim;
template <typename DType> template <typename DType>
struct CopyLhs { struct CopyLhs {
static constexpr bool use_lhs = true; static constexpr bool use_lhs = true;
static constexpr bool use_rhs = false; static constexpr bool use_rhs = false;
static constexpr bool reduce_last_dim = false; static constexpr bool reduce_last_dim = false;
static __device__ __forceinline__ DType Call( static __device__ __forceinline__ DType
const DType *lhs, const DType *rhs, int64_t len = 1) { Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
return lhs[0]; return lhs[0];
} }
}; };
template <typename DType> constexpr bool CopyLhs<DType>::use_lhs; template <typename DType>
template <typename DType> constexpr bool CopyLhs<DType>::use_rhs; constexpr bool CopyLhs<DType>::use_lhs;
template <typename DType> constexpr bool CopyLhs<DType>::reduce_last_dim; template <typename DType>
constexpr bool CopyLhs<DType>::use_rhs;
template <typename DType>
constexpr bool CopyLhs<DType>::reduce_last_dim;
template <typename DType> template <typename DType>
struct CopyRhs { struct CopyRhs {
static constexpr bool use_lhs = false; static constexpr bool use_lhs = false;
static constexpr bool use_rhs = true; static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = false; static constexpr bool reduce_last_dim = false;
static __device__ __forceinline__ DType Call( static __device__ __forceinline__ DType
const DType *lhs, const DType *rhs, int64_t len = 1) { Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
return rhs[0]; return rhs[0];
} }
}; };
template <typename DType> constexpr bool CopyRhs<DType>::use_lhs; template <typename DType>
template <typename DType> constexpr bool CopyRhs<DType>::use_rhs; constexpr bool CopyRhs<DType>::use_lhs;
template <typename DType> constexpr bool CopyRhs<DType>::reduce_last_dim; template <typename DType>
constexpr bool CopyRhs<DType>::use_rhs;
template <typename DType>
constexpr bool CopyRhs<DType>::reduce_last_dim;
template <typename DType> template <typename DType>
struct Dot { struct Dot {
static constexpr bool use_lhs = true; static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true; static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = true; static constexpr bool reduce_last_dim = true;
static __device__ __forceinline__ DType Call( static __device__ __forceinline__ DType
const DType *lhs, const DType *rhs, int64_t len = 1) { Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
DType rst = static_cast<DType>(0.0f); DType rst = static_cast<DType>(0.0f);
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
rst += lhs[i] * rhs[i]; rst += lhs[i] * rhs[i];
...@@ -116,25 +135,26 @@ struct Dot { ...@@ -116,25 +135,26 @@ struct Dot {
return rst; return rst;
} }
}; };
template <typename DType> constexpr bool Dot<DType>::use_lhs; template <typename DType>
template <typename DType> constexpr bool Dot<DType>::use_rhs; constexpr bool Dot<DType>::use_lhs;
template <typename DType> constexpr bool Dot<DType>::reduce_last_dim; template <typename DType>
constexpr bool Dot<DType>::use_rhs;
template <typename DType>
constexpr bool Dot<DType>::reduce_last_dim;
} // end of namespace binary } // end of namespace binary
/////////////////////////////// CUDA reduce operators /////////////////////////////// /////////////////////////// CUDA reduce operators //////////////////////////////
namespace reduce { namespace reduce {
template <typename Idx, template <typename Idx, typename DType, bool atomic>
typename DType,
bool atomic>
struct _Sum { struct _Sum {
static constexpr __host__ __device__ __forceinline__ DType zero() { static constexpr __host__ __device__ __forceinline__ DType zero() {
return 0.; return 0.;
} }
static constexpr bool require_arg = false; static constexpr bool require_arg = false;
static __device__ __forceinline__ void Call( static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid,
DType val, Idx uid, Idx eid) { Idx eid) {
if (!atomic) { if (!atomic) {
*out_buf += val; *out_buf += val;
} else { } else {
...@@ -142,26 +162,23 @@ struct _Sum { ...@@ -142,26 +162,23 @@ struct _Sum {
} }
} }
static __device__ __forceinline__ void Call( static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_buf, DType *out_buf, Idx *arg_buf, DType val, Idx id) {
DType val, Idx id) {
if (!atomic) { if (!atomic) {
*out_buf += val; *out_buf += val;
} else { } else {
cuda::AtomicAdd(out_buf, val); cuda::AtomicAdd(out_buf, val);
} }
} }
static __device__ __forceinline__ void CallArg(Idx fid, static __device__ __forceinline__ void CallArg(
Idx *arg_u_buf, Idx *arg_e_buf, Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref,
DType val, DType val_ref, Idx uid, Idx eid) {} Idx uid, Idx eid) {}
}; };
template <typename Idx, template <typename Idx, typename DType, bool atomic = false>
typename DType, struct Sum : _Sum<Idx, DType, atomic> {};
bool atomic = false>
struct Sum: _Sum<Idx, DType, atomic> { };
template <typename Idx, bool atomic> template <typename Idx, bool atomic>
struct Sum<Idx, half, atomic>: _Sum<Idx, half, atomic> { struct Sum<Idx, half, atomic> : _Sum<Idx, half, atomic> {
static constexpr __host__ __device__ __forceinline__ half zero() { static constexpr __host__ __device__ __forceinline__ half zero() {
return __float2half_rn(0.); return __float2half_rn(0.);
} }
...@@ -169,24 +186,22 @@ struct Sum<Idx, half, atomic>: _Sum<Idx, half, atomic> { ...@@ -169,24 +186,22 @@ struct Sum<Idx, half, atomic>: _Sum<Idx, half, atomic> {
#if BF16_ENABLED #if BF16_ENABLED
template <typename Idx, bool atomic> template <typename Idx, bool atomic>
struct Sum<Idx, __nv_bfloat16, atomic>: _Sum<Idx, __nv_bfloat16, atomic> { struct Sum<Idx, __nv_bfloat16, atomic> : _Sum<Idx, __nv_bfloat16, atomic> {
static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() { static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {
return __float2bfloat16_rn(0.); return __float2bfloat16_rn(0.);
} }
}; };
#endif // BF16_ENABLED #endif // BF16_ENABLED
template <typename Idx, template <typename Idx, typename DType, bool atomic>
typename DType,
bool atomic>
struct _Max { struct _Max {
static constexpr __host__ __device__ __forceinline__ DType zero() { static constexpr __host__ __device__ __forceinline__ DType zero() {
return -std::numeric_limits<DType>::infinity(); return -std::numeric_limits<DType>::infinity();
} }
static constexpr bool require_arg = true; static constexpr bool require_arg = true;
static __device__ __forceinline__ void Call( static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid,
DType val, Idx uid, Idx eid) { Idx eid) {
if (!atomic) { if (!atomic) {
if (*out_buf < val) { if (*out_buf < val) {
*out_buf = val; *out_buf = val;
...@@ -198,8 +213,7 @@ struct _Max { ...@@ -198,8 +213,7 @@ struct _Max {
} }
} }
static __device__ __forceinline__ void Call( static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_buf, DType *out_buf, Idx *arg_buf, DType val, Idx id) {
DType val, Idx id) {
if (!atomic) { if (!atomic) {
if (*out_buf < val) { if (*out_buf < val) {
*out_buf = val; *out_buf = val;
...@@ -209,27 +223,22 @@ struct _Max { ...@@ -209,27 +223,22 @@ struct _Max {
cuda::AtomicMax(out_buf, val); cuda::AtomicMax(out_buf, val);
} }
} }
static __device__ __forceinline__ void CallArg(Idx fid, static __device__ __forceinline__ void CallArg(
Idx *arg_u_buf, Idx *arg_e_buf, Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref,
DType val, DType val_ref, Idx uid, Idx eid) { Idx uid, Idx eid) {
if (atomic) { if (atomic) {
if (val == val_ref) { if (val == val_ref) {
if (arg_u_buf) if (arg_u_buf) arg_u_buf[fid] = uid;
arg_u_buf[fid] = uid; if (arg_e_buf) arg_e_buf[fid] = eid;
if (arg_e_buf)
arg_e_buf[fid] = eid;
} }
} }
} }
}; };
template <typename Idx, template <typename Idx, typename DType, bool atomic = false>
typename DType, struct Max : _Max<Idx, DType, atomic> {};
bool atomic = false>
struct Max : _Max<Idx, DType, atomic> { };
template <typename Idx, template <typename Idx, bool atomic>
bool atomic>
struct Max<Idx, half, atomic> : _Max<Idx, half, atomic> { struct Max<Idx, half, atomic> : _Max<Idx, half, atomic> {
static constexpr __host__ __device__ __forceinline__ half zero() { static constexpr __host__ __device__ __forceinline__ half zero() {
return __float2half_rn(-6.550400e+04f); return __float2half_rn(-6.550400e+04f);
...@@ -237,8 +246,7 @@ struct Max<Idx, half, atomic> : _Max<Idx, half, atomic> { ...@@ -237,8 +246,7 @@ struct Max<Idx, half, atomic> : _Max<Idx, half, atomic> {
}; };
#if BF16_ENABLED #if BF16_ENABLED
template <typename Idx, template <typename Idx, bool atomic>
bool atomic>
struct Max<Idx, __nv_bfloat16, atomic> : _Max<Idx, __nv_bfloat16, atomic> { struct Max<Idx, __nv_bfloat16, atomic> : _Max<Idx, __nv_bfloat16, atomic> {
static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() { static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {
return __float2bfloat16_rn(-std::numeric_limits<float>::infinity()); return __float2bfloat16_rn(-std::numeric_limits<float>::infinity());
...@@ -246,17 +254,15 @@ struct Max<Idx, __nv_bfloat16, atomic> : _Max<Idx, __nv_bfloat16, atomic> { ...@@ -246,17 +254,15 @@ struct Max<Idx, __nv_bfloat16, atomic> : _Max<Idx, __nv_bfloat16, atomic> {
}; };
#endif // BF16_ENABLED #endif // BF16_ENABLED
template <typename Idx, template <typename Idx, typename DType, bool atomic>
typename DType,
bool atomic>
struct _Min { struct _Min {
static constexpr __host__ __device__ __forceinline__ DType zero() { static constexpr __host__ __device__ __forceinline__ DType zero() {
return std::numeric_limits<DType>::infinity(); return std::numeric_limits<DType>::infinity();
} }
static constexpr bool require_arg = true; static constexpr bool require_arg = true;
static __device__ __forceinline__ void Call( static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid,
DType val, Idx uid, Idx eid) { Idx eid) {
if (!atomic) { if (!atomic) {
if (*out_buf > val) { if (*out_buf > val) {
*out_buf = val; *out_buf = val;
...@@ -268,8 +274,7 @@ struct _Min { ...@@ -268,8 +274,7 @@ struct _Min {
} }
} }
static __device__ __forceinline__ void Call( static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_buf, DType *out_buf, Idx *arg_buf, DType val, Idx id) {
DType val, Idx id) {
if (!atomic) { if (!atomic) {
if (*out_buf > val) { if (*out_buf > val) {
*out_buf = val; *out_buf = val;
...@@ -279,27 +284,22 @@ struct _Min { ...@@ -279,27 +284,22 @@ struct _Min {
cuda::AtomicMin(out_buf, val); cuda::AtomicMin(out_buf, val);
} }
} }
static __device__ __forceinline__ void CallArg(Idx fid, static __device__ __forceinline__ void CallArg(
Idx *arg_u_buf, Idx *arg_e_buf, Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref,
DType val, DType val_ref, Idx uid, Idx eid) { Idx uid, Idx eid) {
if (atomic) { if (atomic) {
if (val == val_ref) { if (val == val_ref) {
if (arg_u_buf) if (arg_u_buf) arg_u_buf[fid] = uid;
arg_u_buf[fid] = uid; if (arg_e_buf) arg_e_buf[fid] = eid;
if (arg_e_buf)
arg_e_buf[fid] = eid;
} }
} }
} }
}; };
template <typename Idx, template <typename Idx, typename DType, bool atomic = false>
typename DType, struct Min : _Min<Idx, DType, atomic> {};
bool atomic = false>
struct Min : _Min<Idx, DType, atomic> { };
template <typename Idx, template <typename Idx, bool atomic>
bool atomic>
struct Min<Idx, half, atomic> : _Min<Idx, half, atomic> { struct Min<Idx, half, atomic> : _Min<Idx, half, atomic> {
static constexpr __host__ __device__ __forceinline__ half zero() { static constexpr __host__ __device__ __forceinline__ half zero() {
return __float2half_rn(6.550400e+04f); return __float2half_rn(6.550400e+04f);
...@@ -307,8 +307,7 @@ struct Min<Idx, half, atomic> : _Min<Idx, half, atomic> { ...@@ -307,8 +307,7 @@ struct Min<Idx, half, atomic> : _Min<Idx, half, atomic> {
}; };
#if BF16_ENABLED #if BF16_ENABLED
template <typename Idx, template <typename Idx, bool atomic>
bool atomic>
struct Min<Idx, __nv_bfloat16, atomic> : _Min<Idx, __nv_bfloat16, atomic> { struct Min<Idx, __nv_bfloat16, atomic> : _Min<Idx, __nv_bfloat16, atomic> {
static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() { static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {
return __float2bfloat16_rn(std::numeric_limits<float>::infinity()); return __float2bfloat16_rn(std::numeric_limits<float>::infinity());
......
...@@ -4,10 +4,12 @@ ...@@ -4,10 +4,12 @@
* @brief GatherMM C APIs and definitions. * @brief GatherMM C APIs and definitions.
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <algorithm> // std::swap #include <algorithm> // std::swap
#include "./utils.h"
#include "./functor.cuh"
#include "./atomic.cuh" #include "./atomic.cuh"
#include "./functor.cuh"
#include "./utils.h"
namespace dgl { namespace dgl {
using namespace cuda; using namespace cuda;
...@@ -15,62 +17,58 @@ namespace aten { ...@@ -15,62 +17,58 @@ namespace aten {
namespace { namespace {
/** @brief Call cuBLAS GEMM API for dense matmul operation for float and double. */ /** @brief Call cuBLAS GEMM API for dense matmul operation for float and double.
*/
template <typename DType> template <typename DType>
cublasStatus_t cublasGemm(cublasHandle_t handle, cublasOperation_t transa, cublasStatus_t cublasGemm(
cublasOperation_t transb, int m, int n, int k, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
const DType* alpha, const DType* A, int lda, int m, int n, int k, const DType* alpha, const DType* A, int lda,
const DType* B, int ldb, const DType* beta, const DType* B, int ldb, const DType* beta, DType* C, int ldc) {
DType* C, int ldc) {
LOG(INFO) << "Not supported dtype"; LOG(INFO) << "Not supported dtype";
return CUBLAS_STATUS_EXECUTION_FAILED; return CUBLAS_STATUS_EXECUTION_FAILED;
} }
template <> template <>
cublasStatus_t cublasGemm<__half>(cublasHandle_t handle, cublasOperation_t transa, cublasStatus_t cublasGemm<__half>(
cublasOperation_t transb, int m, int n, int k, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
const __half* alpha, const __half* A, int lda, int m, int n, int k, const __half* alpha, const __half* A, int lda,
const __half* B, int ldb, const __half* beta, const __half* B, int ldb, const __half* beta, __half* C, int ldc) {
__half* C, int ldc) { return cublasHgemm(
return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
B, ldb, beta, C, ldc);
} }
#if BF16_ENABLED #if BF16_ENABLED
template <> template <>
cublasStatus_t cublasGemm<__nv_bfloat16>(cublasHandle_t handle, cublasOperation_t transa, cublasStatus_t cublasGemm<__nv_bfloat16>(
cublasOperation_t transb, int m, int n, int k, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
const __nv_bfloat16* alpha, const __nv_bfloat16* A, int lda, int m, int n, int k, const __nv_bfloat16* alpha, const __nv_bfloat16* A,
const __nv_bfloat16* B, int ldb, const __nv_bfloat16* beta, int lda, const __nv_bfloat16* B, int ldb, const __nv_bfloat16* beta,
__nv_bfloat16* C, int ldc) { __nv_bfloat16* C, int ldc) {
float alpha_float = __bfloat162float(*alpha); float alpha_float = __bfloat162float(*alpha);
float beta_float = __bfloat162float(*beta); float beta_float = __bfloat162float(*beta);
return cublasGemmEx(handle, transa, transb, m, n, k, return cublasGemmEx(
&alpha_float, A, CUDA_R_16BF, lda, handle, transa, transb, m, n, k, &alpha_float, A, CUDA_R_16BF, lda, B,
B, CUDA_R_16BF, ldb, CUDA_R_16BF, ldb, &beta_float, C, CUDA_R_16BF, ldc, CUBLAS_COMPUTE_32F,
&beta_float, C, CUDA_R_16BF, ldc, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
} }
#endif // BF16_ENABLED #endif // BF16_ENABLED
template <> template <>
cublasStatus_t cublasGemm<float>(cublasHandle_t handle, cublasOperation_t transa, cublasStatus_t cublasGemm<float>(
cublasOperation_t transb, int m, int n, int k, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
const float* alpha, const float* A, int lda, int m, int n, int k, const float* alpha, const float* A, int lda,
const float* B, int ldb, const float* beta, const float* B, int ldb, const float* beta, float* C, int ldc) {
float* C, int ldc) { return cublasSgemm(
return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda, handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
B, ldb, beta, C, ldc);
} }
template <> template <>
cublasStatus_t cublasGemm<double>(cublasHandle_t handle, cublasOperation_t transa, cublasStatus_t cublasGemm<double>(
cublasOperation_t transb, int m, int n, int k, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
const double* alpha, const double* A, int lda, int m, int n, int k, const double* alpha, const double* A, int lda,
const double* B, int ldb, const double* beta, const double* B, int ldb, const double* beta, double* C, int ldc) {
double* C, int ldc) { return cublasDgemm(
return cublasDgemm(handle, transa, transb, m, n, k, alpha, A, lda, handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
B, ldb, beta, C, ldc);
} }
} // namespace } // namespace
...@@ -78,122 +76,113 @@ cublasStatus_t cublasGemm<double>(cublasHandle_t handle, cublasOperation_t trans ...@@ -78,122 +76,113 @@ cublasStatus_t cublasGemm<double>(cublasHandle_t handle, cublasOperation_t trans
namespace cuda { namespace cuda {
/** /**
* @note Each row of A multiplies a segment of matrix of B of dimension in_len * outlen. * @note Each row of A multiplies a segment of matrix of B of dimension in_len *
* One warp is assigned to process one row of A. Each WARP sequentially multiplies * outlen. One warp is assigned to process one row of A. Each WARP sequentially
* one element of A and a row of B to compute partial result of the output. A * multiplies one element of A and a row of B to compute partial result of the
* is loaded in shared memory in a coalesced way. Output matrix is loaded in * output. A is loaded in shared memory in a coalesced way. Output matrix is
* registers. B should get benefit from L2 cache. * loaded in registers. B should get benefit from L2 cache.
*/ */
template <typename Idx, typename DType> template <typename Idx, typename DType>
__global__ void GatherMMScatterKernel( __global__ void GatherMMScatterKernel(
const DType* __restrict__ A, const DType* __restrict__ A, const DType* __restrict__ B,
const DType* __restrict__ B, DType* __restrict__ C, const Idx* __restrict__ idx_a,
DType* __restrict__ C, const Idx* __restrict__ idx_b, const Idx* __restrict__ idx_c,
const Idx* __restrict__ idx_a, const int64_t num_rows, const int64_t in_len, const int64_t out_len) {
const Idx* __restrict__ idx_b, unsigned int tId = threadIdx.x;
const Idx* __restrict__ idx_c, unsigned int laneId = tId & 31;
const int64_t num_rows, unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x);
const int64_t in_len, unsigned int warpId = gId >> 5;
const int64_t out_len) { unsigned int row = warpId;
if (row < num_rows) {
const unsigned int local_row =
row & 3; // hardcoded for TB size 128 (4 warps)
const Idx cur_rowA = (idx_a) ? idx_a[row] : row;
const Idx cur_rowB = (idx_b) ? idx_b[row] : row;
const Idx cur_rowC = (idx_c) ? idx_c[row] : row;
const Idx B_offset = cur_rowB * in_len * out_len;
const int sh_a_tile = 64;
__shared__ DType sh_A[4 * sh_a_tile];
int a_tile = sh_a_tile;
for (unsigned int k_start = 0; k_start < in_len; k_start += 64) {
if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;
// Load A in shared mem in a coalesced way
for (unsigned int l = laneId; l < a_tile; l += 32)
sh_A[local_row * sh_a_tile + l] = A[cur_rowA * in_len + (k_start + l)];
__syncwarp();
unsigned int tId = threadIdx.x; for (unsigned int outloop = 0; outloop < out_len; outloop += 32) {
unsigned int laneId = tId & 31; DType out_reg = static_cast<DType>(0.0f); // thread private
unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x); const unsigned int l = laneId;
unsigned int warpId = gId >> 5; if (l < out_len) {
unsigned int row = warpId; // iterate over elements of a row of A
if (row < num_rows) { for (unsigned int i = 0; i < a_tile; i++) {
const unsigned int local_row = row & 3; // hardcoded for TB size 128 (4 warps) const DType a_val = sh_A[local_row * sh_a_tile + i];
const Idx cur_rowA = (idx_a) ? idx_a[row] : row; // iterate over elements of a row of B in parallel
const Idx cur_rowB = (idx_b) ? idx_b[row] : row; out_reg +=
const Idx cur_rowC = (idx_c) ? idx_c[row] : row; a_val * B[B_offset + ((i + k_start) * out_len + (outloop + l))];
const Idx B_offset = cur_rowB * in_len * out_len; }
const int sh_a_tile = 64; if (idx_c) {
__shared__ DType sh_A[4 * sh_a_tile]; AtomicAdd(C + cur_rowC * out_len + (outloop + l), out_reg);
int a_tile = sh_a_tile; } else {
for (unsigned int k_start = 0; k_start < in_len; k_start += 64) { C[cur_rowC * out_len + (outloop + l)] += out_reg;
if ((in_len - k_start) < a_tile) a_tile = in_len - k_start; }
// Load A in shared mem in a coalesced way
for (unsigned int l = laneId; l < a_tile; l += 32)
sh_A[local_row * sh_a_tile + l] = A[cur_rowA * in_len + (k_start + l)];
__syncwarp();
for (unsigned int outloop = 0; outloop < out_len; outloop +=32) {
DType out_reg = static_cast<DType>(0.0f); // thread private
const unsigned int l = laneId;
if (l < out_len) {
// iterate over elements of a row of A
for (unsigned int i = 0; i < a_tile; i++) {
const DType a_val = sh_A[local_row * sh_a_tile + i];
// iterate over elements of a row of B in parallel
out_reg += a_val * B[B_offset + ((i + k_start) * out_len + (outloop + l))];
}
if (idx_c) {
AtomicAdd(C + cur_rowC * out_len + (outloop + l), out_reg);
} else {
C[cur_rowC * out_len + (outloop + l)] += out_reg;
}
}
}
} }
}
} }
}
} }
/** /**
* @note Output matrix is accumulated via atomic operations. Rest of the strategies * @note Output matrix is accumulated via atomic operations. Rest of the
* are similar to GatherMMKernel. One warp is assigned to process one row of A. Each * strategies are similar to GatherMMKernel. One warp is assigned to process one
* WARP sequentially multiplies one element of A and a row of B to compute partial * row of A. Each WARP sequentially multiplies one element of A and a row of B
* result of the output. A is loaded in shared memory in a coalesced way. B should * to compute partial result of the output. A is loaded in shared memory in a
* get benefit from L2 cache. * coalesced way. B should get benefit from L2 cache.
*/ */
template <typename Idx, typename DType> template <typename Idx, typename DType>
__global__ void GatherMMScatterKernel2( __global__ void GatherMMScatterKernel2(
const DType* __restrict__ A, const DType* __restrict__ A, const DType* __restrict__ B,
const DType* __restrict__ B, DType* __restrict__ C, const Idx* __restrict__ idx_a,
DType* __restrict__ C, const Idx* __restrict__ idx_b, const Idx* __restrict__ idx_c,
const Idx* __restrict__ idx_a, const int64_t num_rows, const int64_t in_len, const int64_t out_len) {
const Idx* __restrict__ idx_b, unsigned int tId = threadIdx.x;
const Idx* __restrict__ idx_c, unsigned int laneId = tId & 31;
const int64_t num_rows, unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x);
const int64_t in_len, unsigned int warpId = gId >> 5;
const int64_t out_len) { unsigned int row = warpId;
if (row < num_rows) {
unsigned int tId = threadIdx.x; const unsigned int local_row =
unsigned int laneId = tId & 31; row & 3; // hardcoded for TB size 128 (4 warps)
unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x); const Idx row_a = (idx_a) ? idx_a[row] : row;
unsigned int warpId = gId >> 5; const Idx row_b = (idx_b) ? idx_b[row] : row;
unsigned int row = warpId; const Idx row_c = (idx_c) ? idx_c[row] : row;
if (row < num_rows) { const Idx C_offset = row_c * in_len * out_len;
const unsigned int local_row = row & 3; // hardcoded for TB size 128 (4 warps) const int sh_a_tile = 64;
const Idx row_a = (idx_a) ? idx_a[row] : row; __shared__ DType sh_A[4 * sh_a_tile];
const Idx row_b = (idx_b) ? idx_b[row] : row; int a_tile = sh_a_tile;
const Idx row_c = (idx_c) ? idx_c[row] : row; for (unsigned int k_start = 0; k_start < in_len; k_start += 64) {
const Idx C_offset = row_c * in_len * out_len; if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;
const int sh_a_tile = 64; /* Load A in shared mem in a coalesced way */
__shared__ DType sh_A[4 * sh_a_tile]; for (unsigned int l = laneId; l < a_tile; l += 32)
int a_tile = sh_a_tile; sh_A[local_row * sh_a_tile + l] = A[row_a * in_len + (k_start + l)];
for (unsigned int k_start = 0; k_start < in_len; k_start += 64) { __syncwarp();
if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;
/* Load A in shared mem in a coalesced way */
for (unsigned int l = laneId; l < a_tile; l += 32)
sh_A[local_row * sh_a_tile + l] = A[row_a * in_len + (k_start + l)];
__syncwarp();
for (unsigned int outloop = 0; outloop < out_len; outloop +=32) { for (unsigned int outloop = 0; outloop < out_len; outloop += 32) {
DType out_reg = static_cast<DType>(0.0f); // thread private DType out_reg = static_cast<DType>(0.0f); // thread private
const unsigned int l = laneId; const unsigned int l = laneId;
if (l < out_len) { if (l < out_len) {
const DType b_val = B[row_b * out_len + (outloop + l)]; const DType b_val = B[row_b * out_len + (outloop + l)];
/* iterate over elements of a row of A */ /* iterate over elements of a row of A */
for (unsigned int i = 0; i < a_tile; i++) { for (unsigned int i = 0; i < a_tile; i++) {
const DType a_val = sh_A[local_row * sh_a_tile + i]; const DType a_val = sh_A[local_row * sh_a_tile + i];
const Idx C_idx = C_offset + ((i + k_start) * out_len + (outloop + l)); const Idx C_idx =
AtomicAdd(C + C_idx, a_val * b_val); C_offset + ((i + k_start) * out_len + (outloop + l));
} AtomicAdd(C + C_idx, a_val * b_val);
} }
}
} }
}
} }
}
} }
} // namespace cuda } // namespace cuda
...@@ -210,103 +199,89 @@ __global__ void GatherMMScatterKernel2( ...@@ -210,103 +199,89 @@ __global__ void GatherMMScatterKernel2(
* @param b_trans Matrix B to be transposed * @param b_trans Matrix B to be transposed
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SegmentMM(const NDArray A, void SegmentMM(
const NDArray B, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
NDArray C, bool a_trans, bool b_trans) {
const NDArray seglen_A, auto device = runtime::DeviceAPI::Get(A->ctx);
bool a_trans, bool b_trans) { cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(A->ctx); const DType* A_data = A.Ptr<DType>();
cudaStream_t stream = runtime::getCurrentCUDAStream(); const DType* B_data = B.Ptr<DType>();
const DType *A_data = A.Ptr<DType>(); const IdType* seglen_A_data = seglen_A.Ptr<IdType>();
const DType *B_data = B.Ptr<DType>(); DType* C_data = C.Ptr<DType>();
const IdType* seglen_A_data = seglen_A.Ptr<IdType>(); int64_t A_offset = 0, B_offset = 0, C_offset = 0;
DType *C_data = C.Ptr<DType>(); int64_t m, n, k;
int64_t A_offset = 0, B_offset = 0, C_offset = 0; int64_t num_rel = seglen_A.NumElements();
int64_t m, n, k; DType alpha = 1., beta = 0.;
int64_t num_rel = seglen_A.NumElements();
DType alpha = 1., beta = 0.;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
if (!thr_entry->cublas_handle) if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle))); CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream)); CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));
IdType m_offset = 0; IdType m_offset = 0;
for (IdType etype = 0; etype < num_rel; ++etype) { for (IdType etype = 0; etype < num_rel; ++etype) {
m = seglen_A_data[etype]; // rows of A m = seglen_A_data[etype]; // rows of A
CHECK_LE(m_offset + m, A->shape[0]) << "Segment index out of bound of A->shape[0]."; CHECK_LE(m_offset + m, A->shape[0])
n = B->shape[2]; // cols of B << "Segment index out of bound of A->shape[0].";
k = B->shape[1]; // cols of A == rows of B n = B->shape[2]; // cols of B
int ldb = n, lda = k, ldc = n; k = B->shape[1]; // cols of A == rows of B
cublasOperation_t transB = CUBLAS_OP_N; int ldb = n, lda = k, ldc = n;
cublasOperation_t transA = CUBLAS_OP_N; cublasOperation_t transB = CUBLAS_OP_N;
if (b_trans) { cublasOperation_t transA = CUBLAS_OP_N;
transB = CUBLAS_OP_T; if (b_trans) {
ldb = n, lda = n, ldc = k; transB = CUBLAS_OP_T;
std::swap(n, k); ldb = n, lda = n, ldc = k;
} std::swap(n, k);
CUBLAS_CALL(cublasGemm<DType>(
thr_entry->cublas_handle,
transB,
transA,
n, m, k,
&alpha,
B_data + B_offset, ldb,
A_data + A_offset, lda,
&beta,
C_data + C_offset, ldc));
A_offset += m * k;
B_offset += k * n;
C_offset += m * n;
m_offset += m;
} }
CUBLAS_CALL(cublasGemm<DType>(
thr_entry->cublas_handle, transB, transA, n, m, k, &alpha,
B_data + B_offset, ldb, A_data + A_offset, lda, &beta,
C_data + C_offset, ldc));
A_offset += m * k;
B_offset += k * n;
C_offset += m * n;
m_offset += m;
}
} }
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SegmentMMBackwardB(const NDArray A, void SegmentMMBackwardB(
const NDArray dC, const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) {
NDArray dB, auto device = runtime::DeviceAPI::Get(A->ctx);
const NDArray seglen) { cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(A->ctx); const DType* A_data = A.Ptr<DType>();
cudaStream_t stream = runtime::getCurrentCUDAStream(); const DType* dC_data = dC.Ptr<DType>();
const DType *A_data = A.Ptr<DType>(); const IdType* seglen_data = seglen.Ptr<IdType>();
const DType *dC_data = dC.Ptr<DType>(); DType* dB_data = dB.Ptr<DType>();
const IdType* seglen_data = seglen.Ptr<IdType>(); int64_t A_offset = 0, dC_offset = 0, dB_offset = 0;
DType *dB_data = dB.Ptr<DType>(); int64_t m, n, k;
int64_t A_offset = 0, dC_offset = 0, dB_offset = 0; int64_t num_rel = seglen.NumElements();
int64_t m, n, k; DType alpha = 1., beta = 1.;
int64_t num_rel = seglen.NumElements();
DType alpha = 1., beta = 1.;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
if (!thr_entry->cublas_handle) if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle))); CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream)); CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));
IdType k_offset = 0; IdType k_offset = 0;
for (IdType etype = 0; etype < num_rel; ++etype) { for (IdType etype = 0; etype < num_rel; ++etype) {
m = dC->shape[1]; m = dC->shape[1];
n = A->shape[1]; n = A->shape[1];
k = seglen_data[etype]; k = seglen_data[etype];
CHECK_LE(k_offset + k, A->shape[0]) << "Segement index out of bound of A->shape[0]."; CHECK_LE(k_offset + k, A->shape[0])
int lddC = m, ldA = n, lddB = m; << "Segement index out of bound of A->shape[0].";
cublasOperation_t trans_dC = CUBLAS_OP_N; int lddC = m, ldA = n, lddB = m;
cublasOperation_t trans_A = CUBLAS_OP_T; cublasOperation_t trans_dC = CUBLAS_OP_N;
CUBLAS_CALL(cublasGemm<DType>( cublasOperation_t trans_A = CUBLAS_OP_T;
thr_entry->cublas_handle, CUBLAS_CALL(cublasGemm<DType>(
trans_dC, thr_entry->cublas_handle, trans_dC, trans_A, m, n, k, &alpha,
trans_A, dC_data + dC_offset, lddC, A_data + A_offset, ldA, &beta,
m, n, k, dB_data + dB_offset, lddB));
&alpha, dC_offset += m * k;
dC_data + dC_offset, lddC, A_offset += n * k;
A_data + A_offset, ldA, dB_offset += m * n;
&beta, k_offset += k;
dB_data + dB_offset, lddB)); }
dC_offset += m * k;
A_offset += n * k;
dB_offset += m * n;
k_offset += k;
}
} }
/** /**
...@@ -320,30 +295,23 @@ void SegmentMMBackwardB(const NDArray A, ...@@ -320,30 +295,23 @@ void SegmentMMBackwardB(const NDArray A,
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void GatherMM(const NDArray A, void GatherMM(
const NDArray B, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
NDArray C, const NDArray idx_b) {
const NDArray idx_a, auto device = runtime::DeviceAPI::Get(A->ctx);
const NDArray idx_b) { cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(A->ctx); int64_t out_len = B->shape[2]; // cols of B
cudaStream_t stream = runtime::getCurrentCUDAStream(); int64_t in_len = A->shape[1]; // cols of A
int64_t out_len = B->shape[2]; // cols of B const int64_t tot_num_rows = A->shape[0];
int64_t in_len = A->shape[1]; // cols of A const int ntx = 128;
const int64_t tot_num_rows = A->shape[0]; const int warp_size = 32;
const int ntx = 128; const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);
const int warp_size = 32; const dim3 nblks(nbx);
const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx); const dim3 nthrs(ntx);
const dim3 nblks(nbx); CUDA_KERNEL_CALL(
const dim3 nthrs(ntx); (cuda::GatherMMScatterKernel<IdType, DType>), nblks, nthrs, 0, stream,
CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel<IdType, DType>), A.Ptr<DType>(), B.Ptr<DType>(), C.Ptr<DType>(), idx_a.Ptr<IdType>(),
nblks, nthrs, 0, stream, idx_b.Ptr<IdType>(), nullptr, tot_num_rows, in_len, out_len);
A.Ptr<DType>(),
B.Ptr<DType>(),
C.Ptr<DType>(),
idx_a.Ptr<IdType>(),
idx_b.Ptr<IdType>(),
nullptr,
tot_num_rows, in_len, out_len);
} }
/** /**
...@@ -360,128 +328,118 @@ void GatherMM(const NDArray A, ...@@ -360,128 +328,118 @@ void GatherMM(const NDArray A,
* @param b_trans Matrix B to be transposed * @param b_trans Matrix B to be transposed
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void GatherMMScatter(const NDArray A, void GatherMMScatter(
const NDArray B, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
NDArray C, const NDArray idx_b, const NDArray idx_c) {
const NDArray idx_a, auto device = runtime::DeviceAPI::Get(A->ctx);
const NDArray idx_b, cudaStream_t stream = runtime::getCurrentCUDAStream();
const NDArray idx_c) { const IdType* idx_c_data = idx_c.Ptr<IdType>();
auto device = runtime::DeviceAPI::Get(A->ctx); int64_t out_len = (B->ndim == 2) ? B->shape[1] : B->shape[2]; // cols of B
cudaStream_t stream = runtime::getCurrentCUDAStream(); int64_t in_len = A->shape[1]; // cols of A
const IdType *idx_c_data = idx_c.Ptr<IdType>(); int64_t tot_num_rows = A->shape[0];
int64_t out_len = (B->ndim == 2)? B->shape[1] : B->shape[2]; // cols of B const int ntx = 128;
int64_t in_len = A->shape[1]; // cols of A const int warp_size = 32;
int64_t tot_num_rows = A->shape[0]; const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);
const int ntx = 128; const dim3 nblks(nbx);
const int warp_size = 32; const dim3 nthrs(ntx);
const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx); if (B->ndim == 3) {
const dim3 nblks(nbx); CUDA_KERNEL_CALL(
const dim3 nthrs(ntx); (cuda::GatherMMScatterKernel<IdType, DType>), nblks, nthrs, 0, stream,
if (B->ndim == 3) { A.Ptr<DType>(), B.Ptr<DType>(), C.Ptr<DType>(), idx_a.Ptr<IdType>(),
CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel<IdType, DType>), idx_b.Ptr<IdType>(), idx_c.Ptr<IdType>(), tot_num_rows, in_len,
nblks, nthrs, 0, stream, out_len);
A.Ptr<DType>(), } else {
B.Ptr<DType>(), // Custom kernel for W_grad[idx_c[i]] = H^T[i] * C.grad[i]
C.Ptr<DType>(), // This kernel accesses rows of A in a transposed way w/o explicitly
idx_a.Ptr<IdType>(), // converting A
idx_b.Ptr<IdType>(), CUDA_KERNEL_CALL(
idx_c.Ptr<IdType>(), (cuda::GatherMMScatterKernel2<IdType, DType>), nblks, nthrs, 0, stream,
tot_num_rows, in_len, out_len); A.Ptr<DType>(), B.Ptr<DType>(), C.Ptr<DType>(), idx_a.Ptr<IdType>(),
} else { idx_b.Ptr<IdType>(), idx_c.Ptr<IdType>(), tot_num_rows, in_len,
// Custom kernel for W_grad[idx_c[i]] = H^T[i] * C.grad[i] out_len);
// This kernel accesses rows of A in a transposed way w/o explicitly converting A }
CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel2<IdType, DType>),
nblks, nthrs, 0, stream,
A.Ptr<DType>(),
B.Ptr<DType>(),
C.Ptr<DType>(),
idx_a.Ptr<IdType>(),
idx_b.Ptr<IdType>(),
idx_c.Ptr<IdType>(),
tot_num_rows, in_len, out_len);
}
} }
template void GatherMM<kDGLCUDA, int32_t, __half>( template void GatherMM<kDGLCUDA, int32_t, __half>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_b);
template void GatherMM<kDGLCUDA, int64_t, __half>( template void GatherMM<kDGLCUDA, int64_t, __half>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_b);
#if BF16_ENABLED #if BF16_ENABLED
template void GatherMM<kDGLCUDA, int32_t, __nv_bfloat16>( template void GatherMM<kDGLCUDA, int32_t, __nv_bfloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_b);
template void GatherMM<kDGLCUDA, int64_t, __nv_bfloat16>( template void GatherMM<kDGLCUDA, int64_t, __nv_bfloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_b);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template void GatherMM<kDGLCUDA, int32_t, float>( template void GatherMM<kDGLCUDA, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_b);
template void GatherMM<kDGLCUDA, int64_t, float>( template void GatherMM<kDGLCUDA, int64_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_b);
template void GatherMM<kDGLCUDA, int32_t, double>( template void GatherMM<kDGLCUDA, int32_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_b);
template void GatherMM<kDGLCUDA, int64_t, double>( template void GatherMM<kDGLCUDA, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_b);
template void GatherMMScatter<kDGLCUDA, int32_t, __half>( template void GatherMMScatter<kDGLCUDA, int32_t, __half>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCUDA, int64_t, __half>( template void GatherMMScatter<kDGLCUDA, int64_t, __half>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_b, const NDArray idx_c);
#if BF16_ENABLED #if BF16_ENABLED
template void GatherMMScatter<kDGLCUDA, int32_t, __nv_bfloat16>( template void GatherMMScatter<kDGLCUDA, int32_t, __nv_bfloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCUDA, int64_t, __nv_bfloat16>( template void GatherMMScatter<kDGLCUDA, int64_t, __nv_bfloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_b, const NDArray idx_c);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template void GatherMMScatter<kDGLCUDA, int32_t, float>( template void GatherMMScatter<kDGLCUDA, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCUDA, int64_t, float>( template void GatherMMScatter<kDGLCUDA, int64_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCUDA, int32_t, double>( template void GatherMMScatter<kDGLCUDA, int32_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCUDA, int64_t, double>( template void GatherMMScatter<kDGLCUDA, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_b, const NDArray idx_c);
template void SegmentMM<kDGLCUDA, int32_t, __half>( template void SegmentMM<kDGLCUDA, int32_t, __half>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
const NDArray seglen_A, bool a_trans, bool b_trans); bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int64_t, __half>( template void SegmentMM<kDGLCUDA, int64_t, __half>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
const NDArray seglen_A, bool a_trans, bool b_trans); bool a_trans, bool b_trans);
#if BF16_ENABLED #if BF16_ENABLED
template void SegmentMM<kDGLCUDA, int32_t, __nv_bfloat16>( template void SegmentMM<kDGLCUDA, int32_t, __nv_bfloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
const NDArray seglen_A, bool a_trans, bool b_trans); bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int64_t, __nv_bfloat16>( template void SegmentMM<kDGLCUDA, int64_t, __nv_bfloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
const NDArray seglen_A, bool a_trans, bool b_trans); bool a_trans, bool b_trans);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template void SegmentMM<kDGLCUDA, int32_t, float>( template void SegmentMM<kDGLCUDA, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
const NDArray seglen_A, bool a_trans, bool b_trans); bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int64_t, float>( template void SegmentMM<kDGLCUDA, int64_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
const NDArray seglen_A, bool a_trans, bool b_trans); bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int32_t, double>( template void SegmentMM<kDGLCUDA, int32_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
const NDArray seglen_A, bool a_trans, bool b_trans); bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int64_t, double>( template void SegmentMM<kDGLCUDA, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
const NDArray seglen_A, bool a_trans, bool b_trans); bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDGLCUDA, int32_t, __half>( template void SegmentMMBackwardB<kDGLCUDA, int32_t, __half>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
......
...@@ -6,18 +6,20 @@ ...@@ -6,18 +6,20 @@
* sampling code rowwise_sampling.cu. * sampling code rowwise_sampling.cu.
* @author pengqirong (OPPO), dlasalle and Xin from Nvidia. * @author pengqirong (OPPO), dlasalle and Xin from Nvidia.
*/ */
#include <curand_kernel.h>
#include <dgl/random.h> #include <dgl/random.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <curand_kernel.h>
#include <numeric> #include <numeric>
#include "./dgl_cub.cuh"
#include "./utils.h"
#include "../../array/cuda/atomic.cuh" #include "../../array/cuda/atomic.cuh"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./dgl_cub.cuh"
#include "./utils.h"
// require CUB 1.17 to use DeviceSegmentedSort // require CUB 1.17 to use DeviceSegmentedSort
static_assert(CUB_VERSION >= 101700, "Require CUB >= 1.17 to use DeviceSegmentedSort"); static_assert(
CUB_VERSION >= 101700, "Require CUB >= 1.17 to use DeviceSegmentedSort");
using namespace dgl::aten::cuda; using namespace dgl::aten::cuda;
...@@ -30,26 +32,26 @@ namespace { ...@@ -30,26 +32,26 @@ namespace {
constexpr int BLOCK_SIZE = 128; constexpr int BLOCK_SIZE = 128;
/** /**
* @brief Compute the size of each row in the sampled CSR, without replacement. * @brief Compute the size of each row in the sampled CSR, without replacement.
* temp_deg is calculated for rows with deg > num_picks. * temp_deg is calculated for rows with deg > num_picks.
* For these rows, we will calculate their A-Res values and sort them to get top-num_picks. * For these rows, we will calculate their A-Res values and sort them to get
* * top-num_picks.
* @tparam IdType The type of node and edge indexes. *
* @param num_picks The number of non-zero entries to pick per row. * @tparam IdType The type of node and edge indexes.
* @param num_rows The number of rows to pick. * @param num_picks The number of non-zero entries to pick per row.
* @param in_rows The set of rows to pick. * @param num_rows The number of rows to pick.
* @param in_ptr The index where each row's edges start. * @param in_rows The set of rows to pick.
* @param out_deg The size of each row in the sampled matrix, as indexed by `in_rows` (output). * @param in_ptr The index where each row's edges start.
* @param temp_deg The size of each row in the input matrix, as indexed by `in_rows` (output). * @param out_deg The size of each row in the sampled matrix, as indexed by
*/ * `in_rows` (output).
template<typename IdType> * @param temp_deg The size of each row in the input matrix, as indexed by
* `in_rows` (output).
*/
template <typename IdType>
__global__ void _CSRRowWiseSampleDegreeKernel( __global__ void _CSRRowWiseSampleDegreeKernel(
const int64_t num_picks, const int64_t num_picks, const int64_t num_rows,
const int64_t num_rows, const IdType* const in_rows, const IdType* const in_ptr,
const IdType * const in_rows, IdType* const out_deg, IdType* const temp_deg) {
const IdType * const in_ptr,
IdType * const out_deg,
IdType * const temp_deg) {
const int64_t tIdx = threadIdx.x + blockIdx.x * blockDim.x; const int64_t tIdx = threadIdx.x + blockIdx.x * blockDim.x;
if (tIdx < num_rows) { if (tIdx < num_rows) {
...@@ -69,25 +71,24 @@ __global__ void _CSRRowWiseSampleDegreeKernel( ...@@ -69,25 +71,24 @@ __global__ void _CSRRowWiseSampleDegreeKernel(
} }
/** /**
* @brief Compute the size of each row in the sampled CSR, with replacement. * @brief Compute the size of each row in the sampled CSR, with replacement.
* We need the actual in degree of each row to store CDF values. * We need the actual in degree of each row to store CDF values.
* *
* @tparam IdType The type of node and edge indexes. * @tparam IdType The type of node and edge indexes.
* @param num_picks The number of non-zero entries to pick per row. * @param num_picks The number of non-zero entries to pick per row.
* @param num_rows The number of rows to pick. * @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick. * @param in_rows The set of rows to pick.
* @param in_ptr The index where each row's edges start. * @param in_ptr The index where each row's edges start.
* @param out_deg The size of each row in the sampled matrix, as indexed by `in_rows` (output). * @param out_deg The size of each row in the sampled matrix, as indexed by
* @param temp_deg The size of each row in the input matrix, as indexed by `in_rows` (output). * `in_rows` (output).
*/ * @param temp_deg The size of each row in the input matrix, as indexed by
template<typename IdType> * `in_rows` (output).
*/
template <typename IdType>
__global__ void _CSRRowWiseSampleDegreeReplaceKernel( __global__ void _CSRRowWiseSampleDegreeReplaceKernel(
const int64_t num_picks, const int64_t num_picks, const int64_t num_rows,
const int64_t num_rows, const IdType* const in_rows, const IdType* const in_ptr,
const IdType * const in_rows, IdType* const out_deg, IdType* const temp_deg) {
const IdType * const in_ptr,
IdType * const out_deg,
IdType * const temp_deg) {
const int64_t tIdx = threadIdx.x + blockIdx.x * blockDim.x; const int64_t tIdx = threadIdx.x + blockIdx.x * blockDim.x;
if (tIdx < num_rows) { if (tIdx < num_rows) {
...@@ -106,23 +107,20 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel( ...@@ -106,23 +107,20 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel(
} }
/** /**
* @brief Equivalent to numpy expression: array[idx[off:off + len]] * @brief Equivalent to numpy expression: array[idx[off:off + len]]
* *
* @tparam IdType The ID type used for indices. * @tparam IdType The ID type used for indices.
* @tparam FloatType The float type used for array values. * @tparam FloatType The float type used for array values.
* @param array The array to be selected. * @param array The array to be selected.
* @param idx_data The index mapping array. * @param idx_data The index mapping array.
* @param index The index of value to be selected. * @param index The index of value to be selected.
* @param offset The offset to start. * @param offset The offset to start.
* @param out The selected value (output). * @param out The selected value (output).
*/ */
template<typename IdType, typename FloatType> template <typename IdType, typename FloatType>
__device__ void _DoubleSlice( __device__ void _DoubleSlice(
const FloatType * const array, const FloatType* const array, const IdType* const idx_data,
const IdType * const idx_data, const IdType idx, const IdType offset, FloatType* const out) {
const IdType idx,
const IdType offset,
FloatType* const out) {
if (idx_data) { if (idx_data) {
*out = array[idx_data[offset + idx]]; *out = array[idx_data[offset + idx]];
} else { } else {
...@@ -131,39 +129,35 @@ __device__ void _DoubleSlice( ...@@ -131,39 +129,35 @@ __device__ void _DoubleSlice(
} }
/** /**
* @brief Compute A-Res value. A-Res value needs to be calculated only if deg * @brief Compute A-Res value. A-Res value needs to be calculated only if deg
* is greater than num_picks in weighted rowwise sampling without replacement. * is greater than num_picks in weighted rowwise sampling without replacement.
* *
* @tparam IdType The ID type used for matrices. * @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices. * @tparam FloatType The Float type used for matrices.
* @tparam TILE_SIZE The number of rows covered by each threadblock. * @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use. * @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row. * @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick. * @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick. * @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR. * @param in_ptr The indptr array of the input CSR.
* @param data The data array of the input CSR. * @param data The data array of the input CSR.
* @param prob The probability array of the input CSR. * @param prob The probability array of the input CSR.
* @param ares_ptr The offset to write each row to in the A-res array. * @param ares_ptr The offset to write each row to in the A-res array.
* @param ares_idxs The A-Res value corresponding index array, the index of input CSR (output). * @param ares_idxs The A-Res value corresponding index array, the index of
* @param ares The A-Res value array (output). * input CSR (output).
* @author pengqirong (OPPO) * @param ares The A-Res value array (output).
*/ * @author pengqirong (OPPO)
template<typename IdType, typename FloatType, int TILE_SIZE> */
template <typename IdType, typename FloatType, int TILE_SIZE>
__global__ void _CSRAResValueKernel( __global__ void _CSRAResValueKernel(
const uint64_t rand_seed, const uint64_t rand_seed, const int64_t num_picks, const int64_t num_rows,
const int64_t num_picks, const IdType* const in_rows, const IdType* const in_ptr,
const int64_t num_rows, const IdType* const data, const FloatType* const prob,
const IdType * const in_rows, const IdType* const ares_ptr, IdType* const ares_idxs,
const IdType * const in_ptr, FloatType* const ares) {
const IdType * const data,
const FloatType * const prob,
const IdType * const ares_ptr,
IdType * const ares_idxs,
FloatType * const ares) {
int64_t out_row = blockIdx.x * TILE_SIZE; int64_t out_row = blockIdx.x * TILE_SIZE;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows); const int64_t last_row =
min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
curandStatePhilox4_32_10_t rng; curandStatePhilox4_32_10_t rng;
curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng); curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
...@@ -181,9 +175,11 @@ __global__ void _CSRAResValueKernel( ...@@ -181,9 +175,11 @@ __global__ void _CSRAResValueKernel(
const int64_t in_idx = in_row_start + idx; const int64_t in_idx = in_row_start + idx;
const int64_t ares_idx = ares_row_start + idx; const int64_t ares_idx = ares_row_start + idx;
FloatType item_prob; FloatType item_prob;
_DoubleSlice<IdType, FloatType>(prob, data, idx, in_row_start, &item_prob); _DoubleSlice<IdType, FloatType>(
prob, data, idx, in_row_start, &item_prob);
// compute A-Res value // compute A-Res value
ares[ares_idx] = static_cast<FloatType>(__powf(curand_uniform(&rng), 1.0f / item_prob)); ares[ares_idx] = static_cast<FloatType>(
__powf(curand_uniform(&rng), 1.0f / item_prob));
ares_idxs[ares_idx] = static_cast<IdType>(in_idx); ares_idxs[ares_idx] = static_cast<IdType>(in_idx);
} }
} }
...@@ -191,47 +187,42 @@ __global__ void _CSRAResValueKernel( ...@@ -191,47 +187,42 @@ __global__ void _CSRAResValueKernel(
} }
} }
/** /**
* @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO matrix, * @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO
* without replacement. After sorting, we select top-num_picks items. * matrix, without replacement. After sorting, we select top-num_picks items.
* *
* @tparam IdType The ID type used for matrices. * @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices. * @tparam FloatType The Float type used for matrices.
* @tparam TILE_SIZE The number of rows covered by each threadblock. * @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param num_picks The number of non-zeros to pick per row. * @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick. * @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick. * @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR. * @param in_ptr The indptr array of the input CSR.
* @param in_cols The columns array of the input CSR. * @param in_cols The columns array of the input CSR.
* @param data The data array of the input CSR. * @param data The data array of the input CSR.
* @param out_ptr The offset to write each row to in the output COO. * @param out_ptr The offset to write each row to in the output COO.
* @param ares_ptr The offset to write each row to in the ares array. * @param ares_ptr The offset to write each row to in the ares array.
* @param sort_ares_idxs The sorted A-Res value corresponding index array, the index of input CSR. * @param sort_ares_idxs The sorted A-Res value corresponding index array, the
* @param out_rows The rows of the output COO (output). * index of input CSR.
* @param out_cols The columns of the output COO (output). * @param out_rows The rows of the output COO (output).
* @param out_idxs The data array of the output COO (output). * @param out_cols The columns of the output COO (output).
* @author pengqirong (OPPO) * @param out_idxs The data array of the output COO (output).
*/ * @author pengqirong (OPPO)
template<typename IdType, typename FloatType, int TILE_SIZE> */
template <typename IdType, typename FloatType, int TILE_SIZE>
__global__ void _CSRRowWiseSampleKernel( __global__ void _CSRRowWiseSampleKernel(
const int64_t num_picks, const int64_t num_picks, const int64_t num_rows,
const int64_t num_rows, const IdType* const in_rows, const IdType* const in_ptr,
const IdType * const in_rows, const IdType* const in_cols, const IdType* const data,
const IdType * const in_ptr, const IdType* const out_ptr, const IdType* const ares_ptr,
const IdType * const in_cols, const IdType* const sort_ares_idxs, IdType* const out_rows,
const IdType * const data, IdType* const out_cols, IdType* const out_idxs) {
const IdType * const out_ptr,
const IdType * const ares_ptr,
const IdType * const sort_ares_idxs,
IdType * const out_rows,
IdType * const out_cols,
IdType * const out_idxs) {
// we assign one warp per row // we assign one warp per row
assert(blockDim.x == BLOCK_SIZE); assert(blockDim.x == BLOCK_SIZE);
int64_t out_row = blockIdx.x * TILE_SIZE; int64_t out_row = blockIdx.x * TILE_SIZE;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows); const int64_t last_row =
min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
while (out_row < last_row) { while (out_row < last_row) {
const int64_t row = in_rows[out_row]; const int64_t row = in_rows[out_row];
...@@ -267,70 +258,64 @@ __global__ void _CSRRowWiseSampleKernel( ...@@ -267,70 +258,64 @@ __global__ void _CSRRowWiseSampleKernel(
} }
} }
// A stateful callback functor that maintains a running prefix to be applied // A stateful callback functor that maintains a running prefix to be applied
// during consecutive scan operations. // during consecutive scan operations.
template<typename FloatType> template <typename FloatType>
struct BlockPrefixCallbackOp { struct BlockPrefixCallbackOp {
// Running prefix // Running prefix
FloatType running_total; FloatType running_total;
// Constructor // Constructor
__device__ BlockPrefixCallbackOp(FloatType running_total) : running_total(running_total) {} __device__ BlockPrefixCallbackOp(FloatType running_total)
// Callback operator to be entered by the first warp of threads in the block. : running_total(running_total) {}
// Thread-0 is responsible for returning a value for seeding the block-wide scan. // Callback operator to be entered by the first warp of threads in the block.
__device__ FloatType operator()(FloatType block_aggregate) { // Thread-0 is responsible for returning a value for seeding the block-wide
FloatType old_prefix = running_total; // scan.
running_total += block_aggregate; __device__ FloatType operator()(FloatType block_aggregate) {
return old_prefix; FloatType old_prefix = running_total;
} running_total += block_aggregate;
return old_prefix;
}
}; };
/** /**
* @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO matrix, * @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO
* with replacement. We store the CDF (unnormalized) of all neighbors of a row * matrix, with replacement. We store the CDF (unnormalized) of all neighbors of
* in global memory and use binary search to find inverse indices as selected items. * a row in global memory and use binary search to find inverse indices as
* * selected items.
* @tparam IdType The ID type used for matrices. *
* @tparam FloatType The Float type used for matrices. * @tparam IdType The ID type used for matrices.
* @tparam TILE_SIZE The number of rows covered by each threadblock. * @tparam FloatType The Float type used for matrices.
* @param rand_seed The random seed to use. * @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param num_picks The number of non-zeros to pick per row. * @param rand_seed The random seed to use.
* @param num_rows The number of rows to pick. * @param num_picks The number of non-zeros to pick per row.
* @param in_rows The set of rows to pick. * @param num_rows The number of rows to pick.
* @param in_ptr The indptr array of the input CSR. * @param in_rows The set of rows to pick.
* @param in_cols The columns array of the input CSR. * @param in_ptr The indptr array of the input CSR.
* @param data The data array of the input CSR. * @param in_cols The columns array of the input CSR.
* @param prob The probability array of the input CSR. * @param data The data array of the input CSR.
* @param out_ptr The offset to write each row to in the output COO. * @param prob The probability array of the input CSR.
* @param cdf_ptr The offset of each cdf segment. * @param out_ptr The offset to write each row to in the output COO.
* @param cdf The global buffer to store cdf segments. * @param cdf_ptr The offset of each cdf segment.
* @param out_rows The rows of the output COO (output). * @param cdf The global buffer to store cdf segments.
* @param out_cols The columns of the output COO (output). * @param out_rows The rows of the output COO (output).
* @param out_idxs The data array of the output COO (output). * @param out_cols The columns of the output COO (output).
* @author pengqirong (OPPO) * @param out_idxs The data array of the output COO (output).
*/ * @author pengqirong (OPPO)
template<typename IdType, typename FloatType, int TILE_SIZE> */
template <typename IdType, typename FloatType, int TILE_SIZE>
__global__ void _CSRRowWiseSampleReplaceKernel( __global__ void _CSRRowWiseSampleReplaceKernel(
const uint64_t rand_seed, const uint64_t rand_seed, const int64_t num_picks, const int64_t num_rows,
const int64_t num_picks, const IdType* const in_rows, const IdType* const in_ptr,
const int64_t num_rows, const IdType* const in_cols, const IdType* const data,
const IdType * const in_rows, const FloatType* const prob, const IdType* const out_ptr,
const IdType * const in_ptr, const IdType* const cdf_ptr, FloatType* const cdf, IdType* const out_rows,
const IdType * const in_cols, IdType* const out_cols, IdType* const out_idxs) {
const IdType * const data,
const FloatType * const prob,
const IdType * const out_ptr,
const IdType * const cdf_ptr,
FloatType * const cdf,
IdType * const out_rows,
IdType * const out_cols,
IdType * const out_idxs
) {
// we assign one warp per row // we assign one warp per row
assert(blockDim.x == BLOCK_SIZE); assert(blockDim.x == BLOCK_SIZE);
int64_t out_row = blockIdx.x * TILE_SIZE; int64_t out_row = blockIdx.x * TILE_SIZE;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows); const int64_t last_row =
min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
curandStatePhilox4_32_10_t rng; curandStatePhilox4_32_10_t rng;
curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng); curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
...@@ -357,12 +342,14 @@ __global__ void _CSRRowWiseSampleReplaceKernel( ...@@ -357,12 +342,14 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
// Load a segment of consecutive items that are blocked across threads // Load a segment of consecutive items that are blocked across threads
FloatType thread_data; FloatType thread_data;
if (idx < deg) if (idx < deg)
_DoubleSlice<IdType, FloatType>(prob, data, idx, in_row_start, &thread_data); _DoubleSlice<IdType, FloatType>(
prob, data, idx, in_row_start, &thread_data);
else else
thread_data = MIN_THREAD_DATA; thread_data = MIN_THREAD_DATA;
thread_data = max(thread_data, MIN_THREAD_DATA); thread_data = max(thread_data, MIN_THREAD_DATA);
// Collectively compute the block-wide inclusive prefix sum // Collectively compute the block-wide inclusive prefix sum
BlockScan(temp_storage).InclusiveSum(thread_data, thread_data, prefix_op); BlockScan(temp_storage)
.InclusiveSum(thread_data, thread_data, prefix_op);
__syncthreads(); __syncthreads();
// Store scanned items to cdf array // Store scanned items to cdf array
...@@ -376,7 +363,8 @@ __global__ void _CSRRowWiseSampleReplaceKernel( ...@@ -376,7 +363,8 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
// get random value // get random value
FloatType sum = cdf[cdf_row_start + deg - 1]; FloatType sum = cdf[cdf_row_start + deg - 1];
FloatType rand = static_cast<FloatType>(curand_uniform(&rng) * sum); FloatType rand = static_cast<FloatType>(curand_uniform(&rng) * sum);
// get the offset of the first value within cdf array which is greater than random value. // get the offset of the first value within cdf array which is greater
// than random value.
int64_t item = cub::UpperBound<FloatType*, int64_t, FloatType>( int64_t item = cub::UpperBound<FloatType*, int64_t, FloatType>(
&cdf[cdf_row_start], deg, rand); &cdf[cdf_row_start], deg, rand);
item = min(item, deg - 1); item = min(item, deg - 1);
...@@ -395,7 +383,8 @@ __global__ void _CSRRowWiseSampleReplaceKernel( ...@@ -395,7 +383,8 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
template <typename IdType, typename DType, typename BoolType> template <typename IdType, typename DType, typename BoolType>
__global__ void _GenerateFlagsKernel( __global__ void _GenerateFlagsKernel(
int64_t n, const IdType* idx, const DType* values, DType criteria, BoolType* output) { int64_t n, const IdType* idx, const DType* values, DType criteria,
BoolType* output) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < n) { while (tx < n) {
...@@ -413,8 +402,8 @@ COOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) { ...@@ -413,8 +402,8 @@ COOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) {
const int64_t nnz = coo.row->shape[0]; const int64_t nnz = coo.row->shape[0];
const IdType* row = coo.row.Ptr<IdType>(); const IdType* row = coo.row.Ptr<IdType>();
const IdType* col = coo.col.Ptr<IdType>(); const IdType* col = coo.col.Ptr<IdType>();
const IdArray& eid = COOHasData(coo) ? coo.data : Range( const IdArray& eid =
0, nnz, sizeof(IdType) * 8, ctx); COOHasData(coo) ? coo.data : Range(0, nnz, sizeof(IdType) * 8, ctx);
const IdType* data = coo.data.Ptr<IdType>(); const IdType* data = coo.data.Ptr<IdType>();
IdArray new_row = IdArray::Empty({nnz}, idtype, ctx); IdArray new_row = IdArray::Empty({nnz}, idtype, ctx);
IdArray new_col = IdArray::Empty({nnz}, idtype, ctx); IdArray new_col = IdArray::Empty({nnz}, idtype, ctx);
...@@ -431,7 +420,8 @@ COOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) { ...@@ -431,7 +420,8 @@ COOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) {
maskgen(nb, nt, stream, nnz, data, flags); maskgen(nb, nt, stream, nnz, data, flags);
int64_t* rst = static_cast<int64_t*>(device->AllocWorkspace(ctx, sizeof(int64_t))); int64_t* rst =
static_cast<int64_t*>(device->AllocWorkspace(ctx, sizeof(int64_t)));
MaskSelect(device, ctx, row, flags, new_row_data, nnz, rst, stream); MaskSelect(device, ctx, row, flags, new_row_data, nnz, rst, stream);
MaskSelect(device, ctx, col, flags, new_col_data, nnz, rst, stream); MaskSelect(device, ctx, col, flags, new_col_data, nnz, rst, stream);
MaskSelect(device, ctx, data, flags, new_eid_data, nnz, rst, stream); MaskSelect(device, ctx, data, flags, new_eid_data, nnz, rst, stream);
...@@ -441,24 +431,24 @@ COOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) { ...@@ -441,24 +431,24 @@ COOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) {
device->FreeWorkspace(ctx, flags); device->FreeWorkspace(ctx, flags);
device->FreeWorkspace(ctx, rst); device->FreeWorkspace(ctx, rst);
return COOMatrix( return COOMatrix(
coo.num_rows, coo.num_rows, coo.num_cols, new_row.CreateView({new_len}, idtype, 0),
coo.num_cols,
new_row.CreateView({new_len}, idtype, 0),
new_col.CreateView({new_len}, idtype, 0), new_col.CreateView({new_len}, idtype, 0),
new_eid.CreateView({new_len}, idtype, 0)); new_eid.CreateView({new_len}, idtype, 0));
} }
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix _COORemoveIf(const COOMatrix& coo, const NDArray& values, DType criteria) { COOMatrix _COORemoveIf(
const COOMatrix& coo, const NDArray& values, DType criteria) {
const DType* val = values.Ptr<DType>(); const DType* val = values.Ptr<DType>();
auto maskgen = [val, criteria] ( auto maskgen = [val, criteria](
int nb, int nt, cudaStream_t stream, int64_t nnz, const IdType* data, int nb, int nt, cudaStream_t stream, int64_t nnz,
int8_t* flags) { const IdType* data, int8_t* flags) {
CUDA_KERNEL_CALL((_GenerateFlagsKernel<IdType, DType, int8_t>), CUDA_KERNEL_CALL(
nb, nt, 0, stream, (_GenerateFlagsKernel<IdType, DType, int8_t>), nb, nt, 0, stream, nnz,
nnz, data, val, criteria, flags); data, val, criteria, flags);
}; };
return COOGeneralRemoveIf<XPU, IdType, DType, decltype(maskgen)>(coo, maskgen); return COOGeneralRemoveIf<XPU, IdType, DType, decltype(maskgen)>(
coo, maskgen);
} }
} // namespace } // namespace
...@@ -466,42 +456,42 @@ COOMatrix _COORemoveIf(const COOMatrix& coo, const NDArray& values, DType criter ...@@ -466,42 +456,42 @@ COOMatrix _COORemoveIf(const COOMatrix& coo, const NDArray& values, DType criter
/////////////////////////////// CSR /////////////////////////////// /////////////////////////////// CSR ///////////////////////////////
/** /**
* @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO matrix. * @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO
* Use CDF sampling algorithm for with replacement: * matrix. Use CDF sampling algorithm for with replacement:
* 1) Calculate the CDF of all neighbor's prob. * 1) Calculate the CDF of all neighbor's prob.
* 2) For each [0, num_picks), generate a rand ~ U(0, 1). * 2) For each [0, num_picks), generate a rand ~ U(0, 1). Use binary search to
* Use binary search to find its index in the CDF array as a chosen item. * find its index in the CDF array as a chosen item.
* Use A-Res sampling algorithm for without replacement: * Use A-Res sampling algorithm for without replacement:
* 1) For rows with deg > num_picks, calculate A-Res values for all neighbors. * 1) For rows with deg > num_picks, calculate A-Res values for all neighbors.
* 2) Sort the A-Res array and select top-num_picks as chosen items. * 2) Sort the A-Res array and select top-num_picks as chosen items.
* *
* @tparam XPU The device type used for matrices. * @tparam XPU The device type used for matrices.
* @tparam IdType The ID type used for matrices. * @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices. * @tparam FloatType The Float type used for matrices.
* @param mat The CSR matrix. * @param mat The CSR matrix.
* @param rows The set of rows to pick. * @param rows The set of rows to pick.
* @param num_picks The number of non-zeros to pick per row. * @param num_picks The number of non-zeros to pick per row.
* @param prob The probability array of the input CSR. * @param prob The probability array of the input CSR.
* @param replace Is replacement sampling? * @param replace Is replacement sampling?
* @author pengqirong (OPPO), dlasalle and Xin from Nvidia. * @author pengqirong (OPPO), dlasalle and Xin from Nvidia.
*/ */
template <DGLDeviceType XPU, typename IdType, typename FloatType> template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix _CSRRowWiseSampling( COOMatrix _CSRRowWiseSampling(
const CSRMatrix& mat, const CSRMatrix& mat, const IdArray& rows, int64_t num_picks,
const IdArray& rows, const FloatArray& prob, bool replace) {
int64_t num_picks,
const FloatArray& prob,
bool replace) {
const auto& ctx = rows->ctx; const auto& ctx = rows->ctx;
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const int64_t num_rows = rows->shape[0]; const int64_t num_rows = rows->shape[0];
const IdType * const slice_rows = static_cast<const IdType*>(rows->data); const IdType* const slice_rows = static_cast<const IdType*>(rows->data);
IdArray picked_row = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8); IdArray picked_row =
IdArray picked_col = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8); NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdArray picked_idx = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8); IdArray picked_col =
NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdArray picked_idx =
NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdType* const out_rows = static_cast<IdType*>(picked_row->data); IdType* const out_rows = static_cast<IdType*>(picked_row->data);
IdType* const out_cols = static_cast<IdType*>(picked_col->data); IdType* const out_cols = static_cast<IdType*>(picked_col->data);
IdType* const out_idxs = static_cast<IdType*>(picked_idx->data); IdType* const out_idxs = static_cast<IdType*>(picked_idx->data);
...@@ -511,16 +501,12 @@ COOMatrix _CSRRowWiseSampling( ...@@ -511,16 +501,12 @@ COOMatrix _CSRRowWiseSampling(
const IdType* data = CSRHasData(mat) ? mat.data.Ptr<IdType>() : nullptr; const IdType* data = CSRHasData(mat) ? mat.data.Ptr<IdType>() : nullptr;
const FloatType* prob_data = prob.Ptr<FloatType>(); const FloatType* prob_data = prob.Ptr<FloatType>();
if (mat.is_pinned) { if (mat.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(cudaHostGetDevicePointer(&in_ptr, mat.indptr.Ptr<IdType>(), 0));
&in_ptr, mat.indptr.Ptr<IdType>(), 0)); CUDA_CALL(cudaHostGetDevicePointer(&in_cols, mat.indices.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer(
&in_cols, mat.indices.Ptr<IdType>(), 0));
if (CSRHasData(mat)) { if (CSRHasData(mat)) {
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(cudaHostGetDevicePointer(&data, mat.data.Ptr<IdType>(), 0));
&data, mat.data.Ptr<IdType>(), 0));
} }
CUDA_CALL(cudaHostGetDevicePointer( CUDA_CALL(cudaHostGetDevicePointer(&prob_data, prob.Ptr<FloatType>(), 0));
&prob_data, prob.Ptr<FloatType>(), 0));
} }
// compute degree // compute degree
...@@ -528,41 +514,33 @@ COOMatrix _CSRRowWiseSampling( ...@@ -528,41 +514,33 @@ COOMatrix _CSRRowWiseSampling(
// temp_deg: the size of each row we will manipulate in sampling // temp_deg: the size of each row we will manipulate in sampling
// 1) for w/o replacement: in degree if it's greater than num_picks else 0 // 1) for w/o replacement: in degree if it's greater than num_picks else 0
// 2) for w/ replacement: in degree // 2) for w/ replacement: in degree
IdType * out_deg = static_cast<IdType*>( IdType* out_deg = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType))); device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
IdType * temp_deg = static_cast<IdType*>( IdType* temp_deg = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType))); device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
if (replace) { if (replace) {
const dim3 block(512); const dim3 block(512);
const dim3 grid((num_rows + block.x - 1) / block.x); const dim3 grid((num_rows + block.x - 1) / block.x);
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
_CSRRowWiseSampleDegreeReplaceKernel, _CSRRowWiseSampleDegreeReplaceKernel, grid, block, 0, stream, num_picks,
grid, block, 0, stream, num_rows, slice_rows, in_ptr, out_deg, temp_deg);
num_picks, num_rows, slice_rows, in_ptr, out_deg, temp_deg);
} else { } else {
const dim3 block(512); const dim3 block(512);
const dim3 grid((num_rows + block.x - 1) / block.x); const dim3 grid((num_rows + block.x - 1) / block.x);
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
_CSRRowWiseSampleDegreeKernel, _CSRRowWiseSampleDegreeKernel, grid, block, 0, stream, num_picks,
grid, block, 0, stream, num_rows, slice_rows, in_ptr, out_deg, temp_deg);
num_picks, num_rows, slice_rows, in_ptr, out_deg, temp_deg);
} }
// fill temp_ptr // fill temp_ptr
IdType * temp_ptr = static_cast<IdType*>( IdType* temp_ptr = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows + 1)*sizeof(IdType))); device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
size_t prefix_temp_size = 0; size_t prefix_temp_size = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_temp_size, CUDA_CALL(cub::DeviceScan::ExclusiveSum(
temp_deg, nullptr, prefix_temp_size, temp_deg, temp_ptr, num_rows + 1, stream));
temp_ptr, void* prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
num_rows + 1, CUDA_CALL(cub::DeviceScan::ExclusiveSum(
stream)); prefix_temp, prefix_temp_size, temp_deg, temp_ptr, num_rows + 1, stream));
void * prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_temp, prefix_temp_size,
temp_deg,
temp_ptr,
num_rows + 1,
stream));
device->FreeWorkspace(ctx, prefix_temp); device->FreeWorkspace(ctx, prefix_temp);
device->FreeWorkspace(ctx, temp_deg); device->FreeWorkspace(ctx, temp_deg);
...@@ -570,49 +548,40 @@ COOMatrix _CSRRowWiseSampling( ...@@ -570,49 +548,40 @@ COOMatrix _CSRRowWiseSampling(
// cuda events cannot be ignored. Just use synchronized copy. // cuda events cannot be ignored. Just use synchronized copy.
IdType temp_len; IdType temp_len;
// copy using the internal current stream. // copy using the internal current stream.
device->CopyDataFromTo(temp_ptr, num_rows * sizeof(temp_len), &temp_len, 0, device->CopyDataFromTo(
sizeof(temp_len), temp_ptr, num_rows * sizeof(temp_len), &temp_len, 0, sizeof(temp_len),
ctx, ctx, DGLContext{kDGLCPU, 0}, mat.indptr->dtype);
DGLContext{kDGLCPU, 0},
mat.indptr->dtype);
device->StreamSync(ctx, stream); device->StreamSync(ctx, stream);
// fill out_ptr // fill out_ptr
IdType * out_ptr = static_cast<IdType*>( IdType* out_ptr = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows+1)*sizeof(IdType))); device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
prefix_temp_size = 0; prefix_temp_size = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_temp_size, CUDA_CALL(cub::DeviceScan::ExclusiveSum(
out_deg, nullptr, prefix_temp_size, out_deg, out_ptr, num_rows + 1, stream));
out_ptr,
num_rows+1,
stream));
prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size); prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_temp, prefix_temp_size, CUDA_CALL(cub::DeviceScan::ExclusiveSum(
out_deg, prefix_temp, prefix_temp_size, out_deg, out_ptr, num_rows + 1, stream));
out_ptr,
num_rows+1,
stream));
device->FreeWorkspace(ctx, prefix_temp); device->FreeWorkspace(ctx, prefix_temp);
device->FreeWorkspace(ctx, out_deg); device->FreeWorkspace(ctx, out_deg);
cudaEvent_t copyEvent; cudaEvent_t copyEvent;
CUDA_CALL(cudaEventCreate(&copyEvent)); CUDA_CALL(cudaEventCreate(&copyEvent));
// TODO(dlasalle): use pinned memory to overlap with the actual sampling, and wait on // TODO(dlasalle): use pinned memory to overlap with the actual sampling, and
// a cudaevent // wait on a cudaevent
IdType new_len; IdType new_len;
// copy using the internal current stream. // copy using the internal current stream.
device->CopyDataFromTo(out_ptr, num_rows * sizeof(new_len), &new_len, 0, device->CopyDataFromTo(
sizeof(new_len), out_ptr, num_rows * sizeof(new_len), &new_len, 0, sizeof(new_len), ctx,
ctx, DGLContext{kDGLCPU, 0}, mat.indptr->dtype);
DGLContext{kDGLCPU, 0},
mat.indptr->dtype);
CUDA_CALL(cudaEventRecord(copyEvent, stream)); CUDA_CALL(cudaEventRecord(copyEvent, stream));
// allocate workspace // allocate workspace
// 1) for w/ replacement, it's a global buffer to store cdf segments (one segment for each row). // 1) for w/ replacement, it's a global buffer to store cdf segments (one
// segment for each row).
// 2) for w/o replacement, it's used to store a-res segments (one segment for // 2) for w/o replacement, it's used to store a-res segments (one segment for
// each row with degree > num_picks) // each row with degree > num_picks)
FloatType * temp = static_cast<FloatType*>( FloatType* temp = static_cast<FloatType*>(
device->AllocWorkspace(ctx, temp_len * sizeof(FloatType))); device->AllocWorkspace(ctx, temp_len * sizeof(FloatType)));
const uint64_t rand_seed = RandomEngine::ThreadLocal()->RandInt(1000000000); const uint64_t rand_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
...@@ -624,21 +593,9 @@ COOMatrix _CSRRowWiseSampling( ...@@ -624,21 +593,9 @@ COOMatrix _CSRRowWiseSampling(
const dim3 block(BLOCK_SIZE); const dim3 block(BLOCK_SIZE);
const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE); const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
(_CSRRowWiseSampleReplaceKernel<IdType, FloatType, TILE_SIZE>), (_CSRRowWiseSampleReplaceKernel<IdType, FloatType, TILE_SIZE>), grid,
grid, block, 0, stream, block, 0, stream, rand_seed, num_picks, num_rows, slice_rows, in_ptr,
rand_seed, in_cols, data, prob_data, out_ptr, temp_ptr, temp, out_rows, out_cols,
num_picks,
num_rows,
slice_rows,
in_ptr,
in_cols,
data,
prob_data,
out_ptr,
temp_ptr,
temp,
out_rows,
out_cols,
out_idxs); out_idxs);
device->FreeWorkspace(ctx, temp); device->FreeWorkspace(ctx, temp);
} else { // without replacement } else { // without replacement
...@@ -646,53 +603,33 @@ COOMatrix _CSRRowWiseSampling( ...@@ -646,53 +603,33 @@ COOMatrix _CSRRowWiseSampling(
device->AllocWorkspace(ctx, (temp_len) * sizeof(IdType))); device->AllocWorkspace(ctx, (temp_len) * sizeof(IdType)));
// Compute A-Res value. A-Res value needs to be calculated only if deg // Compute A-Res value. A-Res value needs to be calculated only if deg
// is greater than num_picks in weighted rowwise sampling without replacement. // is greater than num_picks in weighted rowwise sampling without
// replacement.
const dim3 block(BLOCK_SIZE); const dim3 block(BLOCK_SIZE);
const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE); const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
(_CSRAResValueKernel<IdType, FloatType, TILE_SIZE>), (_CSRAResValueKernel<IdType, FloatType, TILE_SIZE>), grid, block, 0,
grid, block, 0, stream, stream, rand_seed, num_picks, num_rows, slice_rows, in_ptr, data,
rand_seed, prob_data, temp_ptr, temp_idxs, temp);
num_picks,
num_rows,
slice_rows,
in_ptr,
data,
prob_data,
temp_ptr,
temp_idxs,
temp);
// sort A-Res value array. // sort A-Res value array.
FloatType* sort_temp = static_cast<FloatType*>( FloatType* sort_temp = static_cast<FloatType*>(
device->AllocWorkspace(ctx, temp_len * sizeof(FloatType))); device->AllocWorkspace(ctx, temp_len * sizeof(FloatType)));
IdType* sort_temp_idxs = static_cast<IdType*>( IdType* sort_temp_idxs = static_cast<IdType*>(
device->AllocWorkspace(ctx, temp_len * sizeof(IdType))); device->AllocWorkspace(ctx, temp_len * sizeof(IdType)));
cub::DoubleBuffer<FloatType> sort_keys(temp, sort_temp); cub::DoubleBuffer<FloatType> sort_keys(temp, sort_temp);
cub::DoubleBuffer<IdType> sort_values(temp_idxs, sort_temp_idxs); cub::DoubleBuffer<IdType> sort_values(temp_idxs, sort_temp_idxs);
void *d_temp_storage = nullptr; void* d_temp_storage = nullptr;
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
CUDA_CALL(cub::DeviceSegmentedSort::SortPairsDescending( CUDA_CALL(cub::DeviceSegmentedSort::SortPairsDescending(
d_temp_storage, d_temp_storage, temp_storage_bytes, sort_keys, sort_values, temp_len,
temp_storage_bytes, num_rows, temp_ptr, temp_ptr + 1, stream));
sort_keys,
sort_values,
temp_len,
num_rows,
temp_ptr,
temp_ptr + 1, stream));
d_temp_storage = device->AllocWorkspace(ctx, temp_storage_bytes); d_temp_storage = device->AllocWorkspace(ctx, temp_storage_bytes);
CUDA_CALL(cub::DeviceSegmentedSort::SortPairsDescending( CUDA_CALL(cub::DeviceSegmentedSort::SortPairsDescending(
d_temp_storage, d_temp_storage, temp_storage_bytes, sort_keys, sort_values, temp_len,
temp_storage_bytes, num_rows, temp_ptr, temp_ptr + 1, stream));
sort_keys,
sort_values,
temp_len,
num_rows,
temp_ptr,
temp_ptr + 1, stream));
device->FreeWorkspace(ctx, d_temp_storage); device->FreeWorkspace(ctx, d_temp_storage);
device->FreeWorkspace(ctx, temp); device->FreeWorkspace(ctx, temp);
device->FreeWorkspace(ctx, temp_idxs); device->FreeWorkspace(ctx, temp_idxs);
...@@ -701,20 +638,9 @@ COOMatrix _CSRRowWiseSampling( ...@@ -701,20 +638,9 @@ COOMatrix _CSRRowWiseSampling(
// select tok-num_picks as results // select tok-num_picks as results
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
(_CSRRowWiseSampleKernel<IdType, FloatType, TILE_SIZE>), (_CSRRowWiseSampleKernel<IdType, FloatType, TILE_SIZE>), grid, block, 0,
grid, block, 0, stream, stream, num_picks, num_rows, slice_rows, in_ptr, in_cols, data, out_ptr,
num_picks, temp_ptr, sort_values.Current(), out_rows, out_cols, out_idxs);
num_rows,
slice_rows,
in_ptr,
in_cols,
data,
out_ptr,
temp_ptr,
sort_values.Current(),
out_rows,
out_cols,
out_idxs);
} }
device->FreeWorkspace(ctx, temp_ptr); device->FreeWorkspace(ctx, temp_ptr);
...@@ -728,44 +654,48 @@ COOMatrix _CSRRowWiseSampling( ...@@ -728,44 +654,48 @@ COOMatrix _CSRRowWiseSampling(
picked_col = picked_col.CreateView({new_len}, picked_col->dtype); picked_col = picked_col.CreateView({new_len}, picked_col->dtype);
picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype); picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype);
return COOMatrix(mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx); return COOMatrix(
mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx);
} }
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWiseSampling( COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_picks, FloatArray prob, bool replace) { CSRMatrix mat, IdArray rows, int64_t num_picks, FloatArray prob,
bool replace) {
COOMatrix result; COOMatrix result;
if (num_picks == -1) { if (num_picks == -1) {
// Basically this is UnitGraph::InEdges(). // Basically this is UnitGraph::InEdges().
COOMatrix coo = CSRToCOO(CSRSliceRows(mat, rows), false); COOMatrix coo = CSRToCOO(CSRSliceRows(mat, rows), false);
IdArray sliced_rows = IndexSelect(rows, coo.row); IdArray sliced_rows = IndexSelect(rows, coo.row);
result = COOMatrix(mat.num_rows, mat.num_cols, sliced_rows, coo.col, coo.data); result =
COOMatrix(mat.num_rows, mat.num_cols, sliced_rows, coo.col, coo.data);
} else { } else {
result = _CSRRowWiseSampling<XPU, IdType, DType>(mat, rows, num_picks, prob, replace); result = _CSRRowWiseSampling<XPU, IdType, DType>(
mat, rows, num_picks, prob, replace);
} }
// NOTE(BarclayII): I'm removing the entries with zero probability after sampling. // NOTE(BarclayII): I'm removing the entries with zero probability after
// Is there a better way? // sampling. Is there a better way?
return _COORemoveIf<XPU, IdType, DType>(result, prob, static_cast<DType>(0)); return _COORemoveIf<XPU, IdType, DType>(result, prob, static_cast<DType>(0));
} }
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, float>( template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool); CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, float>( template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool); CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, double>( template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool); CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, double>( template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool); CSRMatrix, IdArray, int64_t, FloatArray, bool);
// These are not being called, but we instantiate them anyway to prevent missing // These are not being called, but we instantiate them anyway to prevent missing
// symbols in Debug build // symbols in Debug build
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, int8_t>( template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, int8_t>(
CSRMatrix, IdArray, int64_t, FloatArray, bool); CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, int8_t>( template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, int8_t>(
CSRMatrix, IdArray, int64_t, FloatArray, bool); CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, uint8_t>( template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, uint8_t>(
CSRMatrix, IdArray, int64_t, FloatArray, bool); CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, uint8_t>( template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, uint8_t>(
CSRMatrix, IdArray, int64_t, FloatArray, bool); CSRMatrix, IdArray, int64_t, FloatArray, bool);
} // namespace impl } // namespace impl
} // namespace aten } // namespace aten
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
* @brief SDDMM C APIs and definitions. * @brief SDDMM C APIs and definitions.
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "./sddmm.cuh"
#include "./functor.cuh" #include "./functor.cuh"
#include "./sddmm.cuh"
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -14,110 +15,85 @@ namespace aten { ...@@ -14,110 +15,85 @@ namespace aten {
* @brief CUDA implementation of g-SDDMM on Csr format. * @brief CUDA implementation of g-SDDMM on Csr format.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op, void SDDMMCsr(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
const CSRMatrix& csr, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target) {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out); cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, csr, lhs, rhs, out);
}); });
}); });
} }
/** /**
* @brief CUDA implementation of g-SDDMM on Coo format. * @brief CUDA implementation of g-SDDMM on Coo format.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op, void SDDMMCoo(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
const COOMatrix& coo, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target) {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out); cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, coo, lhs, rhs, out);
}); });
}); });
} }
template void SDDMMCsr<kDGLCUDA, int32_t, __half>( template void SDDMMCsr<kDGLCUDA, int32_t, __half>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int64_t, __half>( template void SDDMMCsr<kDGLCUDA, int64_t, __half>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
#if BF16_ENABLED #if BF16_ENABLED
template void SDDMMCsr<kDGLCUDA, int32_t, __nv_bfloat16>( template void SDDMMCsr<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int64_t, __nv_bfloat16>( template void SDDMMCsr<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template void SDDMMCsr<kDGLCUDA, int32_t, float>( template void SDDMMCsr<kDGLCUDA, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int64_t, float>( template void SDDMMCsr<kDGLCUDA, int64_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int32_t, double>( template void SDDMMCsr<kDGLCUDA, int32_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int64_t, double>( template void SDDMMCsr<kDGLCUDA, int64_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int32_t, __half>( template void SDDMMCoo<kDGLCUDA, int32_t, __half>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int64_t, __half>( template void SDDMMCoo<kDGLCUDA, int64_t, __half>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
#if BF16_ENABLED #if BF16_ENABLED
template void SDDMMCoo<kDGLCUDA, int32_t, __nv_bfloat16>( template void SDDMMCoo<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int64_t, __nv_bfloat16>( template void SDDMMCoo<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template void SDDMMCoo<kDGLCUDA, int32_t, float>( template void SDDMMCoo<kDGLCUDA, int32_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int64_t, float>( template void SDDMMCoo<kDGLCUDA, int64_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int32_t, double>( template void SDDMMCoo<kDGLCUDA, int32_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int64_t, double>( template void SDDMMCoo<kDGLCUDA, int64_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
int lhs_target, int rhs_target);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -7,15 +7,16 @@ ...@@ -7,15 +7,16 @@
#define DGL_ARRAY_CUDA_SDDMM_CUH_ #define DGL_ARRAY_CUDA_SDDMM_CUH_
#include <dgl/bcast.h> #include <dgl/bcast.h>
#include "macro.cuh"
#include "../../runtime/cuda/cuda_common.h"
#include "../selector.h"
#include "./functor.cuh"
#include "./utils.h"
#include "atomic.cuh" #include "atomic.cuh"
#include "functor.cuh"
#include "fp16.cuh"
#include "bf16.cuh" #include "bf16.cuh"
#include "./utils.h" #include "fp16.cuh"
#include "./functor.cuh" #include "functor.cuh"
#include "../selector.h" #include "macro.cuh"
#include "../../runtime/cuda/cuda_common.h"
namespace dgl { namespace dgl {
...@@ -24,64 +25,64 @@ using namespace cuda; ...@@ -24,64 +25,64 @@ using namespace cuda;
namespace aten { namespace aten {
namespace cuda { namespace cuda {
#define SWITCH_OP(op, Op, ...) \ #define SWITCH_OP(op, Op, ...) \
do { \
if ((op) == "add") { \
typedef cuda::binary::Add<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "sub") { \
typedef cuda::binary::Sub<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "mul") { \
typedef cuda::binary::Mul<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "div") { \
typedef cuda::binary::Div<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_lhs") { \
typedef cuda::binary::CopyLhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_rhs") { \
typedef cuda::binary::CopyRhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "dot") { \
typedef cuda::binary::Dot<DType> Op; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Unsupported SpMM/SDDMM binary operator: " << op; \
} \
} while (0)
#define SWITCH_RHS(rhs_target, RhsTarget, ...) \
do { \ do { \
if ((rhs_target) == 0) { \ if ((op) == "add") { \
constexpr int RhsTarget = 0; \ typedef cuda::binary::Add<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "sub") { \
typedef cuda::binary::Sub<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "mul") { \
typedef cuda::binary::Mul<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((rhs_target) == 1) { \ } else if ((op) == "div") { \
constexpr int RhsTarget = 1; \ typedef cuda::binary::Div<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((rhs_target) == 2) { \ } else if ((op) == "copy_lhs") { \
constexpr int RhsTarget = 2; \ typedef cuda::binary::CopyLhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_rhs") { \
typedef cuda::binary::CopyRhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "dot") { \
typedef cuda::binary::Dot<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else { \ } else { \
LOG(INFO) << "Invalid rhs target: " << (rhs_target); \ LOG(FATAL) << "Unsupported SpMM/SDDMM binary operator: " << op; \
} \ } \
} while (0) } while (0)
#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...)\ #define SWITCH_RHS(rhs_target, RhsTarget, ...) \
do { \ do { \
if ((lhs_target) == 0) { \ if ((rhs_target) == 0) { \
constexpr int LhsTarget = 0; \ constexpr int RhsTarget = 0; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \ { __VA_ARGS__ } \
} else if ((lhs_target) == 1) { \ } else if ((rhs_target) == 1) { \
constexpr int LhsTarget = 1; \ constexpr int RhsTarget = 1; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \ { __VA_ARGS__ } \
} else if ((lhs_target) == 2) { \ } else if ((rhs_target) == 2) { \
constexpr int LhsTarget = 2; \ constexpr int RhsTarget = 2; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \ { __VA_ARGS__ } \
} else { \ } else { \
LOG(INFO) << "Invalid lhs target: " << (lhs_target); \ LOG(INFO) << "Invalid rhs target: " << (rhs_target); \
} \ } \
} while (0)
#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...) \
do { \
if ((lhs_target) == 0) { \
constexpr int LhsTarget = 0; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 1) { \
constexpr int LhsTarget = 1; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 2) { \
constexpr int LhsTarget = 2; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else { \
LOG(INFO) << "Invalid lhs target: " << (lhs_target); \
} \
} while (0) } while (0)
constexpr unsigned int full_mask = 0xffffffff; constexpr unsigned int full_mask = 0xffffffff;
...@@ -90,23 +91,19 @@ constexpr unsigned int full_mask = 0xffffffff; ...@@ -90,23 +91,19 @@ constexpr unsigned int full_mask = 0xffffffff;
* @brief CUDA kernel of g-SDDMM on Coo format. * @brief CUDA kernel of g-SDDMM on Coo format.
* @note it uses edge parallel strategy, different threadblocks (on y-axis) * @note it uses edge parallel strategy, different threadblocks (on y-axis)
* is responsible for the computation on different edges. Threadblocks * is responsible for the computation on different edges. Threadblocks
* on the x-axis are responsible for the computation on different positions * on the x-axis are responsible for the computation on different
* in feature dimension. * positions in feature dimension.
*/ */
template <typename Idx, typename DType, typename BinaryOp, template <
bool UseBcast = false, bool UseIdx = false, typename Idx, typename DType, typename BinaryOp, bool UseBcast = false,
int LhsTarget = 0, int RhsTarget = 2> bool UseIdx = false, int LhsTarget = 0, int RhsTarget = 2>
__global__ void SDDMMCooKernel( __global__ void SDDMMCooKernel(
const DType* __restrict__ lhs, const DType* __restrict__ lhs, const DType* __restrict__ rhs,
const DType* __restrict__ rhs, DType* __restrict__ out, const Idx* __restrict__ row,
DType* __restrict__ out, const Idx* __restrict__ col, const Idx* __restrict__ edge_map, int64_t N,
const Idx* __restrict__ row, int64_t M, int64_t E, int64_t reduce_size,
const Idx* __restrict__ col, const int64_t* __restrict__ lhs_off, const int64_t* __restrict__ rhs_off,
const Idx* __restrict__ edge_map, int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
int64_t N, int64_t M, int64_t E, int64_t reduce_size,
const int64_t* __restrict__ lhs_off,
const int64_t* __restrict__ rhs_off,
int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
// SDDMM with COO. // SDDMM with COO.
Idx ty = blockIdx.y * blockDim.y + threadIdx.y; Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.y; const Idx stride_y = blockDim.y * gridDim.y;
...@@ -114,10 +111,14 @@ __global__ void SDDMMCooKernel( ...@@ -114,10 +111,14 @@ __global__ void SDDMMCooKernel(
const Idx src = _ldg(row + ty); const Idx src = _ldg(row + ty);
const Idx dst = _ldg(col + ty); const Idx dst = _ldg(col + ty);
const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty; const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
const DType* lhsoff = BinaryOp::use_lhs ? const DType* lhsoff =
(lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len): nullptr; BinaryOp::use_lhs
const DType* rhsoff = BinaryOp::use_rhs ? ? (lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len)
(rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len): nullptr; : nullptr;
const DType* rhsoff =
BinaryOp::use_rhs
? (rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len)
: nullptr;
DType* outoff = out + eid * out_len; DType* outoff = out + eid * out_len;
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = blockDim.x * gridDim.x; const int stride_x = blockDim.x * gridDim.x;
...@@ -125,8 +126,7 @@ __global__ void SDDMMCooKernel( ...@@ -125,8 +126,7 @@ __global__ void SDDMMCooKernel(
const Idx lhs_add = UseBcast ? lhs_off[tx] : tx; const Idx lhs_add = UseBcast ? lhs_off[tx] : tx;
const Idx rhs_add = UseBcast ? rhs_off[tx] : tx; const Idx rhs_add = UseBcast ? rhs_off[tx] : tx;
DType val = BinaryOp::Call( DType val = BinaryOp::Call(
lhsoff + lhs_add * reduce_size, lhsoff + lhs_add * reduce_size, rhsoff + rhs_add * reduce_size,
rhsoff + rhs_add * reduce_size,
reduce_size); reduce_size);
outoff[tx] = val; outoff[tx] = val;
tx += stride_x; tx += stride_x;
...@@ -136,56 +136,58 @@ __global__ void SDDMMCooKernel( ...@@ -136,56 +136,58 @@ __global__ void SDDMMCooKernel(
} }
/** /**
* @brief CUDA kernel of SDDMM-dot on Coo format, accelerated with tree reduction. * @brief CUDA kernel of SDDMM-dot on Coo format, accelerated with tree
* reduction.
* @note it uses edge parallel strategy, different threadblocks (on y-axis) * @note it uses edge parallel strategy, different threadblocks (on y-axis)
* is responsible for the computation on different edges. Threadblocks * is responsible for the computation on different edges. Threadblocks
* on the x-axis are responsible for the computation on different positions * on the x-axis are responsible for the computation on different
* in feature dimension. * positions in feature dimension.
*/ */
template <typename Idx, typename DType, template <
bool UseBcast = false, bool UseIdx = false, typename Idx, typename DType, bool UseBcast = false, bool UseIdx = false,
int LhsTarget = 0, int RhsTarget = 2> int LhsTarget = 0, int RhsTarget = 2>
__global__ void SDDMMCooTreeReduceKernel( __global__ void SDDMMCooTreeReduceKernel(
const DType* __restrict__ lhs, const DType* __restrict__ lhs, const DType* __restrict__ rhs,
const DType* __restrict__ rhs, DType* __restrict__ out, const Idx* __restrict__ row,
DType* __restrict__ out, const Idx* __restrict__ col, const Idx* __restrict__ edge_map, int64_t N,
const Idx* __restrict__ row, int64_t M, int64_t E, int64_t reduce_size,
const Idx* __restrict__ col, const int64_t* __restrict__ lhs_off, const int64_t* __restrict__ rhs_off,
const Idx* __restrict__ edge_map, int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
int64_t N, int64_t M, int64_t E, int64_t reduce_size,
const int64_t* __restrict__ lhs_off,
const int64_t* __restrict__ rhs_off,
int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
Idx ty = blockIdx.x * blockDim.y + threadIdx.y; Idx ty = blockIdx.x * blockDim.y + threadIdx.y;
if (ty < E) { if (ty < E) {
const Idx src = _ldg(row + ty); const Idx src = _ldg(row + ty);
const Idx dst = _ldg(col + ty); const Idx dst = _ldg(col + ty);
const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty; const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
const DType* lhsoff = lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len; const DType* lhsoff =
const DType* rhsoff = rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len; lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len;
const DType* rhsoff =
rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len;
DType* outoff = out + eid * out_len; DType* outoff = out + eid * out_len;
int tx = threadIdx.x; // tx < 32 int tx = threadIdx.x; // tx < 32
for (int i = blockIdx.y; i < out_len; i += gridDim.y) { // over output feature dimension for (int i = blockIdx.y; i < out_len;
i += gridDim.y) { // over output feature dimension
const Idx lhs_add = UseBcast ? __ldg(lhs_off + i) : i; const Idx lhs_add = UseBcast ? __ldg(lhs_off + i) : i;
const Idx rhs_add = UseBcast ? __ldg(rhs_off + i) : i; const Idx rhs_add = UseBcast ? __ldg(rhs_off + i) : i;
DType val = reduce::Sum<Idx, DType>::zero();; DType val = reduce::Sum<Idx, DType>::zero();
for (int j = tx; j < reduce_size; j += 64) { for (int j = tx; j < reduce_size; j += 64) {
val += lhsoff[lhs_add * reduce_size + j] * rhsoff[rhs_add * reduce_size + j]; val += lhsoff[lhs_add * reduce_size + j] *
rhsoff[rhs_add * reduce_size + j];
if (j + 32 < reduce_size) if (j + 32 < reduce_size)
val += lhsoff[lhs_add * reduce_size + j + 32] * rhsoff[rhs_add * reduce_size + j + 32]; val += lhsoff[lhs_add * reduce_size + j + 32] *
rhsoff[rhs_add * reduce_size + j + 32];
} }
#pragma unroll #pragma unroll
for (int offset = 16; offset > 0; offset /= 2) for (int offset = 16; offset > 0; offset /= 2)
val += __shfl_down_sync(full_mask, val, offset); val += __shfl_down_sync(full_mask, val, offset);
if (tx == 0) if (tx == 0) outoff[i] = val;
outoff[i] = val;
} }
} }
} }
// Binary search the row_offsets to find the source node of the edge id. // Binary search the row_offsets to find the source node of the edge id.
template <typename Idx> template <typename Idx>
__device__ __forceinline__ Idx BinarySearchSrc(const Idx *array, Idx length, Idx eid) { __device__ __forceinline__ Idx
BinarySearchSrc(const Idx* array, Idx length, Idx eid) {
Idx lo = 0, hi = length - 1; Idx lo = 0, hi = length - 1;
while (lo < hi) { while (lo < hi) {
Idx mid = (lo + hi) >> 1; Idx mid = (lo + hi) >> 1;
...@@ -207,25 +209,21 @@ __device__ __forceinline__ Idx BinarySearchSrc(const Idx *array, Idx length, Idx ...@@ -207,25 +209,21 @@ __device__ __forceinline__ Idx BinarySearchSrc(const Idx *array, Idx length, Idx
* @brief CUDA kernel of g-SDDMM on Csr format. * @brief CUDA kernel of g-SDDMM on Csr format.
* @note it uses edge parallel strategy, different threadblocks (on y-axis) * @note it uses edge parallel strategy, different threadblocks (on y-axis)
* is responsible for the computation on different edges. Threadblocks * is responsible for the computation on different edges. Threadblocks
* on the x-axis are responsible for the computation on different positions * on the x-axis are responsible for the computation on different
* in feature dimension. * positions in feature dimension. To efficiently find the source node idx and
* To efficiently find the source node idx and destination node index of an * destination node index of an given edge on Csr format, it uses binary search
* given edge on Csr format, it uses binary search (time complexity O(log N)). * (time complexity O(log N)).
*/ */
template <typename Idx, typename DType, typename BinaryOp, template <
bool UseBcast = false, bool UseIdx = false, typename Idx, typename DType, typename BinaryOp, bool UseBcast = false,
int LhsTarget = 0, int RhsTarget = 2> bool UseIdx = false, int LhsTarget = 0, int RhsTarget = 2>
__global__ void SDDMMCsrKernel( __global__ void SDDMMCsrKernel(
const DType* __restrict__ lhs, const DType* __restrict__ lhs, const DType* __restrict__ rhs,
const DType* __restrict__ rhs, DType* __restrict__ out, const Idx* __restrict__ indptr,
DType* __restrict__ out, const Idx* __restrict__ indices, const Idx* __restrict__ edge_map,
const Idx* __restrict__ indptr, int64_t N, int64_t M, int64_t E, int64_t reduce_size,
const Idx* __restrict__ indices, const int64_t* __restrict__ lhs_off, const int64_t* __restrict__ rhs_off,
const Idx* __restrict__ edge_map, int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
int64_t N, int64_t M, int64_t E, int64_t reduce_size,
const int64_t* __restrict__ lhs_off,
const int64_t* __restrict__ rhs_off,
int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
// SDDMM with Csr. // SDDMM with Csr.
Idx ty = blockIdx.y * blockDim.y + threadIdx.y; Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.y; const Idx stride_y = blockDim.y * gridDim.y;
...@@ -235,17 +233,20 @@ __global__ void SDDMMCsrKernel( ...@@ -235,17 +233,20 @@ __global__ void SDDMMCsrKernel(
const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty; const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t stride_x = blockDim.x * gridDim.x; const int64_t stride_x = blockDim.x * gridDim.x;
const DType* lhsoff = BinaryOp::use_lhs ? const DType* lhsoff =
(lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len): nullptr; BinaryOp::use_lhs
const DType* rhsoff = BinaryOp::use_rhs ? ? (lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len)
(rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len): nullptr; : nullptr;
const DType* rhsoff =
BinaryOp::use_rhs
? (rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len)
: nullptr;
DType* outoff = out + eid * out_len; DType* outoff = out + eid * out_len;
while (tx < out_len) { while (tx < out_len) {
const Idx lhs_add = UseBcast ? lhs_off[tx] : tx; const Idx lhs_add = UseBcast ? lhs_off[tx] : tx;
const Idx rhs_add = UseBcast ? rhs_off[tx] : tx; const Idx rhs_add = UseBcast ? rhs_off[tx] : tx;
DType val = BinaryOp::Call( DType val = BinaryOp::Call(
lhsoff + lhs_add * reduce_size, lhsoff + lhs_add * reduce_size, rhsoff + rhs_add * reduce_size,
rhsoff + rhs_add * reduce_size,
reduce_size); reduce_size);
outoff[tx] = val; outoff[tx] = val;
tx += stride_x; tx += stride_x;
...@@ -262,26 +263,22 @@ __global__ void SDDMMCsrKernel( ...@@ -262,26 +263,22 @@ __global__ void SDDMMCsrKernel(
* @param rhs The right hand size operand feature. * @param rhs The right hand size operand feature.
* @param out The result feature on edges. * @param out The result feature on edges.
*/ */
template <typename Idx, typename DType, typename Op, template <
int LhsTarget = 0, int RhsTarget = 2> typename Idx, typename DType, typename Op, int LhsTarget = 0,
int RhsTarget = 2>
void SDDMMCoo( void SDDMMCoo(
const BcastOff& bcast, const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs,
const COOMatrix& coo,
NDArray lhs,
NDArray rhs,
NDArray out) { NDArray out) {
const Idx *row = coo.row.Ptr<Idx>(); const Idx* row = coo.row.Ptr<Idx>();
const Idx *col = coo.col.Ptr<Idx>(); const Idx* col = coo.col.Ptr<Idx>();
const Idx *edge_map = coo.data.Ptr<Idx>(); const Idx* edge_map = coo.data.Ptr<Idx>();
const DType *lhs_data = lhs.Ptr<DType>(); const DType* lhs_data = lhs.Ptr<DType>();
const DType *rhs_data = rhs.Ptr<DType>(); const DType* rhs_data = rhs.Ptr<DType>();
DType *out_data = out.Ptr<DType>(); DType* out_data = out.Ptr<DType>();
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t *lhs_off = nullptr, *rhs_off = nullptr; int64_t *lhs_off = nullptr, *rhs_off = nullptr;
int64_t len = bcast.out_len, int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len;
int64_t reduce_dim = bcast.reduce_size; int64_t reduce_dim = bcast.reduce_size;
const int64_t nnz = coo.row->shape[0]; const int64_t nnz = coo.row->shape[0];
...@@ -296,13 +293,11 @@ void SDDMMCoo( ...@@ -296,13 +293,11 @@ void SDDMMCoo(
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, { BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
(SDDMMCooTreeReduceKernel<Idx, DType, UseBcast, UseIdx, LhsTarget, RhsTarget>), (SDDMMCooTreeReduceKernel<
nblks, nthrs, 0, stream, Idx, DType, UseBcast, UseIdx, LhsTarget, RhsTarget>),
lhs_data, rhs_data, out_data, nblks, nthrs, 0, stream, lhs_data, rhs_data, out_data, row, col,
row, col, edge_map, edge_map, coo.num_rows, coo.num_cols, nnz, reduce_dim, lhs_off,
coo.num_rows, coo.num_cols, nnz, reduce_dim, rhs_off, lhs_len, rhs_len, len);
lhs_off, rhs_off,
lhs_len, rhs_len, len);
}); });
} else { } else {
const int ntx = FindNumThreads(len); const int ntx = FindNumThreads(len);
...@@ -312,13 +307,12 @@ void SDDMMCoo( ...@@ -312,13 +307,12 @@ void SDDMMCoo(
const dim3 nblks(nbx, nby); const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, { BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
CUDA_KERNEL_CALL((SDDMMCooKernel<Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>), CUDA_KERNEL_CALL(
nblks, nthrs, 0, stream, (SDDMMCooKernel<
lhs_data, rhs_data, out_data, Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),
row, col, edge_map, nblks, nthrs, 0, stream, lhs_data, rhs_data, out_data, row, col,
coo.num_rows, coo.num_cols, nnz, reduce_dim, edge_map, coo.num_rows, coo.num_cols, nnz, reduce_dim, lhs_off,
lhs_off, rhs_off, rhs_off, lhs_len, rhs_len, len);
lhs_len, rhs_len, len);
}); });
} }
} }
...@@ -331,27 +325,23 @@ void SDDMMCoo( ...@@ -331,27 +325,23 @@ void SDDMMCoo(
* @param rhs The right hand size operand feature. * @param rhs The right hand size operand feature.
* @param out The result feature on edges. * @param out The result feature on edges.
*/ */
template <typename Idx, typename DType, typename Op, template <
int LhsTarget = 0, int RhsTarget = 2> typename Idx, typename DType, typename Op, int LhsTarget = 0,
int RhsTarget = 2>
void SDDMMCsr( void SDDMMCsr(
const BcastOff& bcast, const BcastOff& bcast, const CSRMatrix& csr, NDArray lhs, NDArray rhs,
const CSRMatrix& csr,
NDArray lhs,
NDArray rhs,
NDArray out) { NDArray out) {
const Idx *indptr = csr.indptr.Ptr<Idx>(); const Idx* indptr = csr.indptr.Ptr<Idx>();
const Idx *indices = csr.indices.Ptr<Idx>(); const Idx* indices = csr.indices.Ptr<Idx>();
const Idx *edge_map = csr.data.Ptr<Idx>(); const Idx* edge_map = csr.data.Ptr<Idx>();
const DType *lhs_data = lhs.Ptr<DType>(); const DType* lhs_data = lhs.Ptr<DType>();
const DType *rhs_data = rhs.Ptr<DType>(); const DType* rhs_data = rhs.Ptr<DType>();
DType *out_data = out.Ptr<DType>(); DType* out_data = out.Ptr<DType>();
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t N = csr.num_rows, M = csr.num_cols, E = csr.indices->shape[0]; int64_t N = csr.num_rows, M = csr.num_cols, E = csr.indices->shape[0];
int64_t *lhs_off = nullptr, *rhs_off = nullptr; int64_t *lhs_off = nullptr, *rhs_off = nullptr;
int64_t len = bcast.out_len, int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len;
int64_t reduce_dim = bcast.reduce_size; int64_t reduce_dim = bcast.reduce_size;
const int ntx = FindNumThreads(len); const int ntx = FindNumThreads(len);
...@@ -363,17 +353,14 @@ void SDDMMCsr( ...@@ -363,17 +353,14 @@ void SDDMMCsr(
const bool use_idx = !IsNullArray(csr.data); const bool use_idx = !IsNullArray(csr.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, { BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
CUDA_KERNEL_CALL((SDDMMCsrKernel<Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>), CUDA_KERNEL_CALL(
nblks, nthrs, 0, stream, (SDDMMCsrKernel<
lhs_data, rhs_data, out_data, Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),
indptr, indices, edge_map, nblks, nthrs, 0, stream, lhs_data, rhs_data, out_data, indptr, indices,
N, M, E, reduce_dim, edge_map, N, M, E, reduce_dim, lhs_off, rhs_off, lhs_len, rhs_len, len);
lhs_off, rhs_off,
lhs_len, rhs_len, len);
}); });
} }
} // namespace cuda } // namespace cuda
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
* @brief SDDMM C APIs and definitions. * @brief SDDMM C APIs and definitions.
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "./sddmm.cuh" #include "./sddmm.cuh"
namespace dgl { namespace dgl {
...@@ -14,16 +15,12 @@ namespace aten { ...@@ -14,16 +15,12 @@ namespace aten {
Csr format. Csr format.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SDDMMCooHetero(const std::string& op, void SDDMMCooHetero(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_lhs, const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
const std::vector<NDArray>& vec_rhs, int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_eid,
std::vector<NDArray> vec_out, const std::vector<dgl_type_t>& rhs_eid) {
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid) {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM CUDA kernel for each relation type sequentially */ /* Call SDDMM CUDA kernel for each relation type sequentially */
...@@ -33,7 +30,7 @@ void SDDMMCooHetero(const std::string& op, ...@@ -33,7 +30,7 @@ void SDDMMCooHetero(const std::string& op,
NDArray rhs = vec_rhs[rhs_eid[etype]]; NDArray rhs = vec_rhs[rhs_eid[etype]];
NDArray out = vec_out[etype]; NDArray out = vec_out[etype];
cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>( cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, coo, lhs, rhs, out); bcast, coo, lhs, rhs, out);
} }
}); });
}); });
...@@ -41,61 +38,53 @@ void SDDMMCooHetero(const std::string& op, ...@@ -41,61 +38,53 @@ void SDDMMCooHetero(const std::string& op,
template void SDDMMCooHetero<kDGLCUDA, int32_t, __half>( template void SDDMMCooHetero<kDGLCUDA, int32_t, __half>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int64_t, __half>( template void SDDMMCooHetero<kDGLCUDA, int64_t, __half>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
#if BF16_ENABLED #if BF16_ENABLED
template void SDDMMCooHetero<kDGLCUDA, int32_t, __nv_bfloat16>( template void SDDMMCooHetero<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int64_t, __nv_bfloat16>( template void SDDMMCooHetero<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template void SDDMMCooHetero<kDGLCUDA, int32_t, float>( template void SDDMMCooHetero<kDGLCUDA, int32_t, float>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int64_t, float>( template void SDDMMCooHetero<kDGLCUDA, int64_t, float>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int32_t, double>( template void SDDMMCooHetero<kDGLCUDA, int32_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int64_t, double>( template void SDDMMCooHetero<kDGLCUDA, int64_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
} // namespace aten } // namespace aten
......
...@@ -4,26 +4,22 @@ ...@@ -4,26 +4,22 @@
* @brief SDDMM C APIs and definitions. * @brief SDDMM C APIs and definitions.
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "./sddmm.cuh" #include "./sddmm.cuh"
namespace dgl { namespace dgl {
namespace aten { namespace aten {
/** /**
* @brief CUDA implementation of g-SDDMM on heterograph using * @brief CUDA implementation of g-SDDMM on heterograph using Csr format.
Csr format.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SDDMMCsrHetero(const std::string& op, void SDDMMCsrHetero(
const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_lhs, const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
const std::vector<NDArray>& vec_rhs, int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_eid,
std::vector<NDArray> vec_out, const std::vector<dgl_type_t>& rhs_eid) {
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid) {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM CUDA kernel for each relation type sequentially */ /* Call SDDMM CUDA kernel for each relation type sequentially */
...@@ -33,7 +29,7 @@ void SDDMMCsrHetero(const std::string& op, ...@@ -33,7 +29,7 @@ void SDDMMCsrHetero(const std::string& op,
NDArray rhs = vec_rhs[rhs_eid[etype]]; NDArray rhs = vec_rhs[rhs_eid[etype]];
NDArray out = vec_out[etype]; NDArray out = vec_out[etype];
cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>( cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, csr, lhs, rhs, out); bcast, csr, lhs, rhs, out);
} }
}); });
}); });
...@@ -41,61 +37,53 @@ void SDDMMCsrHetero(const std::string& op, ...@@ -41,61 +37,53 @@ void SDDMMCsrHetero(const std::string& op,
template void SDDMMCsrHetero<kDGLCUDA, int32_t, __half>( template void SDDMMCsrHetero<kDGLCUDA, int32_t, __half>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int64_t, __half>( template void SDDMMCsrHetero<kDGLCUDA, int64_t, __half>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
#if BF16_ENABLED #if BF16_ENABLED
template void SDDMMCsrHetero<kDGLCUDA, int32_t, __nv_bfloat16>( template void SDDMMCsrHetero<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int64_t, __nv_bfloat16>( template void SDDMMCsrHetero<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template void SDDMMCsrHetero<kDGLCUDA, int32_t, float>( template void SDDMMCsrHetero<kDGLCUDA, int32_t, float>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int64_t, float>( template void SDDMMCsrHetero<kDGLCUDA, int64_t, float>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int32_t, double>( template void SDDMMCsrHetero<kDGLCUDA, int32_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int64_t, double>( template void SDDMMCsrHetero<kDGLCUDA, int64_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> out, int lhs_target, int rhs_target, int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
} // namespace aten } // namespace aten
......
...@@ -5,24 +5,21 @@ ...@@ -5,24 +5,21 @@
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include "./segment_reduce.cuh"
#include "./functor.cuh" #include "./functor.cuh"
#include "./segment_reduce.cuh"
#include "./utils.h" #include "./utils.h"
namespace dgl { namespace dgl {
using namespace cuda; using namespace cuda;
namespace aten { namespace aten {
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SegmentReduce(const std::string& op, void SegmentReduce(
NDArray feat, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray offsets, NDArray arg) {
NDArray out,
NDArray arg) {
if (op == "sum") { if (op == "sum") {
cuda::SegmentReduce<IdType, DType, cuda::reduce::Sum<IdType, DType>>( cuda::SegmentReduce<IdType, DType, cuda::reduce::Sum<IdType, DType>>(
feat, offsets, out, arg); feat, offsets, out, arg);
...@@ -37,119 +34,70 @@ void SegmentReduce(const std::string& op, ...@@ -37,119 +34,70 @@ void SegmentReduce(const std::string& op,
} }
} }
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void ScatterAdd(NDArray feat, void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
NDArray idx,
NDArray out) {
cuda::ScatterAdd<IdType, DType>(feat, idx, out); cuda::ScatterAdd<IdType, DType>(feat, idx, out);
} }
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g, void UpdateGradMinMax_hetero(
const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx, const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {
const std::vector<NDArray>& idx_etype, cuda::UpdateGradMinMax_hetero<IdType, DType>(
std::vector<NDArray>* out) { g, op, feat, idx, idx_etype, out);
cuda::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);
} }
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void BackwardSegmentCmp(NDArray feat, void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
NDArray arg,
NDArray out) {
cuda::BackwardSegmentCmp<IdType, DType>(feat, arg, out); cuda::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
} }
template void SegmentReduce<kDGLCUDA, int32_t, __half>( template void SegmentReduce<kDGLCUDA, int32_t, __half>(
const std::string& op, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCUDA, int64_t, __half>( template void SegmentReduce<kDGLCUDA, int64_t, __half>(
const std::string &op, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg); NDArray arg);
#if BF16_ENABLED #if BF16_ENABLED
template void SegmentReduce<kDGLCUDA, int32_t, __nv_bfloat16>( template void SegmentReduce<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCUDA, int64_t, __nv_bfloat16>( template void SegmentReduce<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string &op, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg); NDArray arg);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template void SegmentReduce<kDGLCUDA, int32_t, float>( template void SegmentReduce<kDGLCUDA, int32_t, float>(
const std::string& op, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCUDA, int64_t, float>( template void SegmentReduce<kDGLCUDA, int64_t, float>(
const std::string &op, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCUDA, int32_t, double>( template void SegmentReduce<kDGLCUDA, int32_t, double>(
const std::string &op, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCUDA, int64_t, double>( template void SegmentReduce<kDGLCUDA, int64_t, double>(
const std::string &op, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg); NDArray arg);
template void ScatterAdd<kDGLCUDA, int32_t, __half>( template void ScatterAdd<kDGLCUDA, int32_t, __half>(
NDArray feat, NDArray feat, NDArray idx, NDArray out);
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCUDA, int64_t, __half>( template void ScatterAdd<kDGLCUDA, int64_t, __half>(
NDArray feat, NDArray feat, NDArray idx, NDArray out);
NDArray idx,
NDArray out);
#if BF16_ENABLED #if BF16_ENABLED
template void ScatterAdd<kDGLCUDA, int32_t, __nv_bfloat16>( template void ScatterAdd<kDGLCUDA, int32_t, __nv_bfloat16>(
NDArray feat, NDArray feat, NDArray idx, NDArray out);
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCUDA, int64_t, __nv_bfloat16>( template void ScatterAdd<kDGLCUDA, int64_t, __nv_bfloat16>(
NDArray feat, NDArray feat, NDArray idx, NDArray out);
NDArray idx,
NDArray out);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template void ScatterAdd<kDGLCUDA, int32_t, float>( template void ScatterAdd<kDGLCUDA, int32_t, float>(
NDArray feat, NDArray feat, NDArray idx, NDArray out);
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCUDA, int64_t, float>( template void ScatterAdd<kDGLCUDA, int64_t, float>(
NDArray feat, NDArray feat, NDArray idx, NDArray out);
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCUDA, int32_t, double>( template void ScatterAdd<kDGLCUDA, int32_t, double>(
NDArray feat, NDArray feat, NDArray idx, NDArray out);
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCUDA, int64_t, double>( template void ScatterAdd<kDGLCUDA, int64_t, double>(
NDArray feat, NDArray feat, NDArray idx, NDArray out);
NDArray idx,
NDArray out);
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, __half>( template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, __half>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
...@@ -187,39 +135,23 @@ template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, double>( ...@@ -187,39 +135,23 @@ template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, double>(
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDGLCUDA, int32_t, __half>( template void BackwardSegmentCmp<kDGLCUDA, int32_t, __half>(
NDArray feat, NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int64_t, __half>( template void BackwardSegmentCmp<kDGLCUDA, int64_t, __half>(
NDArray feat, NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
#if BF16_ENABLED #if BF16_ENABLED
template void BackwardSegmentCmp<kDGLCUDA, int32_t, __nv_bfloat16>( template void BackwardSegmentCmp<kDGLCUDA, int32_t, __nv_bfloat16>(
NDArray feat, NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int64_t, __nv_bfloat16>( template void BackwardSegmentCmp<kDGLCUDA, int64_t, __nv_bfloat16>(
NDArray feat, NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template void BackwardSegmentCmp<kDGLCUDA, int32_t, float>( template void BackwardSegmentCmp<kDGLCUDA, int32_t, float>(
NDArray feat, NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int64_t, float>( template void BackwardSegmentCmp<kDGLCUDA, int64_t, float>(
NDArray feat, NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int32_t, double>( template void BackwardSegmentCmp<kDGLCUDA, int32_t, double>(
NDArray feat, NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int64_t, double>( template void BackwardSegmentCmp<kDGLCUDA, int64_t, double>(
NDArray feat, NDArray feat, NDArray arg, NDArray out);
NDArray arg,
NDArray out);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
* @brief SPMM C APIs and definitions. * @brief SPMM C APIs and definitions.
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "./spmm.cuh"
#include "./ge_spmm.cuh"
#include "./functor.cuh"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./functor.cuh"
#include "./ge_spmm.cuh"
#include "./spmm.cuh"
namespace dgl { namespace dgl {
...@@ -21,13 +22,10 @@ namespace aten { ...@@ -21,13 +22,10 @@ namespace aten {
* no broadcast, use dgl's kernel in other cases. * no broadcast, use dgl's kernel in other cases.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SpMMCsr(const std::string& op, const std::string& reduce, void SpMMCsr(
const BcastOff& bcast, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, std::vector<NDArray> out_aux) {
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux) {
bool is_scalar_efeat = efeat.NumElements() == csr.indices->shape[0]; bool is_scalar_efeat = efeat.NumElements() == csr.indices->shape[0];
bool use_efeat = op != "copy_lhs"; bool use_efeat = op != "copy_lhs";
...@@ -36,27 +34,22 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -36,27 +34,22 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
if (op == "copy_lhs" && cusparse_available<DType, IdType>(more_nnz)) { if (op == "copy_lhs" && cusparse_available<DType, IdType>(more_nnz)) {
// cusparse // cusparse
int64_t x_length = 1; int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i) for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i];
x_length *= ufeat->shape[i];
CusparseCsrmm2<DType, IdType>( CusparseCsrmm2<DType, IdType>(
ufeat->ctx, csr, ufeat->ctx, csr, static_cast<DType*>(ufeat->data), nullptr,
static_cast<DType*>(ufeat->data), static_cast<DType*>(out->data), x_length);
nullptr, } else if (
static_cast<DType*>(out->data), op == "mul" && is_scalar_efeat &&
x_length); cusparse_available<DType, IdType>(more_nnz)) {
} else if (op == "mul" && is_scalar_efeat && cusparse_available<DType, IdType>(more_nnz)) {
// cusparse // cusparse
int64_t x_length = 1; int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i) for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i];
x_length *= ufeat->shape[i];
if (!IsNullArray(csr.data)) { if (!IsNullArray(csr.data)) {
efeat = _IndexSelect<DType, IdType>(efeat, csr.data); efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
} }
CusparseCsrmm2<DType, IdType>( CusparseCsrmm2<DType, IdType>(
ufeat->ctx, csr, ufeat->ctx, csr, static_cast<DType*>(ufeat->data),
static_cast<DType*>(ufeat->data), static_cast<DType*>(efeat->data), static_cast<DType*>(out->data),
static_cast<DType*>(efeat->data),
static_cast<DType*>(out->data),
x_length); x_length);
} else { // general kernel } else { // general kernel
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
...@@ -79,31 +72,27 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -79,31 +72,27 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
} }
} }
/** /**
* @brief CUDA implementation of g-SpMM on Coo format. * @brief CUDA implementation of g-SpMM on Coo format.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SpMMCoo(const std::string& op, const std::string& reduce, void SpMMCoo(
const BcastOff& bcast, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, std::vector<NDArray> out_aux) {
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux) {
if (reduce == "sum") { if (reduce == "sum") {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Sum<IdType, DType, true> > ( cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Sum<IdType, DType, true> >(
bcast, coo, ufeat, efeat, out, NullArray(), NullArray()); bcast, coo, ufeat, efeat, out, NullArray(), NullArray());
}); });
} else if (reduce == "max") { } else if (reduce == "max") {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Max<IdType, DType, true> > ( cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Max<IdType, DType, true> >(
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]); bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
}); });
} else if (reduce == "min") { } else if (reduce == "min") {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Min<IdType, DType, true> > ( cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Min<IdType, DType, true> >(
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]); bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
}); });
} else { } else {
...@@ -112,74 +101,74 @@ void SpMMCoo(const std::string& op, const std::string& reduce, ...@@ -112,74 +101,74 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
} }
template void SpMMCsr<kDGLCUDA, int32_t, __half>( template void SpMMCsr<kDGLCUDA, int32_t, __half>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int64_t, __half>( template void SpMMCsr<kDGLCUDA, int64_t, __half>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
#if BF16_ENABLED #if BF16_ENABLED
template void SpMMCsr<kDGLCUDA, int32_t, __nv_bfloat16>( template void SpMMCsr<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int64_t, __nv_bfloat16>( template void SpMMCsr<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template void SpMMCsr<kDGLCUDA, int32_t, float>( template void SpMMCsr<kDGLCUDA, int32_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int64_t, float>( template void SpMMCsr<kDGLCUDA, int64_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int32_t, double>( template void SpMMCsr<kDGLCUDA, int32_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int64_t, double>( template void SpMMCsr<kDGLCUDA, int64_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const CSRMatrix& csr, const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int32_t, __half>( template void SpMMCoo<kDGLCUDA, int32_t, __half>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int64_t, __half>( template void SpMMCoo<kDGLCUDA, int64_t, __half>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
#if BF16_ENABLED #if BF16_ENABLED
template void SpMMCoo<kDGLCUDA, int32_t, __nv_bfloat16>( template void SpMMCoo<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int64_t, __nv_bfloat16>( template void SpMMCoo<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template void SpMMCoo<kDGLCUDA, int32_t, float>( template void SpMMCoo<kDGLCUDA, int32_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int64_t, float>( template void SpMMCoo<kDGLCUDA, int64_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int32_t, double>( template void SpMMCoo<kDGLCUDA, int32_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int64_t, double>( template void SpMMCoo<kDGLCUDA, int64_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const COOMatrix& coo, const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); std::vector<NDArray> out_aux);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -7,13 +7,15 @@ ...@@ -7,13 +7,15 @@
#define DGL_ARRAY_CUDA_SPMM_CUH_ #define DGL_ARRAY_CUDA_SPMM_CUH_
#include <dgl/bcast.h> #include <dgl/bcast.h>
#include <limits> #include <limits>
#include "macro.cuh"
#include "fp16.cuh"
#include "bf16.cuh"
#include "atomic.cuh"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
#include "atomic.cuh"
#include "bf16.cuh"
#include "fp16.cuh"
#include "macro.cuh"
namespace dgl { namespace dgl {
...@@ -32,9 +34,11 @@ inline bool cusparse_available(bool more_nnz_than_matrix_size) { ...@@ -32,9 +34,11 @@ inline bool cusparse_available(bool more_nnz_than_matrix_size) {
return true; return true;
return false; return false;
#else #else
if (std::is_same<DType, __half>::value || std::is_same<DType, __nv_bfloat16>::value) if (std::is_same<DType, __half>::value ||
std::is_same<DType, __nv_bfloat16>::value)
return false; // cusparse's SpMM on fp16 is slow, temporally disabled. return false; // cusparse's SpMM on fp16 is slow, temporally disabled.
// If the CSR matrix has more NNZ than matrix size, we should not use cuSPARSE 11.1. // If the CSR matrix has more NNZ than matrix size, we should not use
// cuSPARSE 11.1.
return !more_nnz_than_matrix_size; return !more_nnz_than_matrix_size;
#endif #endif
} }
...@@ -43,21 +47,19 @@ namespace { ...@@ -43,21 +47,19 @@ namespace {
/** @brief Call cuBLAS geam API for transpose operation for float and double. */ /** @brief Call cuBLAS geam API for transpose operation for float and double. */
template <typename DType> template <typename DType>
cublasStatus_t Xgeam(cublasHandle_t handle, cublasOperation_t transa, cublasStatus_t Xgeam(
cublasOperation_t transb, int m, int n, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
const DType* alpha, const DType* A, int lda, int m, int n, const DType* alpha, const DType* A, int lda,
const DType* beta, const DType* B, int ldb, const DType* beta, const DType* B, int ldb, DType* C, int ldc) {
DType* C, int ldc) {
LOG(FATAL) << "Not supported dtype"; LOG(FATAL) << "Not supported dtype";
return CUBLAS_STATUS_EXECUTION_FAILED; return CUBLAS_STATUS_EXECUTION_FAILED;
} }
template <> template <>
cublasStatus_t Xgeam<__half>(cublasHandle_t handle, cublasOperation_t transa, cublasStatus_t Xgeam<__half>(
cublasOperation_t transb, int m, int n, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
const __half* alpha, const __half* A, int lda, int m, int n, const __half* alpha, const __half* A, int lda,
const __half* beta, const __half* B, int ldb, const __half* beta, const __half* B, int ldb, __half* C, int ldc) {
__half* C, int ldc) {
// TODO(ndickson): There is no cublasHgeam, so a different // TODO(ndickson): There is no cublasHgeam, so a different
// implementation would be required. // implementation would be required.
LOG(FATAL) << "Xgeam does not support dtype half (FP16)"; LOG(FATAL) << "Xgeam does not support dtype half (FP16)";
...@@ -66,9 +68,9 @@ cublasStatus_t Xgeam<__half>(cublasHandle_t handle, cublasOperation_t transa, ...@@ -66,9 +68,9 @@ cublasStatus_t Xgeam<__half>(cublasHandle_t handle, cublasOperation_t transa,
#if BF16_ENABLED #if BF16_ENABLED
template <> template <>
cublasStatus_t Xgeam<__nv_bfloat16>(cublasHandle_t handle, cublasOperation_t transa, cublasStatus_t Xgeam<__nv_bfloat16>(
cublasOperation_t transb, int m, int n, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
const __nv_bfloat16* alpha, const __nv_bfloat16* A, int lda, int m, int n, const __nv_bfloat16* alpha, const __nv_bfloat16* A, int lda,
const __nv_bfloat16* beta, const __nv_bfloat16* B, int ldb, const __nv_bfloat16* beta, const __nv_bfloat16* B, int ldb,
__nv_bfloat16* C, int ldc) { __nv_bfloat16* C, int ldc) {
// TODO(ndickson): There is no cublasHgeam, so a different // TODO(ndickson): There is no cublasHgeam, so a different
...@@ -79,23 +81,21 @@ cublasStatus_t Xgeam<__nv_bfloat16>(cublasHandle_t handle, cublasOperation_t tra ...@@ -79,23 +81,21 @@ cublasStatus_t Xgeam<__nv_bfloat16>(cublasHandle_t handle, cublasOperation_t tra
#endif // BF16_ENABLED #endif // BF16_ENABLED
template <> template <>
cublasStatus_t Xgeam<float>(cublasHandle_t handle, cublasOperation_t transa, cublasStatus_t Xgeam<float>(
cublasOperation_t transb, int m, int n, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
const float* alpha, const float* A, int lda, int m, int n, const float* alpha, const float* A, int lda,
const float* beta, const float* B, int ldb, const float* beta, const float* B, int ldb, float* C, int ldc) {
float* C, int ldc) { return cublasSgeam(
return cublasSgeam(handle, transa, transb, m, n, alpha, A, lda, handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
beta, B, ldb, C, ldc);
} }
template <> template <>
cublasStatus_t Xgeam<double>(cublasHandle_t handle, cublasOperation_t transa, cublasStatus_t Xgeam<double>(
cublasOperation_t transb, int m, int n, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
const double* alpha, const double* A, int lda, int m, int n, const double* alpha, const double* A, int lda,
const double* beta, const double* B, int ldb, const double* beta, const double* B, int ldb, double* C, int ldc) {
double* C, int ldc) { return cublasDgeam(
return cublasDgeam(handle, transa, transb, m, n, alpha, A, lda, handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
beta, B, ldb, C, ldc);
} }
/** /**
...@@ -104,10 +104,8 @@ cublasStatus_t Xgeam<double>(cublasHandle_t handle, cublasOperation_t transa, ...@@ -104,10 +104,8 @@ cublasStatus_t Xgeam<double>(cublasHandle_t handle, cublasOperation_t transa,
*/ */
template <typename DType, typename IdType> template <typename DType, typename IdType>
__global__ void _IndexSelectKernel( __global__ void _IndexSelectKernel(
const DType* __restrict__ in, const DType* __restrict__ in, const IdType* __restrict__ idx,
const IdType* __restrict__ idx, DType* __restrict__ out, int n, int m) {
DType* __restrict__ out,
int n, int m) {
int i = blockIdx.x; int i = blockIdx.x;
for (int j = threadIdx.x; j < m; j += blockDim.x) for (int j = threadIdx.x; j < m; j += blockDim.x)
out[i * m + j] = in[idx[i] * m + j]; out[i * m + j] = in[idx[i] * m + j];
...@@ -119,9 +117,7 @@ __global__ void _IndexSelectKernel( ...@@ -119,9 +117,7 @@ __global__ void _IndexSelectKernel(
*/ */
template <typename DType> template <typename DType>
__global__ void _TransposeKernel( __global__ void _TransposeKernel(
const DType* __restrict__ in, const DType* __restrict__ in, DType* __restrict__ out, int n, int m) {
DType* __restrict__ out,
int n, int m) {
int i = blockIdx.x; int i = blockIdx.x;
for (int j = threadIdx.x; j < m; j += blockDim.x) for (int j = threadIdx.x; j < m; j += blockDim.x)
out[i * m + j] = in[j * n + i]; out[i * m + j] = in[j * n + i];
...@@ -133,8 +129,7 @@ __global__ void _TransposeKernel( ...@@ -133,8 +129,7 @@ __global__ void _TransposeKernel(
* @param col number of columns of input matrix. * @param col number of columns of input matrix.
*/ */
template <typename DType> template <typename DType>
void _Transpose(const DType* in, DType* out, void _Transpose(const DType* in, DType* out, int row, int col) {
int row, int col) {
DType alpha = 1., beta = 0.; DType alpha = 1., beta = 0.;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
...@@ -142,13 +137,8 @@ void _Transpose(const DType* in, DType* out, ...@@ -142,13 +137,8 @@ void _Transpose(const DType* in, DType* out,
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle))); CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream)); CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));
CUBLAS_CALL(Xgeam<DType>( CUBLAS_CALL(Xgeam<DType>(
thr_entry->cublas_handle, thr_entry->cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, row, col, &alpha, in,
CUBLAS_OP_T, col, &beta, nullptr, row, out, row));
CUBLAS_OP_N,
row, col,
&alpha, in, col,
&beta, nullptr, row,
out, row));
} }
/** /**
...@@ -156,8 +146,7 @@ void _Transpose(const DType* in, DType* out, ...@@ -156,8 +146,7 @@ void _Transpose(const DType* in, DType* out,
* @note cuBLAS has no geam API for half data type, fallback to our kernel. * @note cuBLAS has no geam API for half data type, fallback to our kernel.
*/ */
template <> template <>
void _Transpose<half>(const half* in, half* out, void _Transpose<half>(const half* in, half* out, int row, int col) {
int row, int col) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = FindNumThreads(row); int nt = FindNumThreads(row);
int nb = col; int nb = col;
...@@ -170,8 +159,8 @@ void _Transpose<half>(const half* in, half* out, ...@@ -170,8 +159,8 @@ void _Transpose<half>(const half* in, half* out,
* @note cuBLAS has no geam API for bf16 data type, fallback to our kernel. * @note cuBLAS has no geam API for bf16 data type, fallback to our kernel.
*/ */
template <> template <>
void _Transpose<__nv_bfloat16>(const __nv_bfloat16* in, __nv_bfloat16* out, void _Transpose<__nv_bfloat16>(
int row, int col) { const __nv_bfloat16* in, __nv_bfloat16* out, int row, int col) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = FindNumThreads(row); int nt = FindNumThreads(row);
int nb = col; int nb = col;
...@@ -183,8 +172,8 @@ void _Transpose<__nv_bfloat16>(const __nv_bfloat16* in, __nv_bfloat16* out, ...@@ -183,8 +172,8 @@ void _Transpose<__nv_bfloat16>(const __nv_bfloat16* in, __nv_bfloat16* out,
* @brief * @brief
*/ */
template <typename DType, typename IdType> template <typename DType, typename IdType>
__global__ void _IndexSelectKernel(const DType* array, const IdType* index, __global__ void _IndexSelectKernel(
int64_t length, DType* out) { const DType* array, const IdType* index, int64_t length, DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x; int stride_x = gridDim.x * blockDim.x;
while (tx < length) { while (tx < length) {
...@@ -197,7 +186,7 @@ __global__ void _IndexSelectKernel(const DType* array, const IdType* index, ...@@ -197,7 +186,7 @@ __global__ void _IndexSelectKernel(const DType* array, const IdType* index,
* @note duplicate of IndexSelect defined in array_op.h but it can * @note duplicate of IndexSelect defined in array_op.h but it can
* not be applied to float16 dtype. * not be applied to float16 dtype.
*/ */
template<typename DType, typename IdType> template <typename DType, typename IdType>
NDArray _IndexSelect(NDArray array, NDArray index) { NDArray _IndexSelect(NDArray array, NDArray index) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* array_data = static_cast<DType*>(array->data); const DType* array_data = static_cast<DType*>(array->data);
...@@ -205,65 +194,65 @@ NDArray _IndexSelect(NDArray array, NDArray index) { ...@@ -205,65 +194,65 @@ NDArray _IndexSelect(NDArray array, NDArray index) {
const int64_t arr_len = array->shape[0]; const int64_t arr_len = array->shape[0];
const int64_t len = index->shape[0]; const int64_t len = index->shape[0];
NDArray ret = NDArray::Empty({len}, array->dtype, array->ctx); NDArray ret = NDArray::Empty({len}, array->dtype, array->ctx);
if (len == 0) if (len == 0) return ret;
return ret;
DType* ret_data = static_cast<DType*>(ret->data); DType* ret_data = static_cast<DType*>(ret->data);
const int nt = FindNumThreads(len); const int nt = FindNumThreads(len);
const int nb = (len + nt - 1) / nt; const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_IndexSelectKernel, nb, nt, 0, stream, CUDA_KERNEL_CALL(
array_data, idx_data, len, ret_data); _IndexSelectKernel, nb, nt, 0, stream, array_data, idx_data, len,
ret_data);
return ret; return ret;
} }
#if CUDART_VERSION < 11000 #if CUDART_VERSION < 11000
template <typename DType> template <typename DType>
cusparseStatus_t Xcsrmm2(cusparseHandle_t handle, cusparseOperation_t transA, cusparseStatus_t Xcsrmm2(
cusparseHandle_t handle, cusparseOperation_t transA,
cusparseOperation_t transB, int m, int n, int k, int nnz, cusparseOperation_t transB, int m, int n, int k, int nnz,
const DType* alpha, const cusparseMatDescr_t descrA, const DType* alpha, const cusparseMatDescr_t descrA, const DType* csrValA,
const DType* csrValA, const int* csrRowPtrA, const int* csrColIndA, const int* csrRowPtrA, const int* csrColIndA, const DType* B, int ldb,
const DType* B, int ldb, const DType* beta, DType* C, int ldc) { const DType* beta, DType* C, int ldc) {
LOG(INFO) << "Not supported dtype"; LOG(INFO) << "Not supported dtype";
return CUSPARSE_STATUS_EXECUTION_FAILED; return CUSPARSE_STATUS_EXECUTION_FAILED;
} }
template <> template <>
cusparseStatus_t Xcsrmm2<float>(cusparseHandle_t handle, cusparseOperation_t transA, cusparseStatus_t Xcsrmm2<float>(
cusparseHandle_t handle, cusparseOperation_t transA,
cusparseOperation_t transB, int m, int n, int k, int nnz, cusparseOperation_t transB, int m, int n, int k, int nnz,
const float* alpha, const cusparseMatDescr_t descrA, const float* alpha, const cusparseMatDescr_t descrA, const float* csrValA,
const float* csrValA, const int* csrRowPtrA, const int* csrColIndA, const int* csrRowPtrA, const int* csrColIndA, const float* B, int ldb,
const float* B, int ldb, const float* beta, float* C, int ldc) { const float* beta, float* C, int ldc) {
return cusparseScsrmm2(handle, transA, transB, m, n, k, nnz, return cusparseScsrmm2(
alpha, descrA, csrValA, csrRowPtrA, csrColIndA, handle, transA, transB, m, n, k, nnz, alpha, descrA, csrValA, csrRowPtrA,
B, ldb, beta, C, ldc); csrColIndA, B, ldb, beta, C, ldc);
} }
template <> template <>
cusparseStatus_t Xcsrmm2<double>(cusparseHandle_t handle, cusparseOperation_t transA, cusparseStatus_t Xcsrmm2<double>(
cusparseHandle_t handle, cusparseOperation_t transA,
cusparseOperation_t transB, int m, int n, int k, int nnz, cusparseOperation_t transB, int m, int n, int k, int nnz,
const double* alpha, const cusparseMatDescr_t descrA, const double* alpha, const cusparseMatDescr_t descrA, const double* csrValA,
const double* csrValA, const int* csrRowPtrA, const int* csrColIndA, const int* csrRowPtrA, const int* csrColIndA, const double* B, int ldb,
const double* B, int ldb, const double* beta, double* C, int ldc) { const double* beta, double* C, int ldc) {
return cusparseDcsrmm2(handle, transA, transB, m, n, k, nnz, return cusparseDcsrmm2(
alpha, descrA, csrValA, csrRowPtrA, csrColIndA, handle, transA, transB, m, n, k, nnz, alpha, descrA, csrValA, csrRowPtrA,
B, ldb, beta, C, ldc); csrColIndA, B, ldb, beta, C, ldc);
} }
#endif #endif
/** Cusparse implementation of SpMM on Csr format. */ /** Cusparse implementation of SpMM on Csr format. */
template <typename DType, typename IdType> template <typename DType, typename IdType>
void CusparseCsrmm2( void CusparseCsrmm2(
const DGLContext& ctx, const DGLContext& ctx, const CSRMatrix& csr, const DType* B_data,
const CSRMatrix& csr, const DType* A_data, DType* C_data, int x_length) {
const DType* B_data, const DType* A_data,
DType* C_data,
int x_length) {
// We use csrmm2 to perform following operation: // We use csrmm2 to perform following operation:
// C = A x B, where A is a sparse matrix in csr format, B is the dense matrix for node // C = A x B, where A is a sparse matrix in csr format, B is the dense matrix
// feature tensor. However, since cusparse only supports column-major, while our tensor // for node feature tensor. However, since cusparse only supports
// is stored in row-major, the actual computation is: // column-major, while our tensor is stored in row-major, the actual
// C = trans(A x trans(B)). // computation is: C = trans(A x trans(B)). Currently, we use cublasXgeam to
// Currently, we use cublasXgeam to implement transposition and allocate intermediate // implement transposition and allocate intermediate workspace memory for
// workspace memory for this. // this.
const int m = csr.num_rows; const int m = csr.num_rows;
const int n = x_length; const int n = x_length;
const int k = csr.num_cols; const int k = csr.num_cols;
...@@ -282,7 +271,8 @@ void CusparseCsrmm2( ...@@ -282,7 +271,8 @@ void CusparseCsrmm2(
// all one data array // all one data array
DType* valptr = nullptr; DType* valptr = nullptr;
if (!A_data) { if (!A_data) {
valptr = static_cast<DType*>(device->AllocWorkspace(ctx, nnz * sizeof(DType))); valptr =
static_cast<DType*>(device->AllocWorkspace(ctx, nnz * sizeof(DType)));
_Fill(valptr, nnz, static_cast<DType>(1.)); _Fill(valptr, nnz, static_cast<DType>(1.));
} }
#if CUDART_VERSION >= 11000 #if CUDART_VERSION >= 11000
...@@ -290,34 +280,26 @@ void CusparseCsrmm2( ...@@ -290,34 +280,26 @@ void CusparseCsrmm2(
cusparseDnMatDescr_t matB, matC; cusparseDnMatDescr_t matB, matC;
constexpr auto dtype = cuda_dtype<DType>::value; constexpr auto dtype = cuda_dtype<DType>::value;
constexpr auto idtype = cusparse_idtype<IdType>::value; constexpr auto idtype = cusparse_idtype<IdType>::value;
CUSPARSE_CALL(cusparseCreateCsr(&matA, CUSPARSE_CALL(cusparseCreateCsr(
m, k, nnz, &matA, m, k, nnz, static_cast<IdType*>(csr.indptr->data),
static_cast<IdType*>(csr.indptr->data),
static_cast<IdType*>(csr.indices->data), static_cast<IdType*>(csr.indices->data),
const_cast<DType*>(valptr? valptr : A_data), const_cast<DType*>(valptr ? valptr : A_data), idtype, idtype,
idtype, idtype,
CUSPARSE_INDEX_BASE_ZERO, dtype)); CUSPARSE_INDEX_BASE_ZERO, dtype));
CUSPARSE_CALL(cusparseCreateDnMat(&matB, CUSPARSE_CALL(cusparseCreateDnMat(
k, n, n, &matB, k, n, n, const_cast<DType*>(B_data), dtype, CUSPARSE_ORDER_ROW));
const_cast<DType*>(B_data), dtype, CUSPARSE_ORDER_ROW)); CUSPARSE_CALL(
CUSPARSE_CALL(cusparseCreateDnMat(&matC, cusparseCreateDnMat(&matC, m, n, n, C_data, dtype, CUSPARSE_ORDER_ROW));
m, n, n,
C_data, dtype, CUSPARSE_ORDER_ROW));
auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE; auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE; auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE;
size_t workspace_size; size_t workspace_size;
CUSPARSE_CALL(cusparseSpMM_bufferSize( CUSPARSE_CALL(cusparseSpMM_bufferSize(
thr_entry->cusparse_handle, transA, transB, thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
&alpha, matA, matB, &beta, matC, matC, dtype, CUSPARSE_SPMM_CSR_ALG2, &workspace_size));
dtype, CUSPARSE_SPMM_CSR_ALG2,
&workspace_size));
void* workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUSPARSE_CALL(cusparseSpMM( CUSPARSE_CALL(cusparseSpMM(
thr_entry->cusparse_handle, transA, transB, thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
&alpha, matA, matB, &beta, matC, matC, dtype, CUSPARSE_SPMM_CSR_ALG2, workspace));
dtype, CUSPARSE_SPMM_CSR_ALG2,
workspace));
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
CUSPARSE_CALL(cusparseDestroySpMat(matA)); CUSPARSE_CALL(cusparseDestroySpMat(matA));
...@@ -325,46 +307,40 @@ void CusparseCsrmm2( ...@@ -325,46 +307,40 @@ void CusparseCsrmm2(
CUSPARSE_CALL(cusparseDestroyDnMat(matC)); CUSPARSE_CALL(cusparseDestroyDnMat(matC));
#else #else
// allocate matrix for temporary transposed output // allocate matrix for temporary transposed output
DType* trans_out = static_cast<DType*>(device->AllocWorkspace(ctx, m * n * sizeof(DType))); DType* trans_out =
static_cast<DType*>(device->AllocWorkspace(ctx, m * n * sizeof(DType)));
cusparseMatDescr_t descr; cusparseMatDescr_t descr;
CUSPARSE_CALL(cusparseCreateMatDescr(&descr)); CUSPARSE_CALL(cusparseCreateMatDescr(&descr));
CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL)); CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL));
CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO)); CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO));
CUSPARSE_CALL(Xcsrmm2<DType>( CUSPARSE_CALL(Xcsrmm2<DType>(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_TRANSPOSE, m, n, k, nnz, &alpha, descr,
CUSPARSE_OPERATION_TRANSPOSE, (valptr) ? valptr : A_data, static_cast<int32_t*>(csr.indptr->data),
m, n, k, nnz, &alpha, static_cast<int32_t*>(csr.indices->data), B_data, n, &beta, trans_out,
descr, (valptr)? valptr : A_data, m));
static_cast<int32_t*>(csr.indptr->data),
static_cast<int32_t*>(csr.indices->data),
B_data, n, &beta, trans_out, m));
CUSPARSE_CALL(cusparseDestroyMatDescr(descr)); CUSPARSE_CALL(cusparseDestroyMatDescr(descr));
// transpose the output matrix // transpose the output matrix
_Transpose(trans_out, C_data, n, m); _Transpose(trans_out, C_data, n, m);
device->FreeWorkspace(ctx, trans_out); device->FreeWorkspace(ctx, trans_out);
#endif #endif
if (valptr) if (valptr) device->FreeWorkspace(ctx, valptr);
device->FreeWorkspace(ctx, valptr);
} }
/** Cusparse implementation of SpMM on Csr format. */ /** Cusparse implementation of SpMM on Csr format. */
template <typename DType, typename IdType> template <typename DType, typename IdType>
void CusparseCsrmm2Hetero( void CusparseCsrmm2Hetero(
const DGLContext& ctx, const DGLContext& ctx, const CSRMatrix& csr, const DType* B_data,
const CSRMatrix& csr, const DType* A_data, DType* C_data, int64_t x_length,
const DType* B_data, const DType* A_data,
DType* C_data,
int64_t x_length,
cudaStream_t strm_id) { cudaStream_t strm_id) {
// We use csrmm2 to perform following operation: // We use csrmm2 to perform following operation:
// C = A x B, where A is a sparse matrix in csr format, B is the dense matrix for node // C = A x B, where A is a sparse matrix in csr format, B is the dense matrix
// feature tensor. However, since cusparse only supports column-major, while our tensor // for node feature tensor. However, since cusparse only supports
// is stored in row-major, the actual computation is: // column-major, while our tensor is stored in row-major, the actual
// C = trans(A x trans(B)). // computation is: C = trans(A x trans(B)). Currently, we use cublasXgeam to
// Currently, we use cublasXgeam to implement transposition and allocate intermediate // implement transposition and allocate intermediate workspace memory for
// workspace memory for this. // this.
int int_maxlimit = std::numeric_limits<int>::max(); int int_maxlimit = std::numeric_limits<int>::max();
CHECK_GE(int_maxlimit, (csr.num_rows)); CHECK_GE(int_maxlimit, (csr.num_rows));
CHECK_GE(int_maxlimit, csr.num_cols); CHECK_GE(int_maxlimit, csr.num_cols);
...@@ -386,7 +362,8 @@ void CusparseCsrmm2Hetero( ...@@ -386,7 +362,8 @@ void CusparseCsrmm2Hetero(
// all one data array // all one data array
DType* valptr = nullptr; DType* valptr = nullptr;
if (!A_data) { if (!A_data) {
valptr = static_cast<DType*>(device->AllocWorkspace(ctx, nnz * sizeof(DType))); valptr =
static_cast<DType*>(device->AllocWorkspace(ctx, nnz * sizeof(DType)));
_Fill(valptr, nnz, static_cast<DType>(1.)); _Fill(valptr, nnz, static_cast<DType>(1.));
} }
#if CUDART_VERSION >= 11000 #if CUDART_VERSION >= 11000
...@@ -394,34 +371,26 @@ void CusparseCsrmm2Hetero( ...@@ -394,34 +371,26 @@ void CusparseCsrmm2Hetero(
cusparseDnMatDescr_t matB, matC; cusparseDnMatDescr_t matB, matC;
constexpr auto dtype = cuda_dtype<DType>::value; constexpr auto dtype = cuda_dtype<DType>::value;
constexpr auto idtype = cusparse_idtype<IdType>::value; constexpr auto idtype = cusparse_idtype<IdType>::value;
CUSPARSE_CALL(cusparseCreateCsr(&matA, CUSPARSE_CALL(cusparseCreateCsr(
m, k, nnz, &matA, m, k, nnz, static_cast<IdType*>(csr.indptr->data),
static_cast<IdType*>(csr.indptr->data),
static_cast<IdType*>(csr.indices->data), static_cast<IdType*>(csr.indices->data),
const_cast<DType*>(valptr? valptr : A_data), const_cast<DType*>(valptr ? valptr : A_data), idtype, idtype,
idtype, idtype,
CUSPARSE_INDEX_BASE_ZERO, dtype)); CUSPARSE_INDEX_BASE_ZERO, dtype));
CUSPARSE_CALL(cusparseCreateDnMat(&matB, CUSPARSE_CALL(cusparseCreateDnMat(
k, n, n, &matB, k, n, n, const_cast<DType*>(B_data), dtype, CUSPARSE_ORDER_ROW));
const_cast<DType*>(B_data), dtype, CUSPARSE_ORDER_ROW)); CUSPARSE_CALL(
CUSPARSE_CALL(cusparseCreateDnMat(&matC, cusparseCreateDnMat(&matC, m, n, n, C_data, dtype, CUSPARSE_ORDER_ROW));
m, n, n,
C_data, dtype, CUSPARSE_ORDER_ROW));
auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE; auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE; auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE;
size_t workspace_size; size_t workspace_size;
CUSPARSE_CALL(cusparseSpMM_bufferSize( CUSPARSE_CALL(cusparseSpMM_bufferSize(
thr_entry->cusparse_handle, transA, transB, thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
&alpha, matA, matB, &beta, matC, matC, dtype, CUSPARSE_SPMM_CSR_ALG2, &workspace_size));
dtype, CUSPARSE_SPMM_CSR_ALG2,
&workspace_size));
void* workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUSPARSE_CALL(cusparseSpMM( CUSPARSE_CALL(cusparseSpMM(
thr_entry->cusparse_handle, transA, transB, thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
&alpha, matA, matB, &beta, matC, matC, dtype, CUSPARSE_SPMM_CSR_ALG2, workspace));
dtype, CUSPARSE_SPMM_CSR_ALG2,
workspace));
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
CUSPARSE_CALL(cusparseDestroySpMat(matA)); CUSPARSE_CALL(cusparseDestroySpMat(matA));
...@@ -434,74 +403,63 @@ void CusparseCsrmm2Hetero( ...@@ -434,74 +403,63 @@ void CusparseCsrmm2Hetero(
CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO)); CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO));
CHECK_EQ(sizeof(IdType), sizeof(int32_t)); CHECK_EQ(sizeof(IdType), sizeof(int32_t));
CUSPARSE_CALL(Xcsrmm2<DType>( CUSPARSE_CALL(Xcsrmm2<DType>(
thr_entry->cusparse_handle, thr_entry->cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_NON_TRANSPOSE, CUSPARSE_OPERATION_TRANSPOSE, m, n, k, nnz, &alpha, descr,
CUSPARSE_OPERATION_TRANSPOSE, (valptr) ? valptr : A_data, static_cast<int32_t*>(csr.indptr->data),
m, n, k, nnz, &alpha, static_cast<int32_t*>(csr.indices->data), B_data, n, &beta, C_data, m));
descr, (valptr)? valptr : A_data,
static_cast<int32_t*>(csr.indptr->data),
static_cast<int32_t*>(csr.indices->data),
B_data, n, &beta, C_data, m));
CUSPARSE_CALL(cusparseDestroyMatDescr(descr)); CUSPARSE_CALL(cusparseDestroyMatDescr(descr));
#endif #endif
if (valptr) if (valptr) device->FreeWorkspace(ctx, valptr);
device->FreeWorkspace(ctx, valptr);
} }
} // namespace } // namespace
#define SWITCH_OP(op, Op, ...) \ #define SWITCH_OP(op, Op, ...) \
do { \ do { \
if ((op) == "add") { \ if ((op) == "add") { \
typedef cuda::binary::Add<DType> Op; \ typedef cuda::binary::Add<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((op) == "sub") { \ } else if ((op) == "sub") { \
typedef cuda::binary::Sub<DType> Op; \ typedef cuda::binary::Sub<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((op) == "mul") { \ } else if ((op) == "mul") { \
typedef cuda::binary::Mul<DType> Op; \ typedef cuda::binary::Mul<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((op) == "div") { \ } else if ((op) == "div") { \
typedef cuda::binary::Div<DType> Op; \ typedef cuda::binary::Div<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((op) == "copy_lhs") { \ } else if ((op) == "copy_lhs") { \
typedef cuda::binary::CopyLhs<DType> Op; \ typedef cuda::binary::CopyLhs<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else if ((op) == "copy_rhs") { \ } else if ((op) == "copy_rhs") { \
typedef cuda::binary::CopyRhs<DType> Op; \ typedef cuda::binary::CopyRhs<DType> Op; \
{ __VA_ARGS__ } \ { __VA_ARGS__ } \
} else { \ } else { \
LOG(FATAL) << "Unsupported SpMM binary operator: " << op; \ LOG(FATAL) << "Unsupported SpMM binary operator: " << op; \
} \ } \
} while (0) } while (0)
namespace cuda { namespace cuda {
/** /**
* @brief CUDA kernel of g-SpMM on Coo format. * @brief CUDA kernel of g-SpMM on Coo format.
* @note it uses edge parallel strategy, different threadblocks (on y-axis) * @note it uses edge parallel strategy, different threadblocks (on y-axis)
* is responsible for the computation on different edges. Threadblocks * is responsible for the computation on different edges. Threadblocks
* on the x-axis are responsible for the computation on different positions * on the x-axis are responsible for the computation on different
* in feature dimension. * positions in feature dimension. To avoid possible data hazards, it uses
* To avoid possible data hazards, it uses atomic operators for reduction. * atomic operators for reduction.
*/ */
template <typename Idx, typename DType, template <
typename BinaryOp, typename ReduceOp, typename Idx, typename DType, typename BinaryOp, typename ReduceOp,
bool UseBcast = false, bool UseIdx = false> bool UseBcast = false, bool UseIdx = false>
__global__ void SpMMCooKernel( __global__ void SpMMCooKernel(
const DType* __restrict__ ufeat, const DType* __restrict__ ufeat, const DType* __restrict__ efeat,
const DType* __restrict__ efeat, DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,
DType* __restrict__ out, const Idx* __restrict__ row, const Idx* __restrict__ col,
Idx* __restrict__ arg_u, const Idx* __restrict__ edge_map, int64_t N, int64_t M, int64_t E,
Idx* __restrict__ arg_e, const int64_t* __restrict__ ubcast_off,
const Idx* __restrict__ row, const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,
const Idx* __restrict__ col, int64_t efeat_len, int64_t out_len) {
const Idx* __restrict__ edge_map,
int64_t N, int64_t M, int64_t E,
const int64_t* __restrict__ ubcast_off,
const int64_t* __restrict__ ebcast_off,
int64_t ufeat_len, int64_t efeat_len, int64_t out_len) {
// SPMM with COO. // SPMM with COO.
Idx ty = blockIdx.y * blockDim.y + threadIdx.y; Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.y; const Idx stride_y = blockDim.y * gridDim.y;
...@@ -511,8 +469,8 @@ __global__ void SpMMCooKernel( ...@@ -511,8 +469,8 @@ __global__ void SpMMCooKernel(
const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty; const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t stride_x = blockDim.x * gridDim.x; const int64_t stride_x = blockDim.x * gridDim.x;
const DType* uoff = BinaryOp::use_lhs ? (ufeat + src * ufeat_len): nullptr; const DType* uoff = BinaryOp::use_lhs ? (ufeat + src * ufeat_len) : nullptr;
const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len): nullptr; const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;
DType* outoff = out + dst * out_len; DType* outoff = out + dst * out_len;
while (tx < out_len) { while (tx < out_len) {
const int64_t lhs_add = UseBcast ? ubcast_off[tx] : tx; const int64_t lhs_add = UseBcast ? ubcast_off[tx] : tx;
...@@ -531,25 +489,20 @@ __global__ void SpMMCooKernel( ...@@ -531,25 +489,20 @@ __global__ void SpMMCooKernel(
* @brief CUDA kernel to compute argu and arge in g-SpMM on Coo format. * @brief CUDA kernel to compute argu and arge in g-SpMM on Coo format.
* @note it uses edge parallel strategy, different threadblocks (on y-axis) * @note it uses edge parallel strategy, different threadblocks (on y-axis)
* is responsible for the computation on different edges. Threadblocks * is responsible for the computation on different edges. Threadblocks
* on the x-axis are responsible for the computation on different positions * on the x-axis are responsible for the computation on different
* in feature dimension. * positions in feature dimension.
*/ */
template <typename Idx, typename DType, template <
typename BinaryOp, typename ReduceOp, typename Idx, typename DType, typename BinaryOp, typename ReduceOp,
bool UseBcast = false, bool UseIdx = false> bool UseBcast = false, bool UseIdx = false>
__global__ void ArgSpMMCooKernel( __global__ void ArgSpMMCooKernel(
const DType* __restrict__ ufeat, const DType* __restrict__ ufeat, const DType* __restrict__ efeat,
const DType* __restrict__ efeat, DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,
DType* __restrict__ out, const Idx* __restrict__ row, const Idx* __restrict__ col,
Idx* __restrict__ arg_u, const Idx* __restrict__ edge_map, int64_t N, int64_t M, int64_t E,
Idx* __restrict__ arg_e, const int64_t* __restrict__ ubcast_off,
const Idx* __restrict__ row, const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,
const Idx* __restrict__ col, int64_t efeat_len, int64_t out_len) {
const Idx* __restrict__ edge_map,
int64_t N, int64_t M, int64_t E,
const int64_t* __restrict__ ubcast_off,
const int64_t* __restrict__ ebcast_off,
int64_t ufeat_len, int64_t efeat_len, int64_t out_len) {
// SPMM with COO arg max/min. // SPMM with COO arg max/min.
Idx ty = blockIdx.y * blockDim.y + threadIdx.y; Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.y; const Idx stride_y = blockDim.y * gridDim.y;
...@@ -559,11 +512,11 @@ __global__ void ArgSpMMCooKernel( ...@@ -559,11 +512,11 @@ __global__ void ArgSpMMCooKernel(
const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty; const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x; int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t stride_x = blockDim.x * gridDim.x; const int64_t stride_x = blockDim.x * gridDim.x;
const DType* uoff = BinaryOp::use_lhs ? (ufeat + src * ufeat_len): nullptr; const DType* uoff = BinaryOp::use_lhs ? (ufeat + src * ufeat_len) : nullptr;
const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len): nullptr; const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;
const DType* outoff = out + dst * out_len; const DType* outoff = out + dst * out_len;
Idx* arguoff = BinaryOp::use_lhs ? (arg_u + dst * out_len): nullptr; Idx* arguoff = BinaryOp::use_lhs ? (arg_u + dst * out_len) : nullptr;
Idx* argeoff = BinaryOp::use_rhs ? (arg_e + dst * out_len): nullptr; Idx* argeoff = BinaryOp::use_rhs ? (arg_e + dst * out_len) : nullptr;
while (tx < out_len) { while (tx < out_len) {
int64_t lhs_add = UseBcast ? ubcast_off[tx] : tx; int64_t lhs_add = UseBcast ? ubcast_off[tx] : tx;
int64_t rhs_add = UseBcast ? ebcast_off[tx] : tx; int64_t rhs_add = UseBcast ? ebcast_off[tx] : tx;
...@@ -582,22 +535,17 @@ __global__ void ArgSpMMCooKernel( ...@@ -582,22 +535,17 @@ __global__ void ArgSpMMCooKernel(
* Threadblocks on the x-axis are responsible for the computation on * Threadblocks on the x-axis are responsible for the computation on
* different positions in feature dimension. * different positions in feature dimension.
*/ */
template <typename Idx, typename DType, template <
typename BinaryOp, typename ReduceOp, typename Idx, typename DType, typename BinaryOp, typename ReduceOp,
bool UseBcast = false, bool UseIdx = false> bool UseBcast = false, bool UseIdx = false>
__global__ void SpMMCsrKernel( __global__ void SpMMCsrKernel(
const DType* __restrict__ ufeat, const DType* __restrict__ ufeat, const DType* __restrict__ efeat,
const DType* __restrict__ efeat, DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,
DType* __restrict__ out, const Idx* __restrict__ indptr, const Idx* __restrict__ indices,
Idx* __restrict__ arg_u, const Idx* __restrict__ edge_map, int64_t num_rows, int64_t num_cols,
Idx* __restrict__ arg_e, const int64_t* __restrict__ ubcast_off,
const Idx* __restrict__ indptr, const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,
const Idx* __restrict__ indices, int64_t efeat_len, int64_t out_len) {
const Idx* __restrict__ edge_map,
int64_t num_rows, int64_t num_cols,
const int64_t* __restrict__ ubcast_off,
const int64_t* __restrict__ ebcast_off,
int64_t ufeat_len, int64_t efeat_len, int64_t out_len) {
// SPMM with CSR. // SPMM with CSR.
int ty = blockIdx.x * blockDim.y + threadIdx.y; int ty = blockIdx.x * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.x; const Idx stride_y = blockDim.y * gridDim.x;
...@@ -612,16 +560,19 @@ __global__ void SpMMCsrKernel( ...@@ -612,16 +560,19 @@ __global__ void SpMMCsrKernel(
for (Idx i = indptr[ty]; i < indptr[ty + 1]; ++i) { for (Idx i = indptr[ty]; i < indptr[ty + 1]; ++i) {
const Idx eid = UseIdx ? _ldg(edge_map + i) : i; const Idx eid = UseIdx ? _ldg(edge_map + i) : i;
const Idx cid = _ldg(indices + i); const Idx cid = _ldg(indices + i);
const DType* uoff = BinaryOp::use_lhs ? (ufeat + cid * ufeat_len): nullptr; const DType* uoff =
const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len): nullptr; BinaryOp::use_lhs ? (ufeat + cid * ufeat_len) : nullptr;
const DType* eoff =
BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;
DType out = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add); DType out = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);
ReduceOp::Call(&local_accum, &local_argu, &local_arge, out, cid, eid); ReduceOp::Call(&local_accum, &local_argu, &local_arge, out, cid, eid);
} }
// The use of += is to compute cross-type reducing on heterogeneous graph // The use of += is to compute cross-type reducing on heterogeneous graph
// when reduce op is `sum`. // when reduce op is `sum`.
// C = SpMM(SpA, B) + C // C = SpMM(SpA, B) + C
// Separate kernel `SpMMCmpCsrHeteroKernel` is used for max- and min-reducer. It // Separate kernel `SpMMCmpCsrHeteroKernel` is used for max- and
// does not affect the output on homogeneous graph as `out` is initialized to zero. // min-reducer. It does not affect the output on homogeneous graph as
// `out` is initialized to zero.
out[ty * out_len + tx] += local_accum; out[ty * out_len + tx] += local_accum;
if (ReduceOp::require_arg && BinaryOp::use_lhs) if (ReduceOp::require_arg && BinaryOp::use_lhs)
arg_u[ty * out_len + tx] = local_argu; arg_u[ty * out_len + tx] = local_argu;
...@@ -640,23 +591,18 @@ __global__ void SpMMCsrKernel( ...@@ -640,23 +591,18 @@ __global__ void SpMMCsrKernel(
* Threadblocks on the x-axis are responsible for the computation on * Threadblocks on the x-axis are responsible for the computation on
* different positions in feature dimension. * different positions in feature dimension.
*/ */
template <typename Idx, typename DType, template <
typename BinaryOp, typename ReduceOp, typename Idx, typename DType, typename BinaryOp, typename ReduceOp,
bool UseBcast = false, bool UseIdx = false> bool UseBcast = false, bool UseIdx = false>
__global__ void SpMMCmpCsrHeteroKernel( __global__ void SpMMCmpCsrHeteroKernel(
const DType* __restrict__ ufeat, const DType* __restrict__ ufeat, const DType* __restrict__ efeat,
const DType* __restrict__ efeat, DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,
DType* __restrict__ out, Idx* __restrict__ arg_u_ntype, Idx* __restrict__ arg_e_etype,
Idx* __restrict__ arg_u, Idx* __restrict__ arg_e, const Idx* __restrict__ indptr, const Idx* __restrict__ indices,
Idx* __restrict__ arg_u_ntype, Idx* __restrict__ arg_e_etype, const Idx* __restrict__ edge_map, int64_t num_rows, int64_t num_cols,
const Idx* __restrict__ indptr, const int64_t* __restrict__ ubcast_off,
const Idx* __restrict__ indices, const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,
const Idx* __restrict__ edge_map, int64_t efeat_len, int64_t out_len, const int src_type, const int etype) {
int64_t num_rows, int64_t num_cols,
const int64_t* __restrict__ ubcast_off,
const int64_t* __restrict__ ebcast_off,
int64_t ufeat_len, int64_t efeat_len, int64_t out_len,
const int src_type, const int etype) {
// SPMM with CSR. // SPMM with CSR.
int ty = blockIdx.y * blockDim.y + threadIdx.y; int ty = blockIdx.y * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.y; const Idx stride_y = blockDim.y * gridDim.y;
...@@ -671,12 +617,15 @@ __global__ void SpMMCmpCsrHeteroKernel( ...@@ -671,12 +617,15 @@ __global__ void SpMMCmpCsrHeteroKernel(
for (Idx i = indptr[ty]; i < indptr[ty + 1]; ++i) { for (Idx i = indptr[ty]; i < indptr[ty + 1]; ++i) {
const Idx eid = UseIdx ? _ldg(edge_map + i) : i; const Idx eid = UseIdx ? _ldg(edge_map + i) : i;
const Idx cid = _ldg(indices + i); const Idx cid = _ldg(indices + i);
const DType* uoff = BinaryOp::use_lhs ? (ufeat + cid * ufeat_len): nullptr; const DType* uoff =
const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len): nullptr; BinaryOp::use_lhs ? (ufeat + cid * ufeat_len) : nullptr;
const DType* eoff =
BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;
DType tmp_out = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add); DType tmp_out = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);
ReduceOp::Call(&new_out, &local_argu, &local_arge, tmp_out, cid, eid); ReduceOp::Call(&new_out, &local_argu, &local_arge, tmp_out, cid, eid);
} }
// Update output only when max/min values are different that original output // Update output only when max/min values are different that original
// output
if (out[ty * out_len + tx] != new_out) { if (out[ty * out_len + tx] != new_out) {
out[ty * out_len + tx] = new_out; out[ty * out_len + tx] = new_out;
if (ReduceOp::require_arg && BinaryOp::use_lhs) { if (ReduceOp::require_arg && BinaryOp::use_lhs) {
...@@ -703,17 +652,16 @@ __global__ void SpMMCmpCsrHeteroKernel( ...@@ -703,17 +652,16 @@ __global__ void SpMMCmpCsrHeteroKernel(
* @param out The result feature on destination nodes. * @param out The result feature on destination nodes.
* @param argu Arg-Min/Max on source nodes, which refers the source node indices * @param argu Arg-Min/Max on source nodes, which refers the source node indices
* correspond to the minimum/maximum values of reduction result on * correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer. * destination nodes. It's useful in computing gradients of Min/Max
* reducer.
* @param arge Arg-Min/Max on edges. which refers the source node indices * @param arge Arg-Min/Max on edges. which refers the source node indices
* correspond to the minimum/maximum values of reduction result on * correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer. * destination nodes. It's useful in computing gradients of Min/Max
* reducer.
*/ */
template <typename Idx, typename DType, template <typename Idx, typename DType, typename BinaryOp, typename ReduceOp>
typename BinaryOp, typename ReduceOp>
void SpMMCoo( void SpMMCoo(
const BcastOff& bcast, const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
const COOMatrix& coo,
NDArray ufeat, NDArray efeat,
NDArray out, NDArray argu, NDArray arge) { NDArray out, NDArray argu, NDArray arge) {
#if defined(CUDART_VERSION) && CUDART_VERSION <= 10000 #if defined(CUDART_VERSION) && CUDART_VERSION <= 10000
if (std::is_same<DType, half>::value) if (std::is_same<DType, half>::value)
...@@ -721,27 +669,23 @@ void SpMMCoo( ...@@ -721,27 +669,23 @@ void SpMMCoo(
<< "for float16 in CUDA 10.0. Please upgrade your CUDA " << "for float16 in CUDA 10.0. Please upgrade your CUDA "
<< "to later versions."; << "to later versions.";
#endif #endif
const Idx *row = coo.row.Ptr<Idx>(), const Idx *row = coo.row.Ptr<Idx>(), *col = coo.col.Ptr<Idx>(),
*col = coo.col.Ptr<Idx>(),
*edge_map = coo.data.Ptr<Idx>(); *edge_map = coo.data.Ptr<Idx>();
const DType *ufeat_data = ufeat.Ptr<DType>(), const DType *ufeat_data = ufeat.Ptr<DType>(),
*efeat_data = efeat.Ptr<DType>(); *efeat_data = efeat.Ptr<DType>();
DType *out_data = out.Ptr<DType>(); DType* out_data = out.Ptr<DType>();
Idx *argu_data = argu.Ptr<Idx>(), Idx *argu_data = argu.Ptr<Idx>(), *arge_data = arge.Ptr<Idx>();
*arge_data = arge.Ptr<Idx>();
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
const int64_t N = coo.num_rows, M = coo.num_cols, E = coo.row->shape[0]; const int64_t N = coo.num_rows, M = coo.num_cols, E = coo.row->shape[0];
int64_t *ubcast_off = nullptr, *ebcast_off = nullptr; int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
int64_t len = bcast.out_len, int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len;
int64_t out_size = out.NumElements(); int64_t out_size = out.NumElements();
const int nt = FindNumThreads(out_size); const int nt = FindNumThreads(out_size);
const int nb = (out_size + nt - 1) / nt; const int nb = (out_size + nt - 1) / nt;
CUDA_KERNEL_CALL(_FillKernel, nb, nt, 0, stream, CUDA_KERNEL_CALL(
out_data, out_size, ReduceOp::zero()); _FillKernel, nb, nt, 0, stream, out_data, out_size, ReduceOp::zero());
const int ntx = FindNumThreads(len); const int ntx = FindNumThreads(len);
const int nty = CUDA_MAX_NUM_THREADS / ntx; const int nty = CUDA_MAX_NUM_THREADS / ntx;
...@@ -752,20 +696,16 @@ void SpMMCoo( ...@@ -752,20 +696,16 @@ void SpMMCoo(
const bool use_idx = !IsNullArray(coo.data); const bool use_idx = !IsNullArray(coo.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, { BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, {
CUDA_KERNEL_CALL((SpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>), CUDA_KERNEL_CALL(
nblks, nthrs, 0, stream, (SpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
ufeat_data, efeat_data, out_data, argu_data, arge_data, nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,
row, col, edge_map, arge_data, row, col, edge_map, N, M, E, ubcast_off, ebcast_off, lhs_len,
N, M, E, rhs_len, len);
ubcast_off, ebcast_off,
lhs_len, rhs_len, len);
if (ReduceOp::require_arg) { if (ReduceOp::require_arg) {
CUDA_KERNEL_CALL((ArgSpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>), CUDA_KERNEL_CALL(
nblks, nthrs, 0, stream, (ArgSpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
ufeat_data, efeat_data, out_data, argu_data, arge_data, nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,
row, col, edge_map, arge_data, row, col, edge_map, N, M, E, ubcast_off, ebcast_off,
N, M, E,
ubcast_off, ebcast_off,
lhs_len, rhs_len, len); lhs_len, rhs_len, len);
} }
}); });
...@@ -780,33 +720,30 @@ void SpMMCoo( ...@@ -780,33 +720,30 @@ void SpMMCoo(
* @param out The result feature on destination nodes. * @param out The result feature on destination nodes.
* @param argu Arg-Min/Max on source nodes, which refers the source node indices * @param argu Arg-Min/Max on source nodes, which refers the source node indices
* correspond to the minimum/maximum values of reduction result on * correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer. * destination nodes. It's useful in computing gradients of Min/Max
* reducer.
* @param arge Arg-Min/Max on edges. which refers the source node indices * @param arge Arg-Min/Max on edges. which refers the source node indices
* correspond to the minimum/maximum values of reduction result on * correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer. * destination nodes. It's useful in computing gradients of Min/Max
* reducer.
*/ */
template <typename Idx, typename DType, template <typename Idx, typename DType, typename BinaryOp, typename ReduceOp>
typename BinaryOp, typename ReduceOp>
void SpMMCsr( void SpMMCsr(
const BcastOff& bcast, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
const CSRMatrix& csr,
NDArray ufeat, NDArray efeat,
NDArray out, NDArray argu, NDArray arge) { NDArray out, NDArray argu, NDArray arge) {
const Idx *indptr = csr.indptr.Ptr<Idx>(); const Idx* indptr = csr.indptr.Ptr<Idx>();
const Idx *indices = csr.indices.Ptr<Idx>(); const Idx* indices = csr.indices.Ptr<Idx>();
const Idx *edge_map = csr.data.Ptr<Idx>(); const Idx* edge_map = csr.data.Ptr<Idx>();
const DType *ufeat_data = ufeat.Ptr<DType>(); const DType* ufeat_data = ufeat.Ptr<DType>();
const DType *efeat_data = efeat.Ptr<DType>(); const DType* efeat_data = efeat.Ptr<DType>();
DType *out_data = out.Ptr<DType>(); DType* out_data = out.Ptr<DType>();
Idx* argu_data = argu.Ptr<Idx>(); Idx* argu_data = argu.Ptr<Idx>();
Idx* arge_data = arge.Ptr<Idx>(); Idx* arge_data = arge.Ptr<Idx>();
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t *ubcast_off = nullptr, *ebcast_off = nullptr; int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
int64_t len = bcast.out_len, int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len;
const int ntx = FindNumThreads(len); const int ntx = FindNumThreads(len);
const int nty = CUDA_MAX_NUM_THREADS / ntx; const int nty = CUDA_MAX_NUM_THREADS / ntx;
const int nby = (len + ntx - 1) / ntx; const int nby = (len + ntx - 1) / ntx;
...@@ -815,15 +752,13 @@ void SpMMCsr( ...@@ -815,15 +752,13 @@ void SpMMCsr(
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
const bool use_idx = !IsNullArray(csr.data); const bool use_idx = !IsNullArray(csr.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, { BCAST_IDX_CTX_SWITCH(
CUDA_KERNEL_CALL((SpMMCsrKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>), bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off,
nblks, nthrs, 0, stream, {CUDA_KERNEL_CALL(
ufeat_data, efeat_data, out_data, argu_data, arge_data, (SpMMCsrKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
indptr, indices, edge_map, nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,
csr.num_rows, csr.num_cols, arge_data, indptr, indices, edge_map, csr.num_rows, csr.num_cols,
ubcast_off, ebcast_off, ubcast_off, ebcast_off, lhs_len, rhs_len, len)});
lhs_len, rhs_len, len)
});
} }
/** /**
...@@ -835,43 +770,41 @@ void SpMMCsr( ...@@ -835,43 +770,41 @@ void SpMMCsr(
* @param out The result feature on destination nodes. * @param out The result feature on destination nodes.
* @param argu Arg-Min/Max on source nodes, which refers the source node indices * @param argu Arg-Min/Max on source nodes, which refers the source node indices
* correspond to the minimum/maximum values of reduction result on * correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer. * destination nodes. It's useful in computing gradients of Min/Max
* reducer.
* @param arge Arg-Min/Max on edges. which refers the source node indices * @param arge Arg-Min/Max on edges. which refers the source node indices
* correspond to the minimum/maximum values of reduction result on * correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer. * destination nodes. It's useful in computing gradients of Min/Max
* @param argu_ntype Node type of the arg-Min/Max on source nodes, which refers the * reducer.
* source node types correspond to the minimum/maximum values of reduction result * @param argu_ntype Node type of the arg-Min/Max on source nodes, which refers
* on destination nodes. It's useful in computing gradients of Min/Max reducer. * the source node types correspond to the minimum/maximum values of reduction
* @param arge_etype Edge-type of the arg-Min/Max on edges. which refers the source * result on destination nodes. It's useful in computing gradients of Min/Max
* node indices correspond to the minimum/maximum values of reduction result on * reducer.
* destination nodes. It's useful in computing gradients of Min/Max reducer. * @param arge_etype Edge-type of the arg-Min/Max on edges. which refers the
* source node indices correspond to the minimum/maximum values of reduction
* result on destination nodes. It's useful in computing gradients of Min/Max
* reducer.
* @param src_type Node type of the source nodes of an etype * @param src_type Node type of the source nodes of an etype
* @param etype Edge type * @param etype Edge type
*/ */
template <typename Idx, typename DType, template <typename Idx, typename DType, typename BinaryOp, typename ReduceOp>
typename BinaryOp, typename ReduceOp>
void SpMMCmpCsrHetero( void SpMMCmpCsrHetero(
const BcastOff& bcast, const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
const CSRMatrix& csr, NDArray out, NDArray argu, NDArray arge, NDArray argu_ntype,
NDArray ufeat, NDArray efeat, NDArray arge_etype, const int src_type, const int etype) {
NDArray out, NDArray argu, NDArray arge, const Idx* indptr = csr.indptr.Ptr<Idx>();
NDArray argu_ntype, NDArray arge_etype, const Idx* indices = csr.indices.Ptr<Idx>();
const int src_type, const int etype) { const Idx* edge_map = csr.data.Ptr<Idx>();
const Idx *indptr = csr.indptr.Ptr<Idx>(); const DType* ufeat_data = ufeat.Ptr<DType>();
const Idx *indices = csr.indices.Ptr<Idx>(); const DType* efeat_data = efeat.Ptr<DType>();
const Idx *edge_map = csr.data.Ptr<Idx>(); DType* out_data = out.Ptr<DType>();
const DType *ufeat_data = ufeat.Ptr<DType>();
const DType *efeat_data = efeat.Ptr<DType>();
DType *out_data = out.Ptr<DType>();
Idx* argu_data = argu.Ptr<Idx>(); Idx* argu_data = argu.Ptr<Idx>();
Idx* arge_data = arge.Ptr<Idx>(); Idx* arge_data = arge.Ptr<Idx>();
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t *ubcast_off = nullptr, *ebcast_off = nullptr; int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
int64_t len = bcast.out_len, int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len;
const int ntx = FindNumThreads(len); const int ntx = FindNumThreads(len);
const int nty = CUDA_MAX_NUM_THREADS / ntx; const int nty = CUDA_MAX_NUM_THREADS / ntx;
const int nbx = (len + ntx - 1) / ntx; const int nbx = (len + ntx - 1) / ntx;
...@@ -880,20 +813,18 @@ void SpMMCmpCsrHetero( ...@@ -880,20 +813,18 @@ void SpMMCmpCsrHetero(
const dim3 nthrs(ntx, nty); const dim3 nthrs(ntx, nty);
const bool use_idx = !IsNullArray(csr.data); const bool use_idx = !IsNullArray(csr.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, { BCAST_IDX_CTX_SWITCH(
CUDA_KERNEL_CALL((SpMMCmpCsrHeteroKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>), bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off,
nblks, nthrs, 0, stream, {CUDA_KERNEL_CALL(
ufeat_data, efeat_data, out_data, argu_data, arge_data, (SpMMCmpCsrHeteroKernel<
static_cast<Idx*>(argu_ntype->data), Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
static_cast<Idx*>(arge_etype->data), nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,
indptr, indices, edge_map, arge_data, static_cast<Idx*>(argu_ntype->data),
csr.num_rows, csr.num_cols, static_cast<Idx*>(arge_etype->data), indptr, indices, edge_map,
ubcast_off, ebcast_off, csr.num_rows, csr.num_cols, ubcast_off, ebcast_off, lhs_len, rhs_len,
lhs_len, rhs_len, len, src_type, etype) len, src_type, etype)});
});
} }
} // namespace cuda } // namespace cuda
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
......
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
* @brief SPMM C APIs and definitions. * @brief SPMM C APIs and definitions.
*/ */
#include <dgl/array.h> #include <dgl/array.h>
#include "./spmm.cuh"
#include "./ge_spmm.cuh"
#include "./functor.cuh"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./functor.cuh"
#include "./ge_spmm.cuh"
#include "./spmm.cuh"
namespace dgl { namespace dgl {
...@@ -21,16 +22,16 @@ namespace aten { ...@@ -21,16 +22,16 @@ namespace aten {
* no broadcast, use dgl's kernel in other cases. * no broadcast, use dgl's kernel in other cases.
*/ */
template <int XPU, typename IdType, typename DType> template <int XPU, typename IdType, typename DType>
void SpMMCsrHetero(const std::string& op, const std::string& reduce, void SpMMCsrHetero(
const BcastOff& bcast, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_ufeat, const std::vector<NDArray>& vec_ufeat,
const std::vector<NDArray>& vec_efeat, const std::vector<NDArray>& vec_efeat, std::vector<NDArray>* vec_out,
std::vector<NDArray>* vec_out, std::vector<std::vector<NDArray>>* out_aux,
std::vector<std::vector<NDArray>>* out_aux, const std::vector<dgl_type_t>& ufeat_ntids, // ufeat node type id
const std::vector<dgl_type_t>& ufeat_ntids, // ufeat node type id const std::vector<dgl_type_t>& out_ntids) { // output node type id
const std::vector<dgl_type_t>& out_ntids) { // output node type id bool is_scalar_efeat =
bool is_scalar_efeat = vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0]; vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0];
bool use_efeat = op != "copy_lhs"; bool use_efeat = op != "copy_lhs";
auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx); auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);
std::vector<DType*> trans_out((*vec_out).size(), NULL); std::vector<DType*> trans_out((*vec_out).size(), NULL);
...@@ -39,15 +40,16 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -39,15 +40,16 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
(CUDART_VERSION < 11000) && (reduce == "sum") && (CUDART_VERSION < 11000) && (reduce == "sum") &&
// legacy cuSPARSE does not care about NNZ, hence the argument "false". // legacy cuSPARSE does not care about NNZ, hence the argument "false".
((op == "copy_lhs" && cusparse_available<DType, IdType>(false)) || ((op == "copy_lhs" && cusparse_available<DType, IdType>(false)) ||
(op == "mul" && is_scalar_efeat && cusparse_available<DType, IdType>(false))); (op == "mul" && is_scalar_efeat &&
cusparse_available<DType, IdType>(false)));
// Create temporary output buffer to store non-transposed output // Create temporary output buffer to store non-transposed output
if (use_legacy_cusparsemm) { if (use_legacy_cusparsemm) {
for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) { for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
const int m = (*vec_out)[ntype]->shape[0]; const int m = (*vec_out)[ntype]->shape[0];
const int n = (*vec_out)[ntype]->shape[1]; const int n = (*vec_out)[ntype]->shape[1];
if (m == 0) continue; if (m == 0) continue;
DType *out = static_cast<DType*>(device->AllocWorkspace(vec_csr[0].indptr->ctx, DType* out = static_cast<DType*>(device->AllocWorkspace(
m * n * sizeof(DType))); vec_csr[0].indptr->ctx, m * n * sizeof(DType)));
CUDA_CALL(cudaMemset(out, 0, m * n * sizeof(DType))); CUDA_CALL(cudaMemset(out, 0, m * n * sizeof(DType)));
trans_out[ntype] = out; trans_out[ntype] = out;
} }
...@@ -57,43 +59,53 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -57,43 +59,53 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
for (dgl_type_t etype = 0; etype < (ufeat_ntids.size() - 1); ++etype) { for (dgl_type_t etype = 0; etype < (ufeat_ntids.size() - 1); ++etype) {
NDArray ufeat = vec_ufeat[ufeat_ntids[etype]]; NDArray ufeat = vec_ufeat[ufeat_ntids[etype]];
NDArray next_ufeat = vec_ufeat[ufeat_ntids[etype + 1]]; NDArray next_ufeat = vec_ufeat[ufeat_ntids[etype + 1]];
CHECK_EQ(ufeat->ndim, next_ufeat->ndim) << "Input features have different shapes"; CHECK_EQ(ufeat->ndim, next_ufeat->ndim)
<< "Input features have different shapes";
for (int i = 1; i < ufeat->ndim; ++i) { for (int i = 1; i < ufeat->ndim; ++i) {
if (ufeat->shape[i] != next_ufeat->shape[i]) { if (ufeat->shape[i] != next_ufeat->shape[i]) {
if (ufeat->shape[i] == 1 || next_ufeat->shape[i] == 1) if (ufeat->shape[i] == 1 || next_ufeat->shape[i] == 1)
LOG(FATAL) << LOG(FATAL) << "Homogenized message passing on heterogeneous graphs "
"Homogenized message passing on heterogeneous graphs does not support " << "does not support "
"automatic broadcasting. Please manually broadcast it before calling " << << "automatic broadcasting. Please manually broadcast it "
"message passing functions."; "before calling "
<< "message passing functions.";
else else
LOG(FATAL) << "Input features have different shapes."; LOG(FATAL) << "Input features have different shapes.";
return; return;
} }
if (etype == 0) if (etype == 0) x_length *= ufeat->shape[i];
x_length *= ufeat->shape[i];
} }
} }
// TODO(Israt): Can python do the following initializations while creating the tensors? // TODO(Israt): Can python do the following initializations while creating the
if (reduce == "max" || reduce == "min") { // tensors?
if (reduce == "max" || reduce == "min") {
const int64_t dim = bcast.out_len; const int64_t dim = bcast.out_len;
std::vector<bool> updated((*vec_out).size(), false); std::vector<bool> updated((*vec_out).size(), false);
for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) { for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
DType *out_off = (*vec_out)[out_ntids[etype]].Ptr<DType>(); DType* out_off = (*vec_out)[out_ntids[etype]].Ptr<DType>();
if (reduce == "max") if (reduce == "max")
_Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Max<IdType, DType>::zero()); _Fill(
out_off, vec_csr[etype].num_rows * dim,
cuda::reduce::Max<IdType, DType>::zero());
else // min else // min
_Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Min<IdType, DType>::zero()); _Fill(
out_off, vec_csr[etype].num_rows * dim,
cuda::reduce::Min<IdType, DType>::zero());
const dgl_type_t dst_id = out_ntids[etype]; const dgl_type_t dst_id = out_ntids[etype];
if (!updated[dst_id]) { if (!updated[dst_id]) {
updated[dst_id] = true; updated[dst_id] = true;
if (op == "copy_lhs") { if (op == "copy_lhs") {
IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>(); IdType* argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
_Fill(argu_ntype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1)); _Fill(
argu_ntype, vec_csr[etype].num_rows * dim,
static_cast<IdType>(-1));
} }
if (op == "copy_rhs") { if (op == "copy_rhs") {
IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>(); IdType* arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
_Fill(arge_etype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1)); _Fill(
arge_etype, vec_csr[etype].num_rows * dim,
static_cast<IdType>(-1));
} }
} }
} }
...@@ -106,60 +118,63 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -106,60 +118,63 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
CSRMatrix csr = vec_csr[etype]; CSRMatrix csr = vec_csr[etype];
if (reduce == "sum") { if (reduce == "sum") {
bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols); bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
/* Call SpMM for each relation type */ /* Call SpMM for each relation type */
if (op == "copy_lhs" && cusparse_available<DType, IdType>(more_nnz)) { // cusparse if (op == "copy_lhs" &&
/* If CUDA is less than 11.0, put the output in trans_out for later transposition */ cusparse_available<DType, IdType>(more_nnz)) { // cusparse
DType *out = (CUDART_VERSION < 11000) ? trans_out[dst_id] : /* If CUDA is less than 11.0, put the output in trans_out for later
static_cast<DType*>((*vec_out)[dst_id]->data); * transposition */
DType* out = (CUDART_VERSION < 11000)
? trans_out[dst_id]
: static_cast<DType*>((*vec_out)[dst_id]->data);
CusparseCsrmm2Hetero<DType, IdType>( CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr, csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data),
static_cast<DType*>(vec_ufeat[src_id]->data), nullptr, out, x_length, stream);
nullptr, } else if (
out, op == "mul" && is_scalar_efeat &&
x_length, stream);
} else if (op == "mul" && is_scalar_efeat &&
cusparse_available<DType, IdType>(more_nnz)) { // cusparse cusparse_available<DType, IdType>(more_nnz)) { // cusparse
NDArray efeat = vec_efeat[etype]; NDArray efeat = vec_efeat[etype];
if (!IsNullArray(csr.data)) if (!IsNullArray(csr.data))
efeat = _IndexSelect<DType, IdType>(efeat, csr.data); efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
CusparseCsrmm2Hetero<DType, IdType>( CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr, csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data),
static_cast<DType*>(vec_ufeat[src_id]->data),
static_cast<DType*>(efeat->data), static_cast<DType*>(efeat->data),
// TODO(Israt): Change (*vec_out) to trans_out to support CUDA version < 11 // TODO(Israt): Change (*vec_out) to trans_out to support CUDA
static_cast<DType*>((*vec_out)[dst_id]->data), // version < 11
x_length, stream); static_cast<DType*>((*vec_out)[dst_id]->data), x_length, stream);
} else { // general kernel } else { // general kernel
NDArray ufeat = (vec_ufeat.size() == 0) ? NDArray ufeat =
NullArray() : vec_ufeat[src_id]; (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ? NDArray efeat =
NullArray() : vec_efeat[etype]; (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >( cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType>>(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(), NullArray()); bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(),
NullArray());
}); });
} }
} else if (reduce == "max") { } else if (reduce == "max") {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ? NDArray ufeat =
NullArray() : vec_ufeat[src_id]; (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ? NDArray efeat =
NullArray() : vec_efeat[etype]; (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >( cuda::SpMMCmpCsrHetero<
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id], IdType, DType, Op, cuda::reduce::Max<IdType, DType>>(
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id], bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
src_id, etype); (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
}); src_id, etype);
});
} else if (reduce == "min") { } else if (reduce == "min") {
SWITCH_OP(op, Op, { SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ? NDArray ufeat =
NullArray() : vec_ufeat[src_id]; (vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ? NDArray efeat =
NullArray() : vec_efeat[etype]; (vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >( cuda::SpMMCmpCsrHetero<
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id], IdType, DType, Op, cuda::reduce::Min<IdType, DType>>(
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id], bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
src_id, etype); (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
src_id, etype);
}); });
} else { } else {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
...@@ -172,7 +187,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -172,7 +187,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const int m = (*vec_out)[ntype]->shape[0]; const int m = (*vec_out)[ntype]->shape[0];
const int n = (*vec_out)[ntype]->shape[1]; const int n = (*vec_out)[ntype]->shape[1];
if (m == 0) continue; if (m == 0) continue;
DType *C_data = static_cast<DType*>((*vec_out)[ntype]->data); DType* C_data = static_cast<DType*>((*vec_out)[ntype]->data);
_Transpose(trans_out[ntype], C_data, n, m); _Transpose(trans_out[ntype], C_data, n, m);
device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]); device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]);
} }
...@@ -180,55 +195,63 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -180,55 +195,63 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
} }
template void SpMMCsrHetero<kDGLCUDA, int32_t, __half>( template void SpMMCsrHetero<kDGLCUDA, int32_t, __half>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids,
const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int64_t, __half>( template void SpMMCsrHetero<kDGLCUDA, int64_t, __half>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids,
const std::vector<dgl_type_t>& out_ntids);
#if BF16_ENABLED #if BF16_ENABLED
template void SpMMCsrHetero<kDGLCUDA, int32_t, __nv_bfloat16>( template void SpMMCsrHetero<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids,
const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int64_t, __nv_bfloat16>( template void SpMMCsrHetero<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids,
const std::vector<dgl_type_t>& out_ntids);
#endif // BF16_ENABLED #endif // BF16_ENABLED
template void SpMMCsrHetero<kDGLCUDA, int32_t, float>( template void SpMMCsrHetero<kDGLCUDA, int32_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids,
const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int64_t, float>( template void SpMMCsrHetero<kDGLCUDA, int64_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids,
const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int32_t, double>( template void SpMMCsrHetero<kDGLCUDA, int32_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids,
const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int64_t, double>( template void SpMMCsrHetero<kDGLCUDA, int64_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce, const BcastOff& bcast,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids,
const std::vector<dgl_type_t>& out_ntids);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
* @brief Utilities for CUDA kernels. * @brief Utilities for CUDA kernels.
*/ */
#include "./utils.h"
#include "./dgl_cub.cuh"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./dgl_cub.cuh"
#include "./utils.h"
namespace dgl { namespace dgl {
namespace cuda { namespace cuda {
...@@ -17,9 +17,11 @@ bool AllTrue(int8_t* flags, int64_t length, const DGLContext& ctx) { ...@@ -17,9 +17,11 @@ bool AllTrue(int8_t* flags, int64_t length, const DGLContext& ctx) {
// Call CUB's reduction // Call CUB's reduction
size_t workspace_size = 0; size_t workspace_size = 0;
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
CUDA_CALL(cub::DeviceReduce::Min(nullptr, workspace_size, flags, rst, length, stream)); CUDA_CALL(cub::DeviceReduce::Min(
nullptr, workspace_size, flags, rst, length, stream));
void* workspace = device->AllocWorkspace(ctx, workspace_size); void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUDA_CALL(cub::DeviceReduce::Min(workspace, workspace_size, flags, rst, length, stream)); CUDA_CALL(cub::DeviceReduce::Min(
workspace, workspace_size, flags, rst, length, stream));
int8_t cpu_rst = GetCUDAScalar(device, ctx, rst); int8_t cpu_rst = GetCUDAScalar(device, ctx, rst);
device->FreeWorkspace(ctx, workspace); device->FreeWorkspace(ctx, workspace);
device->FreeWorkspace(ctx, rst); device->FreeWorkspace(ctx, rst);
......
...@@ -6,10 +6,11 @@ ...@@ -6,10 +6,11 @@
#ifndef DGL_ARRAY_CUDA_UTILS_H_ #ifndef DGL_ARRAY_CUDA_UTILS_H_
#define DGL_ARRAY_CUDA_UTILS_H_ #define DGL_ARRAY_CUDA_UTILS_H_
#include <dmlc/logging.h>
#include <dgl/runtime/c_runtime_api.h> #include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h> #include <dgl/runtime/device_api.h>
#include <dgl/runtime/ndarray.h> #include <dgl/runtime/ndarray.h>
#include <dmlc/logging.h>
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "dgl_cub.cuh" #include "dgl_cub.cuh"
...@@ -22,7 +23,6 @@ namespace cuda { ...@@ -22,7 +23,6 @@ namespace cuda {
// The max number of threads per block // The max number of threads per block
#define CUDA_MAX_NUM_THREADS 256 #define CUDA_MAX_NUM_THREADS 256
/** @brief Calculate the number of threads needed given the dimension length. /** @brief Calculate the number of threads needed given the dimension length.
* *
* It finds the biggest number that is smaller than min(dim, max_nthrs) * It finds the biggest number that is smaller than min(dim, max_nthrs)
...@@ -30,8 +30,7 @@ namespace cuda { ...@@ -30,8 +30,7 @@ namespace cuda {
*/ */
inline int FindNumThreads(int dim, int max_nthrs = CUDA_MAX_NUM_THREADS) { inline int FindNumThreads(int dim, int max_nthrs = CUDA_MAX_NUM_THREADS) {
CHECK_GE(dim, 0); CHECK_GE(dim, 0);
if (dim == 0) if (dim == 0) return 1;
return 1;
int ret = max_nthrs; int ret = max_nthrs;
while (ret > dim) { while (ret > dim) {
ret = ret >> 1; ret = ret >> 1;
...@@ -60,11 +59,9 @@ inline int FindNumBlocks(int nblks, int max_nblks = -1) { ...@@ -60,11 +59,9 @@ inline int FindNumBlocks(int nblks, int max_nblks = -1) {
LOG(FATAL) << "Axis " << axis << " not recognized"; LOG(FATAL) << "Axis " << axis << " not recognized";
break; break;
} }
if (max_nblks == -1) if (max_nblks == -1) max_nblks = default_max_nblks;
max_nblks = default_max_nblks;
CHECK_NE(nblks, 0); CHECK_NE(nblks, 0);
if (nblks < max_nblks) if (nblks < max_nblks) return nblks;
return nblks;
return max_nblks; return max_nblks;
} }
...@@ -108,7 +105,8 @@ template <typename DType> ...@@ -108,7 +105,8 @@ template <typename DType>
void _Fill(DType* ptr, size_t length, DType val) { void _Fill(DType* ptr, size_t length, DType val) {
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = FindNumThreads(length); int nt = FindNumThreads(length);
int nb = (length + nt - 1) / nt; // on x-axis, no need to worry about upperbound. int nb =
(length + nt - 1) / nt; // on x-axis, no need to worry about upperbound.
CUDA_KERNEL_CALL(cuda::_FillKernel, nb, nt, 0, stream, ptr, length, val); CUDA_KERNEL_CALL(cuda::_FillKernel, nb, nt, 0, stream, ptr, length, val);
} }
...@@ -123,9 +121,9 @@ void _Fill(DType* ptr, size_t length, DType val) { ...@@ -123,9 +121,9 @@ void _Fill(DType* ptr, size_t length, DType val) {
template <typename IdType, typename DType> template <typename IdType, typename DType>
__global__ void _LinearSearchKernel( __global__ void _LinearSearchKernel(
const IdType* indptr, const IdType* indices, const IdType* data, const IdType* indptr, const IdType* indices, const IdType* data,
const IdType* row, const IdType* col, const IdType* row, const IdType* col, int64_t row_stride,
int64_t row_stride, int64_t col_stride, int64_t col_stride, int64_t length, const DType* weights, DType filler,
int64_t length, const DType* weights, DType filler, DType* out) { DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < length) { while (tx < length) {
...@@ -148,7 +146,7 @@ __global__ void _LinearSearchKernel( ...@@ -148,7 +146,7 @@ __global__ void _LinearSearchKernel(
// constructor for __half. // constructor for __half.
// The using statement is to avoid a linter error about using // The using statement is to avoid a linter error about using
// long or long long. // long or long long.
using LongLong = long long; // NOLINT using LongLong = long long; // NOLINT
out[tx] = weights ? weights[v] : DType(LongLong(v)); out[tx] = weights ? weights[v] : DType(LongLong(v));
} }
tx += stride_x; tx += stride_x;
...@@ -163,9 +161,9 @@ __global__ void _LinearSearchKernel( ...@@ -163,9 +161,9 @@ __global__ void _LinearSearchKernel(
template <typename IdType> template <typename IdType>
__global__ void _LinearSearchKernel( __global__ void _LinearSearchKernel(
const IdType* indptr, const IdType* indices, const IdType* data, const IdType* indptr, const IdType* indices, const IdType* data,
const IdType* row, const IdType* col, const IdType* row, const IdType* col, int64_t row_stride,
int64_t row_stride, int64_t col_stride, int64_t length, int64_t col_stride, int64_t length, const __nv_bfloat16* weights,
const __nv_bfloat16* weights, __nv_bfloat16 filler, __nv_bfloat16* out) { __nv_bfloat16 filler, __nv_bfloat16* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x; int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x; const int stride_x = gridDim.x * blockDim.x;
while (tx < length) { while (tx < length) {
...@@ -181,7 +179,8 @@ __global__ void _LinearSearchKernel( ...@@ -181,7 +179,8 @@ __global__ void _LinearSearchKernel(
if (v == -1) { if (v == -1) {
out[tx] = filler; out[tx] = filler;
} else { } else {
// If the result is saved in bf16, it should be fine to convert it to float first // If the result is saved in bf16, it should be fine to convert it to
// float first
out[tx] = weights ? weights[v] : __nv_bfloat16(static_cast<float>(v)); out[tx] = weights ? weights[v] : __nv_bfloat16(static_cast<float>(v));
} }
tx += stride_x; tx += stride_x;
...@@ -191,16 +190,10 @@ __global__ void _LinearSearchKernel( ...@@ -191,16 +190,10 @@ __global__ void _LinearSearchKernel(
template <typename DType> template <typename DType>
inline DType GetCUDAScalar( inline DType GetCUDAScalar(
runtime::DeviceAPI* device_api, runtime::DeviceAPI* device_api, DGLContext ctx, const DType* cuda_ptr) {
DGLContext ctx,
const DType* cuda_ptr) {
DType result; DType result;
device_api->CopyDataFromTo( device_api->CopyDataFromTo(
cuda_ptr, 0, cuda_ptr, 0, &result, 0, sizeof(result), ctx, DGLContext{kDGLCPU, 0},
&result, 0,
sizeof(result),
ctx,
DGLContext{kDGLCPU, 0},
DGLDataTypeTraits<DType>::dtype); DGLDataTypeTraits<DType>::dtype);
return result; return result;
} }
...@@ -217,12 +210,12 @@ inline DType GetCUDAScalar( ...@@ -217,12 +210,12 @@ inline DType GetCUDAScalar(
* if x<A[0] then it returns 0. * if x<A[0] then it returns 0.
*/ */
template <typename IdType> template <typename IdType>
__device__ IdType _UpperBound(const IdType *A, int64_t n, IdType x) { __device__ IdType _UpperBound(const IdType* A, int64_t n, IdType x) {
IdType l = 0, r = n, m = 0; IdType l = 0, r = n, m = 0;
while (l < r) { while (l < r) {
m = l + (r-l)/2; m = l + (r - l) / 2;
if (x >= A[m]) { if (x >= A[m]) {
l = m+1; l = m + 1;
} else { } else {
r = m; r = m;
} }
...@@ -241,17 +234,17 @@ __device__ IdType _UpperBound(const IdType *A, int64_t n, IdType x) { ...@@ -241,17 +234,17 @@ __device__ IdType _UpperBound(const IdType *A, int64_t n, IdType x) {
* @return index, i, st. A[i]==x. If such an index not exists returns 'n'. * @return index, i, st. A[i]==x. If such an index not exists returns 'n'.
*/ */
template <typename IdType> template <typename IdType>
__device__ IdType _BinarySearch(const IdType *A, int64_t n, IdType x) { __device__ IdType _BinarySearch(const IdType* A, int64_t n, IdType x) {
IdType l = 0, r = n-1, m = 0; IdType l = 0, r = n - 1, m = 0;
while (l <= r) { while (l <= r) {
m = l + (r-l)/2; m = l + (r - l) / 2;
if (A[m] == x) { if (A[m] == x) {
return m; return m;
} }
if (A[m] < x) { if (A[m] < x) {
l = m+1; l = m + 1;
} else { } else {
r = m-1; r = m - 1;
} }
} }
return n; // not found return n; // not found
...@@ -259,9 +252,9 @@ __device__ IdType _BinarySearch(const IdType *A, int64_t n, IdType x) { ...@@ -259,9 +252,9 @@ __device__ IdType _BinarySearch(const IdType *A, int64_t n, IdType x) {
template <typename DType, typename BoolType> template <typename DType, typename BoolType>
void MaskSelect( void MaskSelect(
runtime::DeviceAPI* device, const DGLContext& ctx, runtime::DeviceAPI* device, const DGLContext& ctx, const DType* input,
const DType* input, const BoolType* mask, DType* output, int64_t n, const BoolType* mask, DType* output, int64_t n, int64_t* rst,
int64_t* rst, cudaStream_t stream) { cudaStream_t stream) {
size_t workspace_size = 0; size_t workspace_size = 0;
CUDA_CALL(cub::DeviceSelect::Flagged( CUDA_CALL(cub::DeviceSelect::Flagged(
nullptr, workspace_size, input, mask, output, rst, n, stream)); nullptr, workspace_size, input, mask, output, rst, n, stream));
......
...@@ -3,33 +3,28 @@ ...@@ -3,33 +3,28 @@
* @file array/kernel.cc * @file array/kernel.cc
* @brief New kernels * @brief New kernels
*/ */
#include <dgl/packed_func_ext.h>
#include <dgl/base_heterograph.h> #include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h>
#ifdef USE_TVM #ifdef USE_TVM
#include <featgraph.h>
#include <dgl/runtime/dlpack_convert.h> #include <dgl/runtime/dlpack_convert.h>
#include <featgraph.h>
#endif // USE_TVM #endif // USE_TVM
#include "kernel_decl.h"
#include "../c_api_common.h" #include "../c_api_common.h"
#include "./check.h" #include "./check.h"
#include "kernel_decl.h"
using namespace dgl::runtime; using namespace dgl::runtime;
namespace dgl { namespace dgl {
namespace aten { namespace aten {
namespace { namespace {} // namespace
} // namespace
/** @brief Generalized Sparse Matrix-Matrix Multiplication. */ /** @brief Generalized Sparse Matrix-Matrix Multiplication. */
void SpMM(const std::string& op, const std::string& reduce, void SpMM(
HeteroGraphPtr graph, const std::string& op, const std::string& reduce, HeteroGraphPtr graph,
NDArray ufeat, NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux) {
// TODO(zihao): format tuning // TODO(zihao): format tuning
SparseFormat format = graph->SelectFormat(0, CSC_CODE); SparseFormat format = graph->SelectFormat(0, CSC_CODE);
const auto& bcast = CalcBcastOff(op, ufeat, efeat); const auto& bcast = CalcBcastOff(op, ufeat, efeat);
...@@ -39,12 +34,12 @@ void SpMM(const std::string& op, const std::string& reduce, ...@@ -39,12 +34,12 @@ void SpMM(const std::string& op, const std::string& reduce,
ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", {
if (format == SparseFormat::kCSC) { if (format == SparseFormat::kCSC) {
SpMMCsr<XPU, IdType, Dtype>( SpMMCsr<XPU, IdType, Dtype>(
op, reduce, bcast, graph->GetCSCMatrix(0), op, reduce, bcast, graph->GetCSCMatrix(0), ufeat, efeat, out,
ufeat, efeat, out, out_aux); out_aux);
} else if (format == SparseFormat::kCOO) { } else if (format == SparseFormat::kCOO) {
SpMMCoo<XPU, IdType, Dtype>( SpMMCoo<XPU, IdType, Dtype>(
op, reduce, bcast, graph->GetCOOMatrix(0), op, reduce, bcast, graph->GetCOOMatrix(0), ufeat, efeat, out,
ufeat, efeat, out, out_aux); out_aux);
} else { } else {
LOG(FATAL) << "SpMM only supports CSC and COO formats"; LOG(FATAL) << "SpMM only supports CSC and COO formats";
} }
...@@ -53,27 +48,27 @@ void SpMM(const std::string& op, const std::string& reduce, ...@@ -53,27 +48,27 @@ void SpMM(const std::string& op, const std::string& reduce,
}); });
} }
/** @brief Generalized segmented dense Matrix-Matrix Multiplication. */ /** @brief Generalized segmented dense Matrix-Matrix Multiplication. */
void SegmentMM(const NDArray A, void SegmentMM(
const NDArray B, const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
NDArray C, bool A_trans, bool B_trans) {
const NDArray seglen_A,
bool A_trans, bool B_trans) {
CHECK_EQ(A->ndim, 2) << "segment_mm expects a 2D tensor for the first input."; CHECK_EQ(A->ndim, 2) << "segment_mm expects a 2D tensor for the first input.";
CHECK_EQ(B->ndim, 3) << "segment_mm expects a 3D tensor for the second input."; CHECK_EQ(B->ndim, 3)
<< "segment_mm expects a 3D tensor for the second input.";
CHECK(!A_trans); CHECK(!A_trans);
if (B_trans) { if (B_trans) {
CHECK_EQ(A->shape[1], B->shape[2]) CHECK_EQ(A->shape[1], B->shape[2])
<< "segment_mm expects A.shape[1] == B.shape[2] when B_trans=True"; << "segment_mm expects A.shape[1] == B.shape[2] when B_trans=True";
} else { } else {
CHECK_EQ(A->shape[1], B->shape[1]) << "segment_mm expects A.shape[1] == B.shape[1]"; CHECK_EQ(A->shape[1], B->shape[1])
<< "segment_mm expects A.shape[1] == B.shape[1]";
} }
CHECK_EQ(B->shape[0], seglen_A.NumElements()) CHECK_EQ(B->shape[0], seglen_A.NumElements())
<< "segment_mm expects len(seglen_A) == B.shape[0]"; << "segment_mm expects len(seglen_A) == B.shape[0]";
CHECK_EQ(seglen_A->ctx.device_type, kDGLCPU) CHECK_EQ(seglen_A->ctx.device_type, kDGLCPU)
<< "segment_mm expects seglen_A to be on CPU."; << "segment_mm expects seglen_A to be on CPU.";
CHECK(A->ctx == B->ctx) << "segment_mm expects A and B to be of the same device"; CHECK(A->ctx == B->ctx)
<< "segment_mm expects A and B to be of the same device";
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "SegmentMM", { ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "SegmentMM", {
ATEN_ID_TYPE_SWITCH(seglen_A->dtype, IdType, { ATEN_ID_TYPE_SWITCH(seglen_A->dtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", {
...@@ -83,15 +78,14 @@ void SegmentMM(const NDArray A, ...@@ -83,15 +78,14 @@ void SegmentMM(const NDArray A,
}); });
} }
void SegmentMMBackwardB(const NDArray A, void SegmentMMBackwardB(
const NDArray dC, const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) {
NDArray dB, CHECK_EQ(A->ndim, 2) << "segment_mm_backward operator expects a 2D tensor "
const NDArray seglen) { "for the first input.";
CHECK_EQ(A->ndim, 2) << "segment_mm_backward operator expects a 2D tensor for the first input."; CHECK_EQ(dC->ndim, 2) << "segment_mm_backward operator expects a 2D tensor "
CHECK_EQ(dC->ndim, 2) "for the second input.";
<< "segment_mm_backward operator expects a 2D tensor for the second input.";
CHECK_EQ(seglen->ctx.device_type, kDGLCPU) CHECK_EQ(seglen->ctx.device_type, kDGLCPU)
<< "segment_mm expects seglen to be on CPU."; << "segment_mm expects seglen to be on CPU.";
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "SegmentMMBackwardB", { ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "SegmentMMBackwardB", {
ATEN_ID_TYPE_SWITCH(seglen->dtype, IdType, { ATEN_ID_TYPE_SWITCH(seglen->dtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", {
...@@ -101,34 +95,35 @@ void SegmentMMBackwardB(const NDArray A, ...@@ -101,34 +95,35 @@ void SegmentMMBackwardB(const NDArray A,
}); });
} }
/** @brief Generalized Dense Matrix-Matrix Multiplication according to relation
/** @brief Generalized Dense Matrix-Matrix Multiplication according to relation types. */ * types. */
void GatherMM(const NDArray A, void GatherMM(
const NDArray B, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
NDArray C, const NDArray idx_b) {
const NDArray idx_a, CHECK_EQ(A->ndim, 2)
const NDArray idx_b) { << "gather_mm operator expects a 2D tensor for the first input.";
CHECK_EQ(A->ndim, 2) << "gather_mm operator expects a 2D tensor for the first input."; CHECK_EQ(B->ndim, 3)
CHECK_EQ(B->ndim, 3) << "gather_mm operator expects a 3D tensor for the second input."; << "gather_mm operator expects a 3D tensor for the second input.";
CHECK(A->ctx == B->ctx) CHECK(A->ctx == B->ctx)
<< "gather_mm expects all arguments to be on the same device."; << "gather_mm expects all arguments to be on the same device.";
if (aten::IsNullArray(idx_a)) { if (aten::IsNullArray(idx_a)) {
CHECK_EQ(A->shape[0], idx_b->shape[0]) CHECK_EQ(A->shape[0], idx_b->shape[0])
<< "gather_mm expects len(idx_b) == A.shape[0] when idx_a is None."; << "gather_mm expects len(idx_b) == A.shape[0] when idx_a is None.";
CHECK(A->ctx == idx_b->ctx) CHECK(A->ctx == idx_b->ctx)
<< "gather_mm expects all arguments to be on the same device."; << "gather_mm expects all arguments to be on the same device.";
} else if (aten::IsNullArray(idx_b)) { } else if (aten::IsNullArray(idx_b)) {
CHECK_EQ(B->shape[0], idx_a->shape[0]) CHECK_EQ(B->shape[0], idx_a->shape[0])
<< "gather_mm expects len(idx_a) == B.shape[0] when idx_b is None."; << "gather_mm expects len(idx_a) == B.shape[0] when idx_b is None.";
CHECK(A->ctx == idx_a->ctx) CHECK(A->ctx == idx_a->ctx)
<< "gather_mm expects all arguments to be on the same device."; << "gather_mm expects all arguments to be on the same device.";
} else { } else {
CHECK_EQ(idx_a->shape[0], idx_b->shape[0]) CHECK_EQ(idx_a->shape[0], idx_b->shape[0])
<< "gather_mm expects len(idx_a) == len(idx_b) when both idx_a and idx_b are given."; << "gather_mm expects len(idx_a) == len(idx_b) when both idx_a and "
"idx_b are given.";
CHECK(A->ctx == idx_a->ctx && A->ctx == idx_b->ctx) CHECK(A->ctx == idx_a->ctx && A->ctx == idx_b->ctx)
<< "gather_mm expects all arguments to be on the same device."; << "gather_mm expects all arguments to be on the same device.";
} }
const auto idtype = aten::IsNullArray(idx_a)? idx_b->dtype : idx_a->dtype; const auto idtype = aten::IsNullArray(idx_a) ? idx_b->dtype : idx_a->dtype;
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", { ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", {
ATEN_ID_TYPE_SWITCH(idtype, IdType, { ATEN_ID_TYPE_SWITCH(idtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", {
...@@ -138,36 +133,36 @@ void GatherMM(const NDArray A, ...@@ -138,36 +133,36 @@ void GatherMM(const NDArray A,
}); });
} }
/** @brief Generalized Dense Matrix-Matrix Multiplication according to relation
/** @brief Generalized Dense Matrix-Matrix Multiplication according to relation types. */ * types. */
void GatherMMScatter(const NDArray A, void GatherMMScatter(
const NDArray B, const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
NDArray C, const NDArray idx_b, const NDArray idx_c) {
const NDArray idx_a, CHECK_EQ(A->ndim, 2)
const NDArray idx_b, << "gather_mm_scatter expects a 2D tensor for the first input.";
const NDArray idx_c) {
CHECK_EQ(A->ndim, 2) << "gather_mm_scatter expects a 2D tensor for the first input.";
CHECK(A->ctx == B->ctx) CHECK(A->ctx == B->ctx)
<< "gather_mm_scatter expects all arguments to be on the same device."; << "gather_mm_scatter expects all arguments to be on the same device.";
if (!aten::IsNullArray(idx_c)) if (!aten::IsNullArray(idx_c))
CHECK(A->ctx == idx_c->ctx) CHECK(A->ctx == idx_c->ctx)
<< "gather_mm_scatter expects all arguments to be on the same device."; << "gather_mm_scatter expects all arguments to be on the same device.";
if (aten::IsNullArray(idx_a) && !aten::IsNullArray(idx_b)) { if (aten::IsNullArray(idx_a) && !aten::IsNullArray(idx_b)) {
CHECK_EQ(A->shape[0], idx_b->shape[0]) CHECK_EQ(A->shape[0], idx_b->shape[0])
<< "gather_mm_scatter expects len(idx_b) == A.shape[0] when idx_a is None."; << "gather_mm_scatter expects len(idx_b) == A.shape[0] when idx_a is "
"None.";
CHECK(A->ctx == idx_b->ctx) CHECK(A->ctx == idx_b->ctx)
<< "gather_mm_scatter expects all arguments to be on the same device."; << "gather_mm_scatter expects all arguments to be on the same device.";
} else if (aten::IsNullArray(idx_b) && !aten::IsNullArray(idx_a)) { } else if (aten::IsNullArray(idx_b) && !aten::IsNullArray(idx_a)) {
CHECK_EQ(B->shape[0], idx_a->shape[0]) CHECK_EQ(B->shape[0], idx_a->shape[0])
<< "gather_mm_scatter expects len(idx_a) == B.shape[0] when idx_b is None."; << "gather_mm_scatter expects len(idx_a) == B.shape[0] when idx_b is "
"None.";
CHECK(A->ctx == idx_a->ctx) CHECK(A->ctx == idx_a->ctx)
<< "gather_mm_scatter expects all arguments to be on the same device."; << "gather_mm_scatter expects all arguments to be on the same device.";
} else if (!aten::IsNullArray(idx_b) && !aten::IsNullArray(idx_a)) { } else if (!aten::IsNullArray(idx_b) && !aten::IsNullArray(idx_a)) {
CHECK_EQ(idx_a->shape[0], idx_b->shape[0]) CHECK_EQ(idx_a->shape[0], idx_b->shape[0])
<< "gather_mm_scatter expects len(idx_a) == len(idx_b) " << "gather_mm_scatter expects len(idx_a) == len(idx_b) "
<< "when both idx_a and idx_b are given."; << "when both idx_a and idx_b are given.";
CHECK(A->ctx == idx_a->ctx && A->ctx == idx_b->ctx) CHECK(A->ctx == idx_a->ctx && A->ctx == idx_b->ctx)
<< "gather_mm_scatter expects all arguments to be on the same device."; << "gather_mm_scatter expects all arguments to be on the same device.";
} }
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", { ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", {
ATEN_ID_TYPE_SWITCH(idx_c->dtype, IdType, { ATEN_ID_TYPE_SWITCH(idx_c->dtype, IdType, {
...@@ -178,14 +173,13 @@ void GatherMMScatter(const NDArray A, ...@@ -178,14 +173,13 @@ void GatherMMScatter(const NDArray A,
}); });
} }
/** @brief Generalized Sparse Matrix-Matrix Multiplication with hetero-graph
/** @brief Generalized Sparse Matrix-Matrix Multiplication with hetero-graph support. */ * support. */
void SpMMHetero(const std::string& op, const std::string& reduce, void SpMMHetero(
HeteroGraphPtr graph, const std::string& op, const std::string& reduce, HeteroGraphPtr graph,
const std::vector<NDArray>& ufeat_vec, const std::vector<NDArray>& ufeat_vec,
const std::vector<NDArray>& efeat_vec, const std::vector<NDArray>& efeat_vec, std::vector<NDArray>* out,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux) {
std::vector<std::vector<NDArray>>* out_aux) {
SparseFormat format = graph->SelectFormat(0, CSC_CODE); SparseFormat format = graph->SelectFormat(0, CSC_CODE);
std::vector<CSRMatrix> vec_graph; std::vector<CSRMatrix> vec_graph;
...@@ -193,7 +187,8 @@ void SpMMHetero(const std::string& op, const std::string& reduce, ...@@ -193,7 +187,8 @@ void SpMMHetero(const std::string& op, const std::string& reduce,
std::vector<dgl_type_t> efeat_eid; std::vector<dgl_type_t> efeat_eid;
std::vector<dgl_type_t> out_eid; std::vector<dgl_type_t> out_eid;
auto pair = graph->meta_graph()->FindEdge(0); // first etype auto pair = graph->meta_graph()->FindEdge(0); // first etype
NDArray ufeat_etype0 = (ufeat_vec.size() == 0) ? NullArray() : ufeat_vec[pair.first]; NDArray ufeat_etype0 =
(ufeat_vec.size() == 0) ? NullArray() : ufeat_vec[pair.first];
NDArray efeat_etype0 = (efeat_vec.size() == 0) ? NullArray() : efeat_vec[0]; NDArray efeat_etype0 = (efeat_vec.size() == 0) ? NullArray() : efeat_vec[0];
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
vec_graph.push_back(graph->GetCSCMatrix(etype)); vec_graph.push_back(graph->GetCSCMatrix(etype));
...@@ -202,54 +197,53 @@ void SpMMHetero(const std::string& op, const std::string& reduce, ...@@ -202,54 +197,53 @@ void SpMMHetero(const std::string& op, const std::string& reduce,
efeat_eid.push_back(etype); efeat_eid.push_back(etype);
out_eid.push_back(pair.second); out_eid.push_back(pair.second);
if (ufeat_etype0->shape[1] != ufeat_vec[pair.first]->shape[1]) if (ufeat_etype0->shape[1] != ufeat_vec[pair.first]->shape[1])
LOG(FATAL) << "Column width of the input node features of all etypes must be same."; LOG(FATAL) << "Column width of the input node features of all etypes "
"must be same.";
if (efeat_etype0->shape[1] != efeat_vec[etype]->shape[1]) if (efeat_etype0->shape[1] != efeat_vec[etype]->shape[1])
LOG(FATAL) << "Column width of the input edge features of all etypes must be same."; LOG(FATAL) << "Column width of the input edge features of all etypes "
"must be same.";
} }
const auto& bcast = CalcBcastOff(op, ufeat_etype0, efeat_etype0); const auto& bcast = CalcBcastOff(op, ufeat_etype0, efeat_etype0);
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", { ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(
ATEN_FLOAT_TYPE_SWITCH_16BITS((*out)[out_eid[0]]->dtype, Dtype, XPU, "Feature data", { graph->DataType(), IdType, {
if (format == SparseFormat::kCSC) { ATEN_FLOAT_TYPE_SWITCH_16BITS(
SpMMCsrHetero<XPU, IdType, Dtype>( (*out)[out_eid[0]]->dtype, Dtype, XPU, "Feature data", {
op, reduce, bcast, vec_graph, if (format == SparseFormat::kCSC) {
ufeat_vec, efeat_vec, out, out_aux, SpMMCsrHetero<XPU, IdType, Dtype>(
ufeat_eid, out_eid); op, reduce, bcast, vec_graph, ufeat_vec, efeat_vec, out,
} else { out_aux, ufeat_eid, out_eid);
// TODO(Israt): Add support for COO format } else {
LOG(FATAL) << "SpMM only supports CSC format for graphs with number " // TODO(Israt): Add support for COO format
<< "of relation types > 1"; LOG(FATAL)
} << "SpMM only supports CSC format for graphs with number "
}); << "of relation types > 1";
}); }
});
});
}); });
} }
/** @brief Generalized Sampled Dense-Dense Matrix Multiplication. */ /** @brief Generalized Sampled Dense-Dense Matrix Multiplication. */
void SDDMM(const std::string& op, void SDDMM(
HeteroGraphPtr graph, const std::string& op, HeteroGraphPtr graph, NDArray lhs, NDArray rhs,
NDArray lhs, NDArray out, int lhs_target, int rhs_target) {
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target) {
// TODO(zihao): format tuning // TODO(zihao): format tuning
SparseFormat format = graph->SelectFormat(0, COO_CODE); SparseFormat format = graph->SelectFormat(0, COO_CODE);
const auto &bcast = CalcBcastOff(op, lhs, rhs); const auto& bcast = CalcBcastOff(op, lhs, rhs);
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", { ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", {
if (format == SparseFormat::kCSR) { if (format == SparseFormat::kCSR) {
SDDMMCsr<XPU, IdType, Dtype>( SDDMMCsr<XPU, IdType, Dtype>(
op, bcast, graph->GetCSRMatrix(0), op, bcast, graph->GetCSRMatrix(0), lhs, rhs, out, lhs_target,
lhs, rhs, out, lhs_target, rhs_target); rhs_target);
} else if (format == SparseFormat::kCOO) { } else if (format == SparseFormat::kCOO) {
SDDMMCoo<XPU, IdType, Dtype>( SDDMMCoo<XPU, IdType, Dtype>(
op, bcast, graph->GetCOOMatrix(0), op, bcast, graph->GetCOOMatrix(0), lhs, rhs, out, lhs_target,
lhs, rhs, out, lhs_target, rhs_target); rhs_target);
} else { } else {
LOG(FATAL) << "SDDMM only supports CSR and COO formats"; LOG(FATAL) << "SDDMM only supports CSR and COO formats";
} }
...@@ -267,21 +261,16 @@ void SDDMM(const std::string& op, ...@@ -267,21 +261,16 @@ void SDDMM(const std::string& op,
*/ */
int get_typeid_by_target(HeteroGraphPtr graph, int target, dgl_type_t etype) { int get_typeid_by_target(HeteroGraphPtr graph, int target, dgl_type_t etype) {
auto pair = graph->meta_graph()->FindEdge(etype); auto pair = graph->meta_graph()->FindEdge(etype);
if (target == 0) if (target == 0) return pair.first;
return pair.first; if (target == 2) return pair.second;
if (target == 2)
return pair.second;
return etype; return etype;
} }
/** @brief Generalized Sampled Dense-Dense Matrix Multiplication. */ /** @brief Generalized Sampled Dense-Dense Matrix Multiplication. */
void SDDMMHetero(const std::string& op, void SDDMMHetero(
HeteroGraphPtr graph, const std::string& op, HeteroGraphPtr graph, std::vector<NDArray> lhs,
std::vector<NDArray> lhs, std::vector<NDArray> rhs, std::vector<NDArray> out, int lhs_target,
std::vector<NDArray> rhs, int rhs_target) {
std::vector<NDArray> out,
int lhs_target,
int rhs_target) {
SparseFormat format = graph->SelectFormat(0, COO_CODE); SparseFormat format = graph->SelectFormat(0, COO_CODE);
std::vector<dgl_type_t> lhs_eid; std::vector<dgl_type_t> lhs_eid;
...@@ -290,79 +279,74 @@ void SDDMMHetero(const std::string& op, ...@@ -290,79 +279,74 @@ void SDDMMHetero(const std::string& op,
lhs_eid.push_back(get_typeid_by_target(graph, lhs_target, etype)); lhs_eid.push_back(get_typeid_by_target(graph, lhs_target, etype));
rhs_eid.push_back(get_typeid_by_target(graph, rhs_target, etype)); rhs_eid.push_back(get_typeid_by_target(graph, rhs_target, etype));
} }
const auto &bcast = CalcBcastOff(op, lhs[lhs_eid[0]], rhs[rhs_eid[0]]); const auto& bcast = CalcBcastOff(op, lhs[lhs_eid[0]], rhs[rhs_eid[0]]);
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", { ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(out[rhs_eid[0]]->dtype, Dtype, XPU, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(
if (format == SparseFormat::kCSR) { out[rhs_eid[0]]->dtype, Dtype, XPU, "Feature data", {
std::vector<CSRMatrix> vec_csr; if (format == SparseFormat::kCSR) {
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { std::vector<CSRMatrix> vec_csr;
vec_csr.push_back(graph->GetCSRMatrix(etype)); for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes();
} ++etype) {
SDDMMCsrHetero<XPU, IdType, Dtype>( vec_csr.push_back(graph->GetCSRMatrix(etype));
op, bcast, vec_csr, }
lhs, rhs, out, lhs_target, rhs_target, SDDMMCsrHetero<XPU, IdType, Dtype>(
lhs_eid, rhs_eid); op, bcast, vec_csr, lhs, rhs, out, lhs_target, rhs_target,
} else if (format == SparseFormat::kCOO) { lhs_eid, rhs_eid);
std::vector<COOMatrix> vec_coo; } else if (format == SparseFormat::kCOO) {
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { std::vector<COOMatrix> vec_coo;
vec_coo.push_back(graph->GetCOOMatrix(etype)); for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes();
} ++etype) {
SDDMMCooHetero<XPU, IdType, Dtype>( vec_coo.push_back(graph->GetCOOMatrix(etype));
op, bcast, vec_coo, }
lhs, rhs, out, lhs_target, rhs_target, SDDMMCooHetero<XPU, IdType, Dtype>(
lhs_eid, rhs_eid); op, bcast, vec_coo, lhs, rhs, out, lhs_target, rhs_target,
} else { lhs_eid, rhs_eid);
LOG(FATAL) << "SDDMM only supports CSR and COO formats"; } else {
} LOG(FATAL) << "SDDMM only supports CSR and COO formats";
}); }
});
}); });
}); });
} }
/** @brief Generalized Edge_softmax op for forward */ /** @brief Generalized Edge_softmax op for forward */
void Edge_softmax_forward(const std::string& op, void Edge_softmax_forward(
HeteroGraphPtr graph, const std::string& op, HeteroGraphPtr graph, NDArray ufeat, NDArray efeat,
NDArray ufeat, NDArray out) {
NDArray efeat,
NDArray out) {
// TODO(zhejiang): add gpu op for edge_softmax // TODO(zhejiang): add gpu op for edge_softmax
const auto& bcast = CalcBcastOff(op, ufeat, efeat); const auto& bcast = CalcBcastOff(op, ufeat, efeat);
ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "edge_softmax", { ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "edge_softmax", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "edge_softmax out data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(
Edge_softmax_csr_forward<XPU, IdType, Dtype>( out->dtype, Dtype, XPU, "edge_softmax out data", {
op, bcast, graph->GetCSCMatrix(0), ufeat, efeat, out); Edge_softmax_csr_forward<XPU, IdType, Dtype>(
}); op, bcast, graph->GetCSCMatrix(0), ufeat, efeat, out);
});
}); });
}); });
} }
/** @brief Generalized Edge_softmax op for backward */ /** @brief Generalized Edge_softmax op for backward */
void Edge_softmax_backward(const std::string& op, void Edge_softmax_backward(
HeteroGraphPtr graph, const std::string& op, HeteroGraphPtr graph, NDArray out, NDArray sds,
NDArray out, NDArray back_out, NDArray ufeat) {
NDArray sds,
NDArray back_out,
NDArray ufeat) {
// TODO(zhejiang): add gpu op for edge_softmax // TODO(zhejiang): add gpu op for edge_softmax
const auto& bcast = CalcBcastOff(op, ufeat, sds); const auto& bcast = CalcBcastOff(op, ufeat, sds);
ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "edge_softmax_back", { ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "edge_softmax_back", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "edge_softmax out data_back", { ATEN_FLOAT_TYPE_SWITCH_16BITS(
Edge_softmax_csr_backward<XPU, IdType, Dtype>( out->dtype, Dtype, XPU, "edge_softmax out data_back", {
op, bcast, graph->GetCSCMatrix(0), out, sds, back_out); Edge_softmax_csr_backward<XPU, IdType, Dtype>(
}); op, bcast, graph->GetCSCMatrix(0), out, sds, back_out);
});
}); });
}); });
} }
NDArray GetEdgeMapping(HeteroGraphRef graph) { NDArray GetEdgeMapping(HeteroGraphRef graph) {
SparseFormat format = graph->SelectFormat(0, CSC_CODE); SparseFormat format = graph->SelectFormat(0, CSC_CODE);
if (format == SparseFormat::kCSC) { if (format == SparseFormat::kCSC) {
...@@ -373,15 +357,13 @@ NDArray GetEdgeMapping(HeteroGraphRef graph) { ...@@ -373,15 +357,13 @@ NDArray GetEdgeMapping(HeteroGraphRef graph) {
} }
/** @brief Segment reduce dispatch function. */ /** @brief Segment reduce dispatch function. */
void SegmentReduceDispatch(const std::string& op, void SegmentReduceDispatch(
NDArray feat, const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray offsets, NDArray arg) {
NDArray out,
NDArray arg) {
ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "SegmentReduce", { ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "SegmentReduce", {
ATEN_ID_TYPE_SWITCH(offsets->dtype, IdType, { ATEN_ID_TYPE_SWITCH(offsets->dtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(feat->dtype, Dtype, XPU, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(feat->dtype, Dtype, XPU, "Feature data", {
SegmentReduce<XPU, IdType, Dtype>(op, feat, offsets, out, arg); SegmentReduce<XPU, IdType, Dtype>(op, feat, offsets, out, arg);
}); });
}); });
}); });
...@@ -398,20 +380,21 @@ void ScatterAddDispatch(NDArray feat, NDArray idx, NDArray out) { ...@@ -398,20 +380,21 @@ void ScatterAddDispatch(NDArray feat, NDArray idx, NDArray out) {
}); });
} }
/** @brief Update gradients (reduce op max/min) dispatch function on heterogeneous graph. */ /** @brief Update gradients (reduce op max/min) dispatch function on
void UpdateGradMinMaxDispatchHetero(const HeteroGraphPtr& graph, * heterogeneous graph. */
const std::string& op, void UpdateGradMinMaxDispatchHetero(
const std::vector<NDArray>& feat, const HeteroGraphPtr& graph, const std::string& op,
const std::vector<NDArray>& idx, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {
std::vector<NDArray>* out) {
auto pair = graph->meta_graph()->FindEdge(0); // checking the first etype auto pair = graph->meta_graph()->FindEdge(0); // checking the first etype
auto src_id = pair.first; auto src_id = pair.first;
ATEN_XPU_SWITCH_CUDA(feat[src_id]->ctx.device_type, XPU, "ScatterAdd", { ATEN_XPU_SWITCH_CUDA(feat[src_id]->ctx.device_type, XPU, "ScatterAdd", {
ATEN_ID_TYPE_SWITCH(idx[src_id]->dtype, IdType, { ATEN_ID_TYPE_SWITCH(idx[src_id]->dtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(feat[src_id]->dtype, Dtype, XPU, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(
UpdateGradMinMax_hetero<XPU, IdType, Dtype>(graph, op, feat, idx, idx_etype, out); feat[src_id]->dtype, Dtype, XPU, "Feature data", {
}); UpdateGradMinMax_hetero<XPU, IdType, Dtype>(
graph, op, feat, idx, idx_etype, out);
});
}); });
}); });
} }
...@@ -428,20 +411,19 @@ void BackwardSegmentCmpDispatch(NDArray feat, NDArray arg, NDArray out) { ...@@ -428,20 +411,19 @@ void BackwardSegmentCmpDispatch(NDArray feat, NDArray arg, NDArray out) {
} }
std::pair<CSRMatrix, NDArray> CSRMM( std::pair<CSRMatrix, NDArray> CSRMM(
CSRMatrix A, CSRMatrix A, NDArray A_weights, CSRMatrix B, NDArray B_weights) {
NDArray A_weights, CHECK_EQ(A.num_cols, B.num_rows)
CSRMatrix B, << "The number of nodes of destination node type of the first graph must "
NDArray B_weights) { "be the "
CHECK_EQ(A.num_cols, B.num_rows) << "same as the number of nodes of source node type of the second graph.";
"The number of nodes of destination node type of the first graph must be the "
"same as the number of nodes of source node type of the second graph.";
CheckCtx( CheckCtx(
A.indptr->ctx, A.indptr->ctx, {A_weights, B_weights},
{A_weights, B_weights},
{"A's edge weights", "B's edge weights"}); {"A's edge weights", "B's edge weights"});
CHECK_EQ(A.indptr->ctx, B.indptr->ctx) << "Device of two graphs must match."; CHECK_EQ(A.indptr->ctx, B.indptr->ctx) << "Device of two graphs must match.";
CHECK_EQ(A.indptr->dtype, B.indptr->dtype) << "ID types of two graphs must match."; CHECK_EQ(A.indptr->dtype, B.indptr->dtype)
CHECK_EQ(A_weights->dtype, B_weights->dtype) << "Data types of two edge weights must match."; << "ID types of two graphs must match.";
CHECK_EQ(A_weights->dtype, B_weights->dtype)
<< "Data types of two edge weights must match.";
std::pair<CSRMatrix, NDArray> ret; std::pair<CSRMatrix, NDArray> ret;
ATEN_XPU_SWITCH_CUDA(A.indptr->ctx.device_type, XPU, "CSRMM", { ATEN_XPU_SWITCH_CUDA(A.indptr->ctx.device_type, XPU, "CSRMM", {
...@@ -455,27 +437,31 @@ std::pair<CSRMatrix, NDArray> CSRMM( ...@@ -455,27 +437,31 @@ std::pair<CSRMatrix, NDArray> CSRMM(
} }
std::pair<CSRMatrix, NDArray> CSRSum( std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& A, const std::vector<CSRMatrix>& A, const std::vector<NDArray>& A_weights) {
const std::vector<NDArray>& A_weights) {
CHECK(A.size() > 0) << "The list of graphs must not be empty."; CHECK(A.size() > 0) << "The list of graphs must not be empty.";
CHECK_EQ(A.size(), A_weights.size()) << CHECK_EQ(A.size(), A_weights.size())
"The list of edge weights must have the same length as the list of graphs."; << "The list of edge weights must have the same length as the list of "
"graphs.";
const auto ctx = A[0].indptr->ctx; const auto ctx = A[0].indptr->ctx;
const auto idtype = A[0].indptr->dtype; const auto idtype = A[0].indptr->dtype;
const auto dtype = A_weights[0]->dtype; const auto dtype = A_weights[0]->dtype;
const auto num_rows = A[0].num_rows; const auto num_rows = A[0].num_rows;
const auto num_cols = A[0].num_cols; const auto num_cols = A[0].num_cols;
for (size_t i = 0; i < A.size(); ++i) { for (size_t i = 0; i < A.size(); ++i) {
CHECK_EQ(A[i].indptr->ctx, ctx) << "The devices of all graphs must be equal."; CHECK_EQ(A[i].indptr->ctx, ctx)
CHECK_EQ(A[i].indptr->dtype, idtype) << "The ID types of all graphs must be equal."; << "The devices of all graphs must be equal.";
CHECK_EQ(A[i].indices->shape[0], A_weights[i]->shape[0]) << CHECK_EQ(A[i].indptr->dtype, idtype)
"Shape of edge weights does not match the number of edges."; << "The ID types of all graphs must be equal.";
CHECK_EQ(A_weights[i]->ctx, ctx) << CHECK_EQ(A[i].indices->shape[0], A_weights[i]->shape[0])
"The devices of edge weights must be the same as that of the graphs."; << "Shape of edge weights does not match the number of edges.";
CHECK_EQ(A_weights[i]->dtype, dtype) << CHECK_EQ(A_weights[i]->ctx, ctx) << "The devices of edge weights must be "
"The data types of all edge weights must be equal."; "the same as that of the graphs.";
CHECK_EQ(A[i].num_rows, num_rows) << "Graphs must have the same number of nodes."; CHECK_EQ(A_weights[i]->dtype, dtype)
CHECK_EQ(A[i].num_cols, num_cols) << "Graphs must have the same number of nodes."; << "The data types of all edge weights must be equal.";
CHECK_EQ(A[i].num_rows, num_rows)
<< "Graphs must have the same number of nodes.";
CHECK_EQ(A[i].num_cols, num_cols)
<< "Graphs must have the same number of nodes.";
} }
std::pair<CSRMatrix, NDArray> ret; std::pair<CSRMatrix, NDArray> ret;
...@@ -490,238 +476,246 @@ std::pair<CSRMatrix, NDArray> CSRSum( ...@@ -490,238 +476,246 @@ std::pair<CSRMatrix, NDArray> CSRSum(
} }
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0]; HeteroGraphRef graph = args[0];
const std::string op = args[1]; const std::string op = args[1];
const std::string reduce_op = args[2]; const std::string reduce_op = args[2];
NDArray U = args[3]; NDArray U = args[3];
NDArray E = args[4]; NDArray E = args[4];
NDArray V = args[5]; NDArray V = args[5];
NDArray ArgU = args[6]; NDArray ArgU = args[6];
NDArray ArgE = args[7]; NDArray ArgE = args[7];
CheckCtx(graph->Context(), {U, E, V, ArgU, ArgE}, CheckCtx(
{"U_data", "E_data", "out", "Arg_U", "Arg_E"}); graph->Context(), {U, E, V, ArgU, ArgE},
CheckContiguous({U, E, V, ArgU, ArgE}, {"U_data", "E_data", "out", "Arg_U", "Arg_E"});
{"U_data", "E_data", "out", "Arg_U", "Arg_E"}); CheckContiguous(
CHECK_EQ(graph->NumEdgeTypes(), 1); {U, E, V, ArgU, ArgE}, {"U_data", "E_data", "out", "Arg_U", "Arg_E"});
auto pair = graph->meta_graph()->FindEdge(0); // only one etype in the graph. CHECK_EQ(graph->NumEdgeTypes(), 1);
const dgl_type_t src_vtype = pair.first; auto pair =
const dgl_type_t dst_vtype = pair.second; graph->meta_graph()->FindEdge(0); // only one etype in the graph.
CheckShape( const dgl_type_t src_vtype = pair.first;
{graph->NumVertices(src_vtype), graph->NumEdges(0), graph->NumVertices(dst_vtype)}, const dgl_type_t dst_vtype = pair.second;
{0, 1, 2, 2, 2}, CheckShape(
{U, E, V, ArgU, ArgE}, {graph->NumVertices(src_vtype), graph->NumEdges(0),
{"U_data", "E_data", "out", "Arg_U", "Arg_E"}); graph->NumVertices(dst_vtype)},
SpMM(op, reduce_op, graph.sptr(), U, E, V, {ArgU, ArgE}); {0, 1, 2, 2, 2}, {U, E, V, ArgU, ArgE},
}); {"U_data", "E_data", "out", "Arg_U", "Arg_E"});
SpMM(op, reduce_op, graph.sptr(), U, E, V, {ArgU, ArgE});
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGATHERMM") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGATHERMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
NDArray A = args[0]; NDArray A = args[0];
NDArray B = args[1]; NDArray B = args[1];
NDArray C = args[2]; NDArray C = args[2];
NDArray idx_a = args[3]; NDArray idx_a = args[3];
NDArray idx_b = args[4]; NDArray idx_b = args[4];
GatherMM(A, B, C, idx_a, idx_b); GatherMM(A, B, C, idx_a, idx_b);
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGATHERMMSCATTER") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGATHERMMSCATTER")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
NDArray A = args[0]; NDArray A = args[0];
NDArray B = args[1]; NDArray B = args[1];
NDArray C = args[2]; NDArray C = args[2];
NDArray idx_a = args[3]; NDArray idx_a = args[3];
NDArray idx_b = args[4]; NDArray idx_b = args[4];
NDArray idx_c = args[5]; NDArray idx_c = args[5];
GatherMMScatter(A, B, C, idx_a, idx_b, idx_c); GatherMMScatter(A, B, C, idx_a, idx_b, idx_c);
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSEGMENTMM") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSEGMENTMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
NDArray A = args[0]; NDArray A = args[0];
NDArray B = args[1]; NDArray B = args[1];
NDArray C = args[2]; NDArray C = args[2];
NDArray seglen_A = args[3]; NDArray seglen_A = args[3];
bool A_trans = args[4]; bool A_trans = args[4];
bool B_trans = args[5]; bool B_trans = args[5];
SegmentMM(A, B, C, seglen_A, A_trans, B_trans); SegmentMM(A, B, C, seglen_A, A_trans, B_trans);
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSEGMENTMMBackwardB") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSEGMENTMMBackwardB")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
NDArray A = args[0]; NDArray A = args[0];
NDArray dC = args[1]; NDArray dC = args[1];
NDArray dB = args[2]; NDArray dB = args[2];
NDArray seglen = args[3]; NDArray seglen = args[3];
SegmentMMBackwardB(A, dC, dB, seglen); SegmentMMBackwardB(A, dC, dB, seglen);
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelEdge_softmax_forward") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelEdge_softmax_forward")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0]; HeteroGraphRef graph = args[0];
const std::string op = args[1]; const std::string op = args[1];
NDArray U = args[2]; NDArray U = args[2];
NDArray E = args[3]; NDArray E = args[3];
NDArray V = args[4]; NDArray V = args[4];
Edge_softmax_forward(op, graph.sptr(), U, E, V); Edge_softmax_forward(op, graph.sptr(), U, E, V);
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelEdge_softmax_backward") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelEdge_softmax_backward")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0]; HeteroGraphRef graph = args[0];
const std::string op = args[1]; const std::string op = args[1];
NDArray out = args[2]; NDArray out = args[2];
NDArray sds = args[3]; NDArray sds = args[3];
NDArray back_out = args[4]; NDArray back_out = args[4];
NDArray ufeat = args[5]; NDArray ufeat = args[5];
Edge_softmax_backward(op, graph.sptr(), out, sds, back_out, ufeat); Edge_softmax_backward(op, graph.sptr(), out, sds, back_out, ufeat);
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMMHetero") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMMHetero")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0]; HeteroGraphRef graph = args[0];
const std::string op = args[1]; const std::string op = args[1];
const std::string reduce_op = args[2]; const std::string reduce_op = args[2];
List<Value> list_U = args[3]; List<Value> list_U = args[3];
List<Value> list_E = args[4]; List<Value> list_E = args[4];
List<Value> list_V = args[5]; List<Value> list_V = args[5];
List<Value> list_ArgU = args[6]; List<Value> list_ArgU = args[6];
List<Value> list_ArgE = args[7]; List<Value> list_ArgE = args[7];
List<Value> list_ArgU_ntype = args[8]; List<Value> list_ArgU_ntype = args[8];
List<Value> list_ArgE_etype = args[9]; List<Value> list_ArgE_etype = args[9];
std::vector<std::vector<NDArray>> Arg_vec; // ArgU + ArgE std::vector<std::vector<NDArray>> Arg_vec; // ArgU + ArgE
for (int i = 0; i < 4; ++i) { // ArgU + ArgE + ArgU_ntype + ArgE_etype for (int i = 0; i < 4; ++i) { // ArgU + ArgE + ArgU_ntype + ArgE_etype
Arg_vec.push_back(std::vector<NDArray>()); Arg_vec.push_back(std::vector<NDArray>());
} }
std::vector<NDArray> U_vec = ListValueToVector<NDArray>(list_U); std::vector<NDArray> U_vec = ListValueToVector<NDArray>(list_U);
std::vector<NDArray> V_vec = ListValueToVector<NDArray>(list_V); std::vector<NDArray> V_vec = ListValueToVector<NDArray>(list_V);
std::vector<NDArray> E_vec = ListValueToVector<NDArray>(list_E); std::vector<NDArray> E_vec = ListValueToVector<NDArray>(list_E);
Arg_vec[0] = ListValueToVector<NDArray>(list_ArgU); Arg_vec[0] = ListValueToVector<NDArray>(list_ArgU);
Arg_vec[1] = ListValueToVector<NDArray>(list_ArgE); Arg_vec[1] = ListValueToVector<NDArray>(list_ArgE);
Arg_vec[2] = ListValueToVector<NDArray>(list_ArgU_ntype); Arg_vec[2] = ListValueToVector<NDArray>(list_ArgU_ntype);
Arg_vec[3] = ListValueToVector<NDArray>(list_ArgE_etype); Arg_vec[3] = ListValueToVector<NDArray>(list_ArgE_etype);
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype); auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_id_t src_id = pair.first; const dgl_id_t src_id = pair.first;
const dgl_id_t dst_id = pair.second; const dgl_id_t dst_id = pair.second;
NDArray U = (U_vec.size() == 0) ? NullArray() : U_vec[src_id]; NDArray U = (U_vec.size() == 0) ? NullArray() : U_vec[src_id];
NDArray E = (E_vec.size() == 0) ? NullArray() : E_vec[etype]; NDArray E = (E_vec.size() == 0) ? NullArray() : E_vec[etype];
CheckCtx(graph->Context(), {U, E, V_vec[dst_id], Arg_vec[0][dst_id], Arg_vec[1][dst_id]}, CheckCtx(
{"U_data", "E_data", "out", "Arg_U", "Arg_E"}); graph->Context(),
CheckContiguous({U, E, V_vec[dst_id], Arg_vec[0][dst_id], Arg_vec[1][dst_id]}, {U, E, V_vec[dst_id], Arg_vec[0][dst_id], Arg_vec[1][dst_id]},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"}); {"U_data", "E_data", "out", "Arg_U", "Arg_E"});
} CheckContiguous(
SpMMHetero(op, reduce_op, graph.sptr(), U_vec, E_vec, &V_vec, &Arg_vec); {U, E, V_vec[dst_id], Arg_vec[0][dst_id], Arg_vec[1][dst_id]},
}); {"U_data", "E_data", "out", "Arg_U", "Arg_E"});
}
SpMMHetero(op, reduce_op, graph.sptr(), U_vec, E_vec, &V_vec, &Arg_vec);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0]; HeteroGraphRef graph = args[0];
const std::string op = args[1]; const std::string op = args[1];
NDArray lhs = args[2]; NDArray lhs = args[2];
NDArray rhs = args[3]; NDArray rhs = args[3];
NDArray out = args[4]; NDArray out = args[4];
int lhs_target = args[5]; int lhs_target = args[5];
int rhs_target = args[6]; int rhs_target = args[6];
CheckCtx(graph->Context(), {lhs, rhs, out}, {"lhs", "rhs", "out"}); CheckCtx(graph->Context(), {lhs, rhs, out}, {"lhs", "rhs", "out"});
CheckContiguous({lhs, rhs, out}, {"lhs", "rhs", "out"}); CheckContiguous({lhs, rhs, out}, {"lhs", "rhs", "out"});
CHECK_EQ(graph->NumEdgeTypes(), 1); CHECK_EQ(graph->NumEdgeTypes(), 1);
auto pair = graph->meta_graph()->FindEdge(0); // only one etype in the graph. auto pair =
const dgl_type_t src_vtype = pair.first; graph->meta_graph()->FindEdge(0); // only one etype in the graph.
const dgl_type_t dst_vtype = pair.second; const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
CheckShape(
{graph->NumVertices(src_vtype), graph->NumEdges(0), graph->NumVertices(dst_vtype)}, CheckShape(
{lhs_target, rhs_target, 1}, {graph->NumVertices(src_vtype), graph->NumEdges(0),
{lhs, rhs, out}, graph->NumVertices(dst_vtype)},
{"U_data", "E_data", "V_data"}); {lhs_target, rhs_target, 1}, {lhs, rhs, out},
SDDMM(op, graph.sptr(), lhs, rhs, out, lhs_target, rhs_target); {"U_data", "E_data", "V_data"});
}); SDDMM(op, graph.sptr(), lhs, rhs, out, lhs_target, rhs_target);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMMHetero") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMMHetero")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0]; HeteroGraphRef graph = args[0];
const std::string op = args[1]; const std::string op = args[1];
List<Value> list_lhs = args[2]; List<Value> list_lhs = args[2];
List<Value> list_rhs = args[3]; List<Value> list_rhs = args[3];
List<Value> list_out = args[4]; List<Value> list_out = args[4];
int lhs_target = args[5]; int lhs_target = args[5];
int rhs_target = args[6]; int rhs_target = args[6];
std::vector<NDArray> vec_lhs; std::vector<NDArray> vec_lhs;
std::vector<NDArray> vec_rhs; std::vector<NDArray> vec_rhs;
std::vector<NDArray> vec_out; std::vector<NDArray> vec_out;
vec_lhs.reserve(list_lhs.size()); vec_lhs.reserve(list_lhs.size());
vec_rhs.reserve(list_rhs.size()); vec_rhs.reserve(list_rhs.size());
vec_out.reserve(list_out.size()); vec_out.reserve(list_out.size());
for (Value val : list_lhs) { for (Value val : list_lhs) {
vec_lhs.push_back(val->data); vec_lhs.push_back(val->data);
} }
for (Value val : list_rhs) { for (Value val : list_rhs) {
vec_rhs.push_back(val->data); vec_rhs.push_back(val->data);
} }
for (Value val : list_out) { for (Value val : list_out) {
vec_out.push_back(val->data); vec_out.push_back(val->data);
} }
SDDMMHetero(op, graph.sptr(), vec_lhs, vec_rhs, vec_out, lhs_target, rhs_target); SDDMMHetero(
}); op, graph.sptr(), vec_lhs, vec_rhs, vec_out, lhs_target, rhs_target);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSegmentReduce") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSegmentReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const std::string op = args[0]; const std::string op = args[0];
NDArray feat = args[1]; NDArray feat = args[1];
NDArray offsets = args[2]; NDArray offsets = args[2];
NDArray out = args[3]; NDArray out = args[3];
NDArray arg = args[4]; NDArray arg = args[4];
CheckCtx(feat->ctx, {feat, offsets, out}, {"feat", "offsets", "out"}); CheckCtx(feat->ctx, {feat, offsets, out}, {"feat", "offsets", "out"});
CheckContiguous({feat, offsets, out}, {"feat", "offsets", "out"}); CheckContiguous({feat, offsets, out}, {"feat", "offsets", "out"});
SegmentReduceDispatch(op, feat, offsets, out, arg); SegmentReduceDispatch(op, feat, offsets, out, arg);
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelScatterAdd") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelScatterAdd")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
NDArray feat = args[0]; NDArray feat = args[0];
NDArray idx = args[1]; NDArray idx = args[1];
NDArray out = args[2]; NDArray out = args[2];
CheckCtx(feat->ctx, {feat, idx, out}, {"feat", "idx", "out"}); CheckCtx(feat->ctx, {feat, idx, out}, {"feat", "idx", "out"});
CheckContiguous({feat, idx, out}, {"feat", "idx", "out"}); CheckContiguous({feat, idx, out}, {"feat", "idx", "out"});
ScatterAddDispatch(feat, idx, out); ScatterAddDispatch(feat, idx, out);
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelUpdateGradMinMaxHetero") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelUpdateGradMinMaxHetero")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0]; HeteroGraphRef graph = args[0];
const std::string op = args[1]; const std::string op = args[1];
List<Value> list_feat = args[2]; List<Value> list_feat = args[2];
List<Value> list_idx = args[3]; List<Value> list_idx = args[3];
List<Value> list_idx_etype = args[4]; List<Value> list_idx_etype = args[4];
List<Value> list_out = args[5]; List<Value> list_out = args[5];
std::vector<NDArray> vec_feat = ListValueToVector<NDArray>(list_feat); std::vector<NDArray> vec_feat = ListValueToVector<NDArray>(list_feat);
std::vector<NDArray> vec_idx = ListValueToVector<NDArray>(list_idx); std::vector<NDArray> vec_idx = ListValueToVector<NDArray>(list_idx);
std::vector<NDArray> vec_idx_etype = ListValueToVector<NDArray>(list_idx_etype); std::vector<NDArray> vec_idx_etype =
std::vector<NDArray> vec_out = ListValueToVector<NDArray>(list_out); ListValueToVector<NDArray>(list_idx_etype);
// CheckCtx(feat->ctx, {feat, idx, out}, {"feat", "idx", "out"}); std::vector<NDArray> vec_out = ListValueToVector<NDArray>(list_out);
// CheckContiguous({feat, idx, out}, {"feat", "idx", "out"}); // CheckCtx(feat->ctx, {feat, idx, out}, {"feat", "idx", "out"});
UpdateGradMinMaxDispatchHetero(graph.sptr(), op, vec_feat, vec_idx, vec_idx_etype, &vec_out); // CheckContiguous({feat, idx, out}, {"feat", "idx", "out"});
}); UpdateGradMinMaxDispatchHetero(
graph.sptr(), op, vec_feat, vec_idx, vec_idx_etype, &vec_out);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelBwdSegmentCmp") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelBwdSegmentCmp")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
NDArray feat = args[0]; NDArray feat = args[0];
NDArray arg = args[1]; NDArray arg = args[1];
NDArray out = args[2]; NDArray out = args[2];
CheckCtx(feat->ctx, {feat, arg, out}, {"feat", "arg", "out"}); CheckCtx(feat->ctx, {feat, arg, out}, {"feat", "arg", "out"});
CheckContiguous({feat, arg, out}, {"feat", "arg", "out"}); CheckContiguous({feat, arg, out}, {"feat", "arg", "out"});
BackwardSegmentCmpDispatch(feat, arg, out); BackwardSegmentCmpDispatch(feat, arg, out);
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGetEdgeMapping") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGetEdgeMapping")
.set_body([](DGLArgs args, DGLRetValue *rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0]; HeteroGraphRef graph = args[0];
*rv = GetEdgeMapping(graph); *rv = GetEdgeMapping(graph);
}); });
/** /**
* @brief Sparse matrix multiplication with graph interface. * @brief Sparse matrix multiplication with graph interface.
...@@ -734,107 +728,111 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGetEdgeMapping") ...@@ -734,107 +728,111 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGetEdgeMapping")
* @return A pair consisting of the new graph as well as its edge weights. * @return A pair consisting of the new graph as well as its edge weights.
*/ */
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRMM") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const HeteroGraphRef A_ref = args[0]; const HeteroGraphRef A_ref = args[0];
NDArray A_weights = args[1]; NDArray A_weights = args[1];
const HeteroGraphRef B_ref = args[2]; const HeteroGraphRef B_ref = args[2];
NDArray B_weights = args[3]; NDArray B_weights = args[3];
int num_vtypes = args[4]; int num_vtypes = args[4];
const HeteroGraphPtr A = A_ref.sptr();
const HeteroGraphPtr B = B_ref.sptr();
CHECK_EQ(A->NumEdgeTypes(), 1) << "The first graph must have only one edge type.";
CHECK_EQ(B->NumEdgeTypes(), 1) << "The second graph must have only one edge type.";
const auto A_csr = A->GetCSRMatrix(0);
const auto B_csr = B->GetCSRMatrix(0);
auto result = CSRMM(A_csr, A_weights, B_csr, B_weights);
List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(CreateFromCSR(num_vtypes, result.first, ALL_CODE)));
ret.push_back(Value(MakeValue(result.second)));
*rv = ret;
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRSum")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
List<HeteroGraphRef> A_refs = args[0];
List<Value> A_weights = args[1];
std::vector<NDArray> weights = ListValueToVector<NDArray>(A_weights);
std::vector<CSRMatrix> mats;
mats.reserve(A_refs.size());
int num_vtypes = 0;
for (auto A_ref : A_refs) {
const HeteroGraphPtr A = A_ref.sptr(); const HeteroGraphPtr A = A_ref.sptr();
CHECK_EQ(A->NumEdgeTypes(), 1) << "Graphs must have only one edge type."; const HeteroGraphPtr B = B_ref.sptr();
mats.push_back(A->GetCSRMatrix(0)); CHECK_EQ(A->NumEdgeTypes(), 1)
if (num_vtypes == 0) << "The first graph must have only one edge type.";
num_vtypes = A->NumVertexTypes(); CHECK_EQ(B->NumEdgeTypes(), 1)
} << "The second graph must have only one edge type.";
auto result = CSRSum(mats, weights); const auto A_csr = A->GetCSRMatrix(0);
const auto B_csr = B->GetCSRMatrix(0);
List<ObjectRef> ret; auto result = CSRMM(A_csr, A_weights, B_csr, B_weights);
ret.push_back(HeteroGraphRef(CreateFromCSR(num_vtypes, result.first, ALL_CODE)));
ret.push_back(Value(MakeValue(result.second))); List<ObjectRef> ret;
*rv = ret; ret.push_back(
}); HeteroGraphRef(CreateFromCSR(num_vtypes, result.first, ALL_CODE)));
ret.push_back(Value(MakeValue(result.second)));
*rv = ret;
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRSum")
.set_body([](DGLArgs args, DGLRetValue* rv) {
List<HeteroGraphRef> A_refs = args[0];
List<Value> A_weights = args[1];
std::vector<NDArray> weights = ListValueToVector<NDArray>(A_weights);
std::vector<CSRMatrix> mats;
mats.reserve(A_refs.size());
int num_vtypes = 0;
for (auto A_ref : A_refs) {
const HeteroGraphPtr A = A_ref.sptr();
CHECK_EQ(A->NumEdgeTypes(), 1)
<< "Graphs must have only one edge type.";
mats.push_back(A->GetCSRMatrix(0));
if (num_vtypes == 0) num_vtypes = A->NumVertexTypes();
}
auto result = CSRSum(mats, weights);
List<ObjectRef> ret;
ret.push_back(
HeteroGraphRef(CreateFromCSR(num_vtypes, result.first, ALL_CODE)));
ret.push_back(Value(MakeValue(result.second)));
*rv = ret;
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRMask") DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRMask")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const HeteroGraphRef A_ref = args[0]; const HeteroGraphRef A_ref = args[0];
NDArray A_weights = args[1]; NDArray A_weights = args[1];
const HeteroGraphRef B_ref = args[2]; const HeteroGraphRef B_ref = args[2];
const HeteroGraphPtr A = A_ref.sptr(); const HeteroGraphPtr A = A_ref.sptr();
const HeteroGraphPtr B = B_ref.sptr(); const HeteroGraphPtr B = B_ref.sptr();
CHECK_EQ(A->NumEdgeTypes(), 1) << "Both graphs must have only one edge type."; CHECK_EQ(A->NumEdgeTypes(), 1)
CHECK_EQ(B->NumEdgeTypes(), 1) << "Both graphs must have only one edge type."; << "Both graphs must have only one edge type.";
const CSRMatrix& A_csr = A->GetCSRMatrix(0); CHECK_EQ(B->NumEdgeTypes(), 1)
const COOMatrix& B_coo = B->GetCOOMatrix(0); << "Both graphs must have only one edge type.";
CHECK_EQ(A_csr.num_rows, B_coo.num_rows) << const CSRMatrix& A_csr = A->GetCSRMatrix(0);
"Both graphs must have the same number of nodes."; const COOMatrix& B_coo = B->GetCOOMatrix(0);
CHECK_EQ(A_csr.num_cols, B_coo.num_cols) << CHECK_EQ(A_csr.num_rows, B_coo.num_rows)
"Both graphs must have the same number of nodes."; << "Both graphs must have the same number of nodes.";
CHECK_EQ(A_csr.num_cols, B_coo.num_cols)
NDArray result; << "Both graphs must have the same number of nodes.";
ATEN_FLOAT_TYPE_SWITCH(A_weights->dtype, DType, "Edge weights", {
result = aten::CSRGetData<DType>(A_csr, B_coo.row, B_coo.col, A_weights, 0.); NDArray result;
}); ATEN_FLOAT_TYPE_SWITCH(A_weights->dtype, DType, "Edge weights", {
*rv = result; result =
}); aten::CSRGetData<DType>(A_csr, B_coo.row, B_coo.col, A_weights, 0.);
});
*rv = result;
});
#ifdef USE_TVM #ifdef USE_TVM
DGL_REGISTER_GLOBAL("sparse._CAPI_FG_LoadModule") DGL_REGISTER_GLOBAL("sparse._CAPI_FG_LoadModule")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
const std::string path = args[0]; const std::string path = args[0];
dgl::featgraph::LoadFeatGraphModule(path); dgl::featgraph::LoadFeatGraphModule(path);
}); });
DGL_REGISTER_GLOBAL("sparse._CAPI_FG_SDDMMTreeReduction") DGL_REGISTER_GLOBAL("sparse._CAPI_FG_SDDMMTreeReduction")
.set_body([] (DGLArgs args, DGLRetValue* rv) { .set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0]; HeteroGraphRef graph = args[0];
NDArray lhs = args[1]; NDArray lhs = args[1];
NDArray rhs = args[2]; NDArray rhs = args[2];
NDArray out = args[3]; NDArray out = args[3];
CheckCtx(graph->Context(), {lhs, rhs, out}, {"lhs", "rhs", "out"}); CheckCtx(graph->Context(), {lhs, rhs, out}, {"lhs", "rhs", "out"});
CheckContiguous({lhs, rhs, out}, {"lhs", "rhs", "out"}); CheckContiguous({lhs, rhs, out}, {"lhs", "rhs", "out"});
CHECK_EQ(graph->NumEdgeTypes(), 1); CHECK_EQ(graph->NumEdgeTypes(), 1);
// auto pair = graph->meta_graph()->FindEdge(0); // only one etype in the graph. // auto pair = graph->meta_graph()->FindEdge(0); // only one etype in the
// const dgl_type_t src_vtype = pair.first; // graph. const dgl_type_t src_vtype = pair.first; const dgl_type_t
// const dgl_type_t dst_vtype = pair.second; // dst_vtype = pair.second; CheckShape(
// CheckShape( // {graph->NumVertices(src_vtype), graph->NumEdges(0),
// {graph->NumVertices(src_vtype), graph->NumEdges(0), graph->NumVertices(dst_vtype)}, // graph->NumVertices(dst_vtype)}, {lhs_target, rhs_target, 1}, {lhs,
// {lhs_target, rhs_target, 1}, // rhs, out},
// {lhs, rhs, out}, // {"U_data", "E_data", "V_data"});
// {"U_data", "E_data", "V_data"}); COOMatrix coo = graph.sptr()->GetCOOMatrix(0);
COOMatrix coo = graph.sptr()->GetCOOMatrix(0); dgl::featgraph::SDDMMTreeReduction(
dgl::featgraph::SDDMMTreeReduction( DLPackConvert::ToDLPack(coo.row), DLPackConvert::ToDLPack(coo.col),
DLPackConvert::ToDLPack(coo.row), DLPackConvert::ToDLPack(lhs), DLPackConvert::ToDLPack(rhs),
DLPackConvert::ToDLPack(coo.col), DLPackConvert::ToDLPack(out));
DLPackConvert::ToDLPack(lhs), });
DLPackConvert::ToDLPack(rhs),
DLPackConvert::ToDLPack(out));
});
#endif // USE_TVM #endif // USE_TVM
} // namespace aten } // namespace aten
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment