"vscode:/vscode.git/clone" did not exist on "6fe30cc908413dc1a4ed4918340f2a3b7af856d2"
Unverified Commit 8ac27dad authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] clang-format auto fix. (#4824)



* [Misc] clang-format auto fix.

* blabla

* ablabla

* blabla
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent bcd37684
......@@ -7,9 +7,11 @@
#define DGL_ARRAY_CUDA_ATOMIC_CUH_
#include <cuda_runtime.h>
#include <cassert>
#include "fp16.cuh"
#include "bf16.cuh"
#include "fp16.cuh"
#if __CUDA_ARCH__ >= 600
#include <cuda_fp16.h>
......@@ -20,22 +22,27 @@ namespace aten {
namespace cuda {
// Type trait for selecting code type
template <int Bytes> struct Code { };
template <int Bytes>
struct Code {};
template <> struct Code<2> {
template <>
struct Code<2> {
typedef unsigned short int Type; // NOLINT
};
template <> struct Code<4> {
template <>
struct Code<4> {
typedef unsigned int Type; // NOLINT
};
template <> struct Code<8> {
template <>
struct Code<8> {
typedef unsigned long long int Type; // NOLINT
};
// Helper class for converting to/from atomicCAS compatible types.
template <typename T> struct Cast {
template <typename T>
struct Cast {
typedef typename Code<sizeof(T)>::Type Type;
static __device__ __forceinline__ Type Encode(T val) {
return static_cast<Type>(val);
......@@ -45,7 +52,8 @@ template <typename T> struct Cast {
}
};
template <> struct Cast<half> {
template <>
struct Cast<half> {
typedef Code<sizeof(half)>::Type Type;
static __device__ __forceinline__ Type Encode(half val) {
return __half_as_ushort(val);
......@@ -56,13 +64,15 @@ template <> struct Cast<half> {
};
#if BF16_ENABLED
template <> struct Cast<__nv_bfloat16> {
template <>
struct Cast<__nv_bfloat16> {
typedef Code<sizeof(__nv_bfloat16)>::Type Type;
static __device__ __forceinline__ Type Encode(__nv_bfloat16 val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __bfloat16_as_ushort(val);
#else
printf("Atomic operations are not supported for bfloat16 (BF16) "
printf(
"Atomic operations are not supported for bfloat16 (BF16) "
"on GPUs with compute capability less than 8.0.\n");
__trap();
return static_cast<Type>(0);
......@@ -72,7 +82,8 @@ template <> struct Cast<__nv_bfloat16> {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __ushort_as_bfloat16(code);
#else
printf("Atomic operations are not supported for bfloat16 (BF16) "
printf(
"Atomic operations are not supported for bfloat16 (BF16) "
"on GPUs with compute capability less than 8.0.\n");
__trap();
return static_cast<__nv_bfloat16>(0.0f);
......@@ -81,7 +92,8 @@ template <> struct Cast<__nv_bfloat16> {
};
#endif // BF16_ENABLED
template <> struct Cast<float> {
template <>
struct Cast<float> {
typedef Code<sizeof(float)>::Type Type;
static __device__ __forceinline__ Type Encode(float val) {
return __float_as_uint(val);
......@@ -91,7 +103,8 @@ template <> struct Cast<float> {
}
};
template <> struct Cast<double> {
template <>
struct Cast<double> {
typedef Code<sizeof(double)>::Type Type;
static __device__ __forceinline__ Type Encode(double val) {
return __double_as_longlong(val);
......@@ -102,7 +115,7 @@ template <> struct Cast<double> {
};
static __device__ __forceinline__ unsigned short int atomicCASshort( // NOLINT
unsigned short int *address, // NOLINT
unsigned short int* address, // NOLINT
unsigned short int compare, // NOLINT
unsigned short int val) { // NOLINT
static_assert(CUDART_VERSION >= 10000, "Requires at least CUDA 10");
......@@ -112,41 +125,45 @@ static __device__ __forceinline__ unsigned short int atomicCASshort( // NOLINT
(void)address;
(void)compare;
(void)val;
printf("Atomic operations are not supported for half precision (FP16) "
printf(
"Atomic operations are not supported for half precision (FP16) "
"on this GPU.\n");
__trap();
return val;
#endif // (defined(__CUDA_ARCH__) && (__CUDA_ARCH__) >= 700)
}
#define DEFINE_ATOMIC(NAME) \
template <typename T> \
__device__ __forceinline__ T Atomic##NAME(T* addr, T val) { \
typedef typename Cast<T>::Type CT; \
CT* addr_as_ui = reinterpret_cast<CT*>(addr); \
CT old = *addr_as_ui; \
CT assumed = old; \
do { \
assumed = old; \
old = atomicCAS(addr_as_ui, assumed, \
Cast<T>::Encode(OP(val, Cast<T>::Decode(old)))); \
} while (assumed != old); \
return Cast<T>::Decode(old); \
#define DEFINE_ATOMIC(NAME) \
template <typename T> \
__device__ __forceinline__ T Atomic##NAME(T* addr, T val) { \
typedef typename Cast<T>::Type CT; \
CT* addr_as_ui = reinterpret_cast<CT*>(addr); \
CT old = *addr_as_ui; \
CT assumed = old; \
do { \
assumed = old; \
old = atomicCAS( \
addr_as_ui, assumed, \
Cast<T>::Encode(OP(val, Cast<T>::Decode(old)))); \
} while (assumed != old); \
return Cast<T>::Decode(old); \
}
#define DEFINE_ATOMIC_16BIT(NAME, dtype) \
template <> \
__device__ __forceinline__ dtype Atomic##NAME<dtype>(dtype* addr, dtype val) { \
typedef uint16_t CT; \
CT* addr_as_ui = reinterpret_cast<CT*>(addr); \
CT old = *addr_as_ui; \
CT assumed = old; \
do { \
assumed = old; \
old = atomicCASshort(addr_as_ui, assumed, \
#define DEFINE_ATOMIC_16BIT(NAME, dtype) \
template <> \
__device__ __forceinline__ dtype Atomic##NAME<dtype>( \
dtype * addr, dtype val) { \
typedef uint16_t CT; \
CT* addr_as_ui = reinterpret_cast<CT*>(addr); \
CT old = *addr_as_ui; \
CT assumed = old; \
do { \
assumed = old; \
old = atomicCASshort( \
addr_as_ui, assumed, \
Cast<dtype>::Encode(OP(val, Cast<dtype>::Decode(old)))); \
} while (assumed != old); \
return Cast<dtype>::Decode(old); \
} while (assumed != old); \
return Cast<dtype>::Decode(old); \
}
#define OP(a, b) max(a, b)
......@@ -169,84 +186,72 @@ DEFINE_ATOMIC_16BIT(Min, __nv_bfloat16)
DEFINE_ATOMIC(Add)
#undef OP
/**
* @brief Performs an atomic compare-and-swap on 64 bit integers. That is,
* it the word `old` at the memory location `address`, computes
* `(old == compare ? val : old)` , and stores the result back to memory at
* the same address.
*
* @param address The address to perform the atomic operation on.
* @param compare The value to compare to.
* @param val The new value to conditionally store.
*
* @return The old value at the address.
*/
inline __device__ int64_t AtomicCAS(
int64_t * const address,
const int64_t compare,
const int64_t val) {
* @brief Performs an atomic compare-and-swap on 64 bit integers. That is,
* it the word `old` at the memory location `address`, computes
* `(old == compare ? val : old)` , and stores the result back to memory at
* the same address.
*
* @param address The address to perform the atomic operation on.
* @param compare The value to compare to.
* @param val The new value to conditionally store.
*
* @return The old value at the address.
*/
inline __device__ int64_t
AtomicCAS(int64_t* const address, const int64_t compare, const int64_t val) {
// match the type of "::atomicCAS", so ignore lint warning
using Type = unsigned long long int; // NOLINT
using Type = unsigned long long int; // NOLINT
static_assert(sizeof(Type) == sizeof(*address), "Type width must match");
return atomicCAS(reinterpret_cast<Type*>(address),
static_cast<Type>(compare),
static_cast<Type>(val));
return atomicCAS(
reinterpret_cast<Type*>(address), static_cast<Type>(compare),
static_cast<Type>(val));
}
/**
* @brief Performs an atomic compare-and-swap on 32 bit integers. That is,
* it the word `old` at the memory location `address`, computes
* `(old == compare ? val : old)` , and stores the result back to memory at
* the same address.
*
* @param address The address to perform the atomic operation on.
* @param compare The value to compare to.
* @param val The new value to conditionally store.
*
* @return The old value at the address.
*/
inline __device__ int32_t AtomicCAS(
int32_t * const address,
const int32_t compare,
const int32_t val) {
* @brief Performs an atomic compare-and-swap on 32 bit integers. That is,
* it the word `old` at the memory location `address`, computes
* `(old == compare ? val : old)` , and stores the result back to memory at
* the same address.
*
* @param address The address to perform the atomic operation on.
* @param compare The value to compare to.
* @param val The new value to conditionally store.
*
* @return The old value at the address.
*/
inline __device__ int32_t
AtomicCAS(int32_t* const address, const int32_t compare, const int32_t val) {
// match the type of "::atomicCAS", so ignore lint warning
using Type = int; // NOLINT
using Type = int; // NOLINT
static_assert(sizeof(Type) == sizeof(*address), "Type width must match");
return atomicCAS(reinterpret_cast<Type*>(address),
static_cast<Type>(compare),
static_cast<Type>(val));
return atomicCAS(
reinterpret_cast<Type*>(address), static_cast<Type>(compare),
static_cast<Type>(val));
}
inline __device__ int64_t AtomicMax(
int64_t * const address,
const int64_t val) {
inline __device__ int64_t AtomicMax(int64_t* const address, const int64_t val) {
// match the type of "::atomicCAS", so ignore lint warning
using Type = unsigned long long int; // NOLINT
using Type = unsigned long long int; // NOLINT
static_assert(sizeof(Type) == sizeof(*address), "Type width must match");
return atomicMax(reinterpret_cast<Type*>(address),
static_cast<Type>(val));
return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));
}
inline __device__ int32_t AtomicMax(
int32_t * const address,
const int32_t val) {
inline __device__ int32_t AtomicMax(int32_t* const address, const int32_t val) {
// match the type of "::atomicCAS", so ignore lint warning
using Type = int; // NOLINT
using Type = int; // NOLINT
static_assert(sizeof(Type) == sizeof(*address), "Type width must match");
return atomicMax(reinterpret_cast<Type*>(address),
static_cast<Type>(val));
return atomicMax(reinterpret_cast<Type*>(address), static_cast<Type>(val));
}
template <>
__device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) {
#if __CUDA_ARCH__ >= 200
......@@ -259,8 +264,8 @@ __device__ __forceinline__ float AtomicAdd<float>(float* addr, float val) {
CT assumed = old;
do {
assumed = old;
old = atomicCAS(addr_as_ui, assumed,
Cast<T>::Encode(Cast<T>::Decode(old) + val));
old = atomicCAS(
addr_as_ui, assumed, Cast<T>::Encode(Cast<T>::Decode(old) + val));
} while (assumed != old);
return Cast<T>::Decode(old);
#endif // __CUDA_ARCH__
......@@ -278,8 +283,8 @@ __device__ __forceinline__ double AtomicAdd<double>(double* addr, double val) {
CT assumed = old;
do {
assumed = old;
old = atomicCAS(addr_as_ui, assumed,
Cast<T>::Encode(Cast<T>::Decode(old) + val));
old = atomicCAS(
addr_as_ui, assumed, Cast<T>::Encode(Cast<T>::Decode(old) + val));
} while (assumed != old);
return Cast<T>::Decode(old);
#endif
......@@ -294,7 +299,8 @@ __device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
#else
(void)addr;
(void)val;
printf("Atomic operations are not supported for half precision (FP16) "
printf(
"Atomic operations are not supported for half precision (FP16) "
"on this GPU.\n");
__trap();
return val;
......@@ -304,15 +310,16 @@ __device__ __forceinline__ half AtomicAdd<half>(half* addr, half val) {
#if BF16_ENABLED
template <>
__device__ __forceinline__ __nv_bfloat16 AtomicAdd<__nv_bfloat16>(
__nv_bfloat16* addr, __nv_bfloat16 val) {
__device__ __forceinline__ __nv_bfloat16
AtomicAdd<__nv_bfloat16>(__nv_bfloat16* addr, __nv_bfloat16 val) {
// make sure we have bfloat16 support
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return atomicAdd(addr, val);
#else
(void)addr;
(void)val;
printf("Atomic operations are not supported for bfloat16 (BF16) "
printf(
"Atomic operations are not supported for bfloat16 (BF16) "
"on GPUs with compute capability less than 8.0.\n");
__trap();
return val;
......@@ -320,7 +327,6 @@ __device__ __forceinline__ __nv_bfloat16 AtomicAdd<__nv_bfloat16>(
}
#endif // BF16_ENABLED
} // namespace cuda
} // namespace aten
} // namespace dgl
......
......@@ -4,9 +4,11 @@
* @brief Retrieve entries of a CSR matrix
*/
#include <dgl/array.h>
#include <vector>
#include <unordered_set>
#include <numeric>
#include <unordered_set>
#include <vector>
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
......@@ -19,56 +21,57 @@ namespace impl {
template <DGLDeviceType XPU, typename IdType, typename DType>
NDArray CSRGetData(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) {
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, DType filler) {
const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0];
CHECK((rowlen == collen) || (rowlen == 1) || (collen == 1))
<< "Invalid row and col id array.";
<< "Invalid row and col id array.";
const int64_t row_stride = (rowlen == 1 && collen != 1) ? 0 : 1;
const int64_t col_stride = (collen == 1 && rowlen != 1) ? 0 : 1;
const int64_t rstlen = std::max(rowlen, collen);
IdArray rst = NDArray::Empty({rstlen}, weights->dtype, rows->ctx);
if (rstlen == 0)
return rst;
if (rstlen == 0) return rst;
cudaStream_t stream = runtime::getCurrentCUDAStream();
const int nt = cuda::FindNumThreads(rstlen);
const int nb = (rstlen + nt - 1) / nt;
if (return_eids)
BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype) <<
"DType does not match row's dtype.";
BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype)
<< "DType does not match row's dtype.";
const IdType* indptr_data = csr.indptr.Ptr<IdType>();
const IdType* indices_data = csr.indices.Ptr<IdType>();
const IdType* data_data = CSRHasData(csr) ? csr.data.Ptr<IdType>() : nullptr;
if (csr.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer(
&indptr_data, csr.indptr.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer(
&indices_data, csr.indices.Ptr<IdType>(), 0));
CUDA_CALL(
cudaHostGetDevicePointer(&indptr_data, csr.indptr.Ptr<IdType>(), 0));
CUDA_CALL(
cudaHostGetDevicePointer(&indices_data, csr.indices.Ptr<IdType>(), 0));
if (CSRHasData(csr)) {
CUDA_CALL(cudaHostGetDevicePointer(
&data_data, csr.data.Ptr<IdType>(), 0));
CUDA_CALL(
cudaHostGetDevicePointer(&data_data, csr.data.Ptr<IdType>(), 0));
}
}
// TODO(minjie): use binary search for sorted csr
CUDA_KERNEL_CALL(cuda::_LinearSearchKernel,
nb, nt, 0, stream,
indptr_data, indices_data, data_data,
rows.Ptr<IdType>(), cols.Ptr<IdType>(),
row_stride, col_stride, rstlen,
return_eids ? nullptr : weights.Ptr<DType>(), filler, rst.Ptr<DType>());
CUDA_KERNEL_CALL(
cuda::_LinearSearchKernel, nb, nt, 0, stream, indptr_data, indices_data,
data_data, rows.Ptr<IdType>(), cols.Ptr<IdType>(), row_stride, col_stride,
rstlen, return_eids ? nullptr : weights.Ptr<DType>(), filler,
rst.Ptr<DType>());
return rst;
}
template NDArray CSRGetData<kDGLCUDA, int32_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, __half filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, __half filler);
#if BF16_ENABLED
template NDArray CSRGetData<kDGLCUDA, int32_t, __nv_bfloat16>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
......@@ -78,19 +81,25 @@ template NDArray CSRGetData<kDGLCUDA, int64_t, __nv_bfloat16>(
NDArray weights, __nv_bfloat16 filler);
#endif // BF16_ENABLED
template NDArray CSRGetData<kDGLCUDA, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, float filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, float filler);
template NDArray CSRGetData<kDGLCUDA, int32_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, double filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, double filler);
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
template NDArray CSRGetData<kDGLCUDA, int32_t, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler);
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, int32_t filler);
template NDArray CSRGetData<kDGLCUDA, int64_t, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler);
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids,
NDArray weights, int64_t filler);
} // namespace impl
} // namespace aten
......
......@@ -5,9 +5,10 @@
*/
#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
#include "./functor.cuh"
#include "./cusparse_dispatcher.cuh"
#include "../../runtime/cuda/cuda_common.h"
#include "./cusparse_dispatcher.cuh"
#include "./functor.cuh"
namespace dgl {
......@@ -16,7 +17,8 @@ using namespace dgl::runtime;
namespace aten {
namespace cusparse {
#if 0 // disabling CUDA 11.0+ implementation for now because of problems on bigger graphs
#if 0 // disabling CUDA 11.0+ implementation for now because of problems on
// bigger graphs
/** @brief Cusparse implementation of SpGEMM on Csr format for CUDA 11.0+ */
template <typename DType, typename IdType>
......@@ -125,14 +127,13 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
dC_weights};
}
#else // __CUDACC_VER_MAJOR__ != 11
#else // __CUDACC_VER_MAJOR__ != 11
/** @brief Cusparse implementation of SpGEMM on Csr format for older CUDA versions */
/** @brief Cusparse implementation of SpGEMM on Csr format for older CUDA
* versions */
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseSpgemm(
const CSRMatrix& A,
const NDArray A_weights_array,
const CSRMatrix& B,
const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,
const NDArray B_weights_array) {
int nnzC;
csrgemm2Info_t info = nullptr;
......@@ -164,36 +165,30 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
CUSPARSE_CALL(cusparseCreateMatDescr(&matA));
CUSPARSE_CALL(cusparseCreateMatDescr(&matB));
CUSPARSE_CALL(cusparseCreateMatDescr(&matC));
CUSPARSE_CALL(cusparseCreateMatDescr(&matD)); // needed even if D is null
CUSPARSE_CALL(cusparseCreateMatDescr(&matD)); // needed even if D is null
CUSPARSE_CALL(CSRGEMM<DType>::bufferSizeExt(thr_entry->cusparse_handle,
m, n, k, &alpha,
matA, nnzA, A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(),
matB, nnzB, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(),
nullptr,
matD, 0, nullptr, nullptr,
info,
&workspace_size));
CUSPARSE_CALL(CSRGEMM<DType>::bufferSizeExt(
thr_entry->cusparse_handle, m, n, k, &alpha, matA, nnzA,
A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), matB, nnzB,
B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), nullptr, matD, 0,
nullptr, nullptr, info, &workspace_size));
void *workspace = device->AllocWorkspace(ctx, workspace_size);
void* workspace = device->AllocWorkspace(ctx, workspace_size);
IdArray C_indptr = IdArray::Empty({m + 1}, idtype, ctx);
CUSPARSE_CALL(CSRGEMM<DType>::nnz(thr_entry->cusparse_handle,
m, n, k,
matA, nnzA, A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(),
matB, nnzB, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(),
matD, 0, nullptr, nullptr,
matC, C_indptr.Ptr<IdType>(), &nnzC, info, workspace));
CUSPARSE_CALL(CSRGEMM<DType>::nnz(
thr_entry->cusparse_handle, m, n, k, matA, nnzA, A.indptr.Ptr<IdType>(),
A.indices.Ptr<IdType>(), matB, nnzB, B.indptr.Ptr<IdType>(),
B.indices.Ptr<IdType>(), matD, 0, nullptr, nullptr, matC,
C_indptr.Ptr<IdType>(), &nnzC, info, workspace));
IdArray C_indices = IdArray::Empty({nnzC}, idtype, ctx);
NDArray C_weights = NDArray::Empty({nnzC}, dtype, ctx);
CUSPARSE_CALL(CSRGEMM<DType>::compute(thr_entry->cusparse_handle,
m, n, k, &alpha,
matA, nnzA, A_weights, A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(),
matB, nnzB, B_weights, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(),
nullptr,
matD, 0, nullptr, nullptr, nullptr,
matC, C_weights.Ptr<DType>(), C_indptr.Ptr<IdType>(), C_indices.Ptr<IdType>(),
info, workspace));
CUSPARSE_CALL(CSRGEMM<DType>::compute(
thr_entry->cusparse_handle, m, n, k, &alpha, matA, nnzA, A_weights,
A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), matB, nnzB, B_weights,
B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), nullptr, matD, 0,
nullptr, nullptr, nullptr, matC, C_weights.Ptr<DType>(),
C_indptr.Ptr<IdType>(), C_indices.Ptr<IdType>(), info, workspace));
device->FreeWorkspace(ctx, workspace);
CUSPARSE_CALL(cusparseDestroyCsrgemm2Info(info));
......@@ -203,7 +198,8 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
CUSPARSE_CALL(cusparseDestroyMatDescr(matD));
return {
CSRMatrix(m, k, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)),
CSRMatrix(
m, k, C_indptr, C_indices, NullArray(C_indptr->dtype, C_indptr->ctx)),
C_weights};
}
......@@ -212,9 +208,7 @@ std::pair<CSRMatrix, NDArray> CusparseSpgemm(
template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRMM(
const CSRMatrix& A,
NDArray A_weights,
const CSRMatrix& B,
const CSRMatrix& A, NDArray A_weights, const CSRMatrix& B,
NDArray B_weights) {
auto ctx = A.indptr->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
......@@ -224,20 +218,18 @@ std::pair<CSRMatrix, NDArray> CSRMM(
// Cast 64 bit indices to 32 bit.
if (A.indptr->dtype.bits == 64) {
newA = CSRMatrix(
A.num_rows, A.num_cols,
AsNumBits(A.indptr, 32), AsNumBits(A.indices, 32), AsNumBits(A.data, 32));
A.num_rows, A.num_cols, AsNumBits(A.indptr, 32),
AsNumBits(A.indices, 32), AsNumBits(A.data, 32));
newB = CSRMatrix(
B.num_rows, B.num_cols,
AsNumBits(B.indptr, 32), AsNumBits(B.indices, 32), AsNumBits(B.data, 32));
B.num_rows, B.num_cols, AsNumBits(B.indptr, 32),
AsNumBits(B.indices, 32), AsNumBits(B.data, 32));
cast = true;
}
// Reorder weights if A or B has edge IDs
NDArray newA_weights, newB_weights;
if (CSRHasData(A))
newA_weights = IndexSelect(A_weights, A.data);
if (CSRHasData(B))
newB_weights = IndexSelect(B_weights, B.data);
if (CSRHasData(A)) newA_weights = IndexSelect(A_weights, A.data);
if (CSRHasData(B)) newB_weights = IndexSelect(B_weights, B.data);
auto result = cusparse::CusparseSpgemm<DType, int32_t>(
cast ? newA : A, CSRHasData(A) ? newA_weights : A_weights,
......@@ -247,9 +239,10 @@ std::pair<CSRMatrix, NDArray> CSRMM(
if (cast) {
CSRMatrix C = result.first;
return {
CSRMatrix(C.num_rows, C.num_cols, AsNumBits(C.indptr, 64), AsNumBits(C.indices, 64),
AsNumBits(C.data, 64)),
result.second};
CSRMatrix(
C.num_rows, C.num_cols, AsNumBits(C.indptr, 64),
AsNumBits(C.indices, 64), AsNumBits(C.data, 64)),
result.second};
} else {
return result;
}
......
......@@ -5,9 +5,10 @@
*/
#include <dgl/array.h>
#include <dgl/runtime/device_api.h>
#include "./functor.cuh"
#include "./cusparse_dispatcher.cuh"
#include "../../runtime/cuda/cuda_common.h"
#include "./cusparse_dispatcher.cuh"
#include "./functor.cuh"
namespace dgl {
......@@ -19,9 +20,7 @@ namespace cusparse {
/** Cusparse implementation of SpSum on Csr format. */
template <typename DType, typename IdType>
std::pair<CSRMatrix, NDArray> CusparseCsrgeam2(
const CSRMatrix& A,
const NDArray A_weights_array,
const CSRMatrix& B,
const CSRMatrix& A, const NDArray A_weights_array, const CSRMatrix& B,
const NDArray B_weights_array) {
const int m = A.num_rows;
const int n = A.num_cols;
......@@ -46,7 +45,8 @@ std::pair<CSRMatrix, NDArray> CusparseCsrgeam2(
CUSPARSE_CALL(cusparseCreateMatDescr(&matB));
CUSPARSE_CALL(cusparseCreateMatDescr(&matC));
cusparseSetPointerMode(thr_entry->cusparse_handle, CUSPARSE_POINTER_MODE_HOST);
cusparseSetPointerMode(
thr_entry->cusparse_handle, CUSPARSE_POINTER_MODE_HOST);
size_t workspace_size = 0;
/* prepare output C */
IdArray dC_csrOffsets = IdArray::Empty({m + 1}, A.indptr->dtype, ctx);
......@@ -57,25 +57,16 @@ std::pair<CSRMatrix, NDArray> CusparseCsrgeam2(
DType* dC_weights_data = dC_weights.Ptr<DType>();
/* prepare buffer */
CUSPARSE_CALL(CSRGEAM<DType>::bufferSizeExt(
thr_entry->cusparse_handle, m, n, &alpha,
matA, nnzA, A_weights,
A.indptr.Ptr<IdType>(),
A.indices.Ptr<IdType>(),
&beta, matB, nnzB, B_weights,
B.indptr.Ptr<IdType>(),
B.indices.Ptr<IdType>(),
matC, dC_weights_data, dC_csrOffsets_data, dC_columns_data,
&workspace_size));
void *workspace = device->AllocWorkspace(ctx, workspace_size);
CUSPARSE_CALL(CSRGEAM<DType>::nnz(thr_entry->cusparse_handle,
m, n, matA, nnzA,
A.indptr.Ptr<IdType>(),
A.indices.Ptr<IdType>(),
matB, nnzB,
B.indptr.Ptr<IdType>(),
B.indices.Ptr<IdType>(),
matC, dC_csrOffsets_data, &nnzC, workspace));
thr_entry->cusparse_handle, m, n, &alpha, matA, nnzA, A_weights,
A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), &beta, matB, nnzB,
B_weights, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), matC,
dC_weights_data, dC_csrOffsets_data, dC_columns_data, &workspace_size));
void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUSPARSE_CALL(CSRGEAM<DType>::nnz(
thr_entry->cusparse_handle, m, n, matA, nnzA, A.indptr.Ptr<IdType>(),
A.indices.Ptr<IdType>(), matB, nnzB, B.indptr.Ptr<IdType>(),
B.indices.Ptr<IdType>(), matC, dC_csrOffsets_data, &nnzC, workspace));
dC_columns = IdArray::Empty({nnzC}, A.indptr->dtype, ctx);
dC_weights = NDArray::Empty({nnzC}, A_weights_array->dtype, ctx);
......@@ -83,15 +74,10 @@ std::pair<CSRMatrix, NDArray> CusparseCsrgeam2(
dC_weights_data = dC_weights.Ptr<DType>();
CUSPARSE_CALL(CSRGEAM<DType>::compute(
thr_entry->cusparse_handle, m, n, &alpha,
matA, nnzA, A_weights,
A.indptr.Ptr<IdType>(),
A.indices.Ptr<IdType>(),
&beta, matB, nnzB, B_weights,
B.indptr.Ptr<IdType>(),
B.indices.Ptr<IdType>(),
matC, dC_weights_data, dC_csrOffsets_data, dC_columns_data,
workspace));
thr_entry->cusparse_handle, m, n, &alpha, matA, nnzA, A_weights,
A.indptr.Ptr<IdType>(), A.indices.Ptr<IdType>(), &beta, matB, nnzB,
B_weights, B.indptr.Ptr<IdType>(), B.indices.Ptr<IdType>(), matC,
dC_weights_data, dC_csrOffsets_data, dC_columns_data, workspace));
device->FreeWorkspace(ctx, workspace);
// destroy matrix/vector descriptors
......@@ -99,16 +85,16 @@ std::pair<CSRMatrix, NDArray> CusparseCsrgeam2(
CUSPARSE_CALL(cusparseDestroyMatDescr(matB));
CUSPARSE_CALL(cusparseDestroyMatDescr(matC));
return {
CSRMatrix(A.num_rows, A.num_cols, dC_csrOffsets, dC_columns,
NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx), true),
dC_weights};
CSRMatrix(
A.num_rows, A.num_cols, dC_csrOffsets, dC_columns,
NullArray(dC_csrOffsets->dtype, dC_csrOffsets->ctx), true),
dC_weights};
}
} // namespace cusparse
template <int XPU, typename IdType, typename DType>
std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& As,
const std::vector<NDArray>& A_weights) {
const std::vector<CSRMatrix>& As, const std::vector<NDArray>& A_weights) {
const int64_t M = As[0].num_rows;
const int64_t N = As[0].num_cols;
const int64_t n = As.size();
......@@ -120,19 +106,18 @@ std::pair<CSRMatrix, NDArray> CSRSum(
if (As[0].indptr->dtype.bits == 64) {
for (int i = 0; i < n; ++i)
newAs.emplace_back(
As[i].num_rows, As[i].num_cols, AsNumBits(As[i].indptr, 32),
AsNumBits(As[i].indices, 32), AsNumBits(As[i].data, 32));
As[i].num_rows, As[i].num_cols, AsNumBits(As[i].indptr, 32),
AsNumBits(As[i].indices, 32), AsNumBits(As[i].data, 32));
cast = true;
} else {
for (int i = 0; i < n; ++i)
newAs.push_back(As[i]);
for (int i = 0; i < n; ++i) newAs.push_back(As[i]);
}
// cuSPARSE csrgeam2 requires the CSR to be sorted.
// TODO(BarclayII): ideally the sorted CSR should be cached but I'm not sure how to do it.
// TODO(BarclayII): ideally the sorted CSR should be cached but I'm not sure
// how to do it.
for (int i = 0; i < n; ++i) {
if (!newAs[i].sorted)
newAs[i] = CSRSort(newAs[i]);
if (!newAs[i].sorted) newAs[i] = CSRSort(newAs[i]);
}
// Reorder weights if A[i] has edge IDs
......@@ -147,10 +132,11 @@ std::pair<CSRMatrix, NDArray> CSRSum(
// Loop and sum
auto result = std::make_pair(
CSRMatrix(
newAs[0].num_rows, newAs[0].num_cols,
newAs[0].indptr, newAs[0].indices,
NullArray(newAs[0].indptr->dtype, newAs[0].indptr->ctx)),
A_weights_reordered[0]); // Weights already reordered so we don't need As[0].data
newAs[0].num_rows, newAs[0].num_cols, newAs[0].indptr,
newAs[0].indices,
NullArray(newAs[0].indptr->dtype, newAs[0].indptr->ctx)),
A_weights_reordered[0]); // Weights already reordered so we don't need
// As[0].data
for (int64_t i = 1; i < n; ++i)
result = cusparse::CusparseCsrgeam2<DType, int32_t>(
result.first, result.second, newAs[i], A_weights_reordered[i]);
......@@ -159,9 +145,10 @@ std::pair<CSRMatrix, NDArray> CSRSum(
if (cast) {
CSRMatrix C = result.first;
return {
CSRMatrix(C.num_rows, C.num_cols, AsNumBits(C.indptr, 64), AsNumBits(C.indices, 64),
AsNumBits(C.data, 64), true),
result.second};
CSRMatrix(
C.num_rows, C.num_cols, AsNumBits(C.indptr, 64),
AsNumBits(C.indices, 64), AsNumBits(C.data, 64), true),
result.second};
} else {
return result;
}
......
/**
* Copyright (c) 2020 by Contributors
* @file array/cuda/dispatcher.cuh
* @brief Templates to dispatch into different cuSPARSE routines based on the type
* argument.
* @brief Templates to dispatch into different cuSPARSE routines based on the
* type argument.
*/
#ifndef DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_
#define DGL_ARRAY_CUDA_CUSPARSE_DISPATCHER_CUH_
#include <cusparse.h>
#include <dgl/runtime/c_runtime_api.h>
#include "fp16.cuh"
#include "bf16.cuh"
#include "fp16.cuh"
namespace dgl {
namespace aten {
......@@ -40,8 +41,8 @@ template <>
struct CSRGEMM<__half> {
template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a different
// implementation would be required.
// TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a
// different implementation would be required.
LOG(FATAL) << "CSRGEMM::bufferSizeExt does not support dtype half (FP16).";
return static_cast<cusparseStatus_t>(0);
}
......@@ -65,9 +66,10 @@ template <>
struct CSRGEMM<__nv_bfloat16> {
template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a different
// implementation would be required.
LOG(FATAL) << "CSRGEMM::bufferSizeExt does not support dtype bfloat16 (BF16).";
// TODO(ndickson): There is no cusparseHcsrgemm2_bufferSizeExt, so a
// different implementation would be required.
LOG(FATAL)
<< "CSRGEMM::bufferSizeExt does not support dtype bfloat16 (BF16).";
return static_cast<cusparseStatus_t>(0);
}
......@@ -147,8 +149,8 @@ template <>
struct CSRGEAM<__half> {
template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a different
// implementation would be required.
// TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a
// different implementation would be required.
LOG(FATAL) << "CSRGEAM::bufferSizeExt does not support dtype half (FP16).";
return static_cast<cusparseStatus_t>(0);
}
......@@ -172,9 +174,10 @@ template <>
struct CSRGEAM<__nv_bfloat16> {
template <typename... Args>
static inline cusparseStatus_t bufferSizeExt(Args... args) {
// TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a different
// implementation would be required.
LOG(FATAL) << "CSRGEAM::bufferSizeExt does not support dtype bfloat16 (BF16).";
// TODO(ndickson): There is no cusparseHcsrgeam2_bufferSizeExt, so a
// different implementation would be required.
LOG(FATAL)
<< "CSRGEAM::bufferSizeExt does not support dtype bfloat16 (BF16).";
return static_cast<cusparseStatus_t>(0);
}
......
......@@ -21,8 +21,8 @@
#ifndef DGL_ARRAY_CUDA_FP16_CUH_
#define DGL_ARRAY_CUDA_FP16_CUH_
#include <cuda_fp16.h>
#include <algorithm>
static __device__ __forceinline__ half max(half a, half b) {
......@@ -42,45 +42,64 @@ static __device__ __forceinline__ half min(half a, half b) {
}
#ifdef __CUDACC__
// Arithmetic FP16 operations for architecture >= 5.3 are already defined in cuda_fp16.h
// Arithmetic FP16 operations for architecture >= 5.3 are already defined in
// cuda_fp16.h
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 530)
__device__ __forceinline__ __half operator+(const __half& lh, const __half& rh) {
__device__ __forceinline__ __half
operator+(const __half& lh, const __half& rh) {
return __half(float(lh) + float(rh)); // NOLINT
}
__device__ __forceinline__ __half operator-(const __half& lh, const __half& rh) {
__device__ __forceinline__ __half
operator-(const __half& lh, const __half& rh) {
return __half(float(lh) - float(rh)); // NOLINT
}
__device__ __forceinline__ __half operator*(const __half& lh, const __half& rh) {
__device__ __forceinline__ __half
operator*(const __half& lh, const __half& rh) {
return __half(float(lh) * float(rh)); // NOLINT
}
__device__ __forceinline__ __half operator/(const __half& lh, const __half& rh) {
__device__ __forceinline__ __half
operator/(const __half& lh, const __half& rh) {
return __half(float(lh) / float(rh)); // NOLINT
}
__device__ __forceinline__ __half& operator+=(__half& lh, const __half& rh) { // NOLINT
lh = __half(float(lh) + float(rh)); return lh; // NOLINT
}
__device__ __forceinline__ __half& operator-=(__half& lh, const __half& rh) { // NOLINT
lh = __half(float(lh) - float(rh)); return lh; // NOLINT
}
__device__ __forceinline__ __half& operator*=(__half& lh, const __half& rh) { // NOLINT
lh = __half(float(lh) * float(rh)); return lh; // NOLINT
}
__device__ __forceinline__ __half& operator/=(__half& lh, const __half& rh) { // NOLINT
lh = __half(float(lh) / float(rh)); return lh; // NOLINT
__device__ __forceinline__ __half& operator+=(
__half& lh, const __half& rh) { // NOLINT
lh = __half(float(lh) + float(rh)); // NOLINT
return lh;
}
__device__ __forceinline__ __half& operator-=(
__half& lh, const __half& rh) { // NOLINT
lh = __half(float(lh) - float(rh)); // NOLINT
return lh;
}
__device__ __forceinline__ __half& operator*=(
__half& lh, const __half& rh) { // NOLINT
lh = __half(float(lh) * float(rh)); // NOLINT
return lh;
}
__device__ __forceinline__ __half& operator/=(
__half& lh, const __half& rh) { // NOLINT
lh = __half(float(lh) / float(rh)); // NOLINT
return lh;
}
__device__ __forceinline__ __half& operator++(__half& h) { // NOLINT
h = __half(float(h) + 1.0f); return h; // NOLINT
h = __half(float(h) + 1.0f); // NOLINT
return h;
}
__device__ __forceinline__ __half& operator--(__half& h) { // NOLINT
h = __half(float(h) - 1.0f); return h; // NOLINT
h = __half(float(h) - 1.0f); // NOLINT
return h;
}
__device__ __forceinline__ __half operator++(__half& h, int) { // NOLINT
__half ret = h; h = __half(float(h) + 1.0f); return ret; // NOLINT
__device__ __forceinline__ __half operator++(__half& h, int) { // NOLINT
__half ret = h;
h = __half(float(h) + 1.0f); // NOLINT
return ret;
}
__device__ __forceinline__ __half operator--(__half& h, int) { // NOLINT
__half ret = h; h = __half(float(h) - 1.0f); return ret; // NOLINT
__device__ __forceinline__ __half operator--(__half& h, int) { // NOLINT
__half ret = h;
h = __half(float(h) - 1.0f); // NOLINT
return ret;
}
__device__ __forceinline__ __half operator+(const __half& h) { return h; }
......@@ -94,11 +113,11 @@ __device__ __forceinline__ bool operator==(const __half& lh, const __half& rh) {
__device__ __forceinline__ bool operator!=(const __half& lh, const __half& rh) {
return float(lh) != float(rh); // NOLINT
}
__device__ __forceinline__ bool operator> (const __half& lh, const __half& rh) {
return float(lh) > float(rh); // NOLINT
__device__ __forceinline__ bool operator>(const __half& lh, const __half& rh) {
return float(lh) > float(rh); // NOLINT
}
__device__ __forceinline__ bool operator< (const __half& lh, const __half& rh) {
return float(lh) < float(rh); // NOLINT
__device__ __forceinline__ bool operator<(const __half& lh, const __half& rh) {
return float(lh) < float(rh); // NOLINT
}
__device__ __forceinline__ bool operator>=(const __half& lh, const __half& rh) {
return float(lh) >= float(rh); // NOLINT
......
......@@ -8,6 +8,7 @@
#include <cmath>
#include <limits>
#include "./atomic.cuh"
#include "./fp16.cuh"
#include "bf16.cuh"
......@@ -16,99 +17,117 @@ namespace dgl {
namespace aten {
namespace cuda {
/////////////////////////////// CUDA binary operators ///////////////////////////////
/////////////////////////// CUDA binary operators //////////////////////////////
namespace binary {
template <typename DType>
struct Add {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = false;
static __device__ __forceinline__ DType Call(
const DType *lhs, const DType *rhs, int64_t len = 1) {
static __device__ __forceinline__ DType
Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
return lhs[0] + rhs[0];
}
};
template <typename DType> constexpr bool Add<DType>::use_lhs;
template <typename DType> constexpr bool Add<DType>::use_rhs;
template <typename DType> constexpr bool Add<DType>::reduce_last_dim;
template <typename DType>
constexpr bool Add<DType>::use_lhs;
template <typename DType>
constexpr bool Add<DType>::use_rhs;
template <typename DType>
constexpr bool Add<DType>::reduce_last_dim;
template <typename DType>
struct Sub {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = false;
static __device__ __forceinline__ DType Call(
const DType *lhs, const DType *rhs, int64_t len = 1) {
static __device__ __forceinline__ DType
Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
return lhs[0] - rhs[0];
}
};
template <typename DType> constexpr bool Sub<DType>::use_lhs;
template <typename DType> constexpr bool Sub<DType>::use_rhs;
template <typename DType> constexpr bool Sub<DType>::reduce_last_dim;
template <typename DType>
constexpr bool Sub<DType>::use_lhs;
template <typename DType>
constexpr bool Sub<DType>::use_rhs;
template <typename DType>
constexpr bool Sub<DType>::reduce_last_dim;
template <typename DType>
struct Mul {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = false;
static __device__ __forceinline__ DType Call(
const DType *lhs, const DType *rhs, int64_t len = 1) {
static __device__ __forceinline__ DType
Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
return lhs[0] * rhs[0];
}
};
template <typename DType> constexpr bool Mul<DType>::use_lhs;
template <typename DType> constexpr bool Mul<DType>::use_rhs;
template <typename DType> constexpr bool Mul<DType>::reduce_last_dim;
template <typename DType>
constexpr bool Mul<DType>::use_lhs;
template <typename DType>
constexpr bool Mul<DType>::use_rhs;
template <typename DType>
constexpr bool Mul<DType>::reduce_last_dim;
template <typename DType>
struct Div {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = false;
static __device__ __forceinline__ DType Call(
const DType *lhs, const DType *rhs, int64_t len = 1) {
static __device__ __forceinline__ DType
Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
return lhs[0] / rhs[0];
}
};
template <typename DType> constexpr bool Div<DType>::use_lhs;
template <typename DType> constexpr bool Div<DType>::use_rhs;
template <typename DType> constexpr bool Div<DType>::reduce_last_dim;
template <typename DType>
constexpr bool Div<DType>::use_lhs;
template <typename DType>
constexpr bool Div<DType>::use_rhs;
template <typename DType>
constexpr bool Div<DType>::reduce_last_dim;
template <typename DType>
struct CopyLhs {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = false;
static constexpr bool reduce_last_dim = false;
static __device__ __forceinline__ DType Call(
const DType *lhs, const DType *rhs, int64_t len = 1) {
static __device__ __forceinline__ DType
Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
return lhs[0];
}
};
template <typename DType> constexpr bool CopyLhs<DType>::use_lhs;
template <typename DType> constexpr bool CopyLhs<DType>::use_rhs;
template <typename DType> constexpr bool CopyLhs<DType>::reduce_last_dim;
template <typename DType>
constexpr bool CopyLhs<DType>::use_lhs;
template <typename DType>
constexpr bool CopyLhs<DType>::use_rhs;
template <typename DType>
constexpr bool CopyLhs<DType>::reduce_last_dim;
template <typename DType>
struct CopyRhs {
static constexpr bool use_lhs = false;
static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = false;
static __device__ __forceinline__ DType Call(
const DType *lhs, const DType *rhs, int64_t len = 1) {
static __device__ __forceinline__ DType
Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
return rhs[0];
}
};
template <typename DType> constexpr bool CopyRhs<DType>::use_lhs;
template <typename DType> constexpr bool CopyRhs<DType>::use_rhs;
template <typename DType> constexpr bool CopyRhs<DType>::reduce_last_dim;
template <typename DType>
constexpr bool CopyRhs<DType>::use_lhs;
template <typename DType>
constexpr bool CopyRhs<DType>::use_rhs;
template <typename DType>
constexpr bool CopyRhs<DType>::reduce_last_dim;
template <typename DType>
struct Dot {
static constexpr bool use_lhs = true;
static constexpr bool use_rhs = true;
static constexpr bool reduce_last_dim = true;
static __device__ __forceinline__ DType Call(
const DType *lhs, const DType *rhs, int64_t len = 1) {
static __device__ __forceinline__ DType
Call(const DType *lhs, const DType *rhs, int64_t len = 1) {
DType rst = static_cast<DType>(0.0f);
for (int64_t i = 0; i < len; ++i) {
rst += lhs[i] * rhs[i];
......@@ -116,25 +135,26 @@ struct Dot {
return rst;
}
};
template <typename DType> constexpr bool Dot<DType>::use_lhs;
template <typename DType> constexpr bool Dot<DType>::use_rhs;
template <typename DType> constexpr bool Dot<DType>::reduce_last_dim;
template <typename DType>
constexpr bool Dot<DType>::use_lhs;
template <typename DType>
constexpr bool Dot<DType>::use_rhs;
template <typename DType>
constexpr bool Dot<DType>::reduce_last_dim;
} // end of namespace binary
} // end of namespace binary
/////////////////////////////// CUDA reduce operators ///////////////////////////////
/////////////////////////// CUDA reduce operators //////////////////////////////
namespace reduce {
template <typename Idx,
typename DType,
bool atomic>
template <typename Idx, typename DType, bool atomic>
struct _Sum {
static constexpr __host__ __device__ __forceinline__ DType zero() {
return 0.;
}
static constexpr bool require_arg = false;
static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
DType val, Idx uid, Idx eid) {
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid,
Idx eid) {
if (!atomic) {
*out_buf += val;
} else {
......@@ -142,26 +162,23 @@ struct _Sum {
}
}
static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_buf,
DType val, Idx id) {
DType *out_buf, Idx *arg_buf, DType val, Idx id) {
if (!atomic) {
*out_buf += val;
} else {
cuda::AtomicAdd(out_buf, val);
}
}
static __device__ __forceinline__ void CallArg(Idx fid,
Idx *arg_u_buf, Idx *arg_e_buf,
DType val, DType val_ref, Idx uid, Idx eid) {}
static __device__ __forceinline__ void CallArg(
Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref,
Idx uid, Idx eid) {}
};
template <typename Idx,
typename DType,
bool atomic = false>
struct Sum: _Sum<Idx, DType, atomic> { };
template <typename Idx, typename DType, bool atomic = false>
struct Sum : _Sum<Idx, DType, atomic> {};
template <typename Idx, bool atomic>
struct Sum<Idx, half, atomic>: _Sum<Idx, half, atomic> {
struct Sum<Idx, half, atomic> : _Sum<Idx, half, atomic> {
static constexpr __host__ __device__ __forceinline__ half zero() {
return __float2half_rn(0.);
}
......@@ -169,24 +186,22 @@ struct Sum<Idx, half, atomic>: _Sum<Idx, half, atomic> {
#if BF16_ENABLED
template <typename Idx, bool atomic>
struct Sum<Idx, __nv_bfloat16, atomic>: _Sum<Idx, __nv_bfloat16, atomic> {
struct Sum<Idx, __nv_bfloat16, atomic> : _Sum<Idx, __nv_bfloat16, atomic> {
static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {
return __float2bfloat16_rn(0.);
}
};
#endif // BF16_ENABLED
template <typename Idx,
typename DType,
bool atomic>
template <typename Idx, typename DType, bool atomic>
struct _Max {
static constexpr __host__ __device__ __forceinline__ DType zero() {
return -std::numeric_limits<DType>::infinity();
}
static constexpr bool require_arg = true;
static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
DType val, Idx uid, Idx eid) {
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid,
Idx eid) {
if (!atomic) {
if (*out_buf < val) {
*out_buf = val;
......@@ -198,8 +213,7 @@ struct _Max {
}
}
static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_buf,
DType val, Idx id) {
DType *out_buf, Idx *arg_buf, DType val, Idx id) {
if (!atomic) {
if (*out_buf < val) {
*out_buf = val;
......@@ -209,27 +223,22 @@ struct _Max {
cuda::AtomicMax(out_buf, val);
}
}
static __device__ __forceinline__ void CallArg(Idx fid,
Idx *arg_u_buf, Idx *arg_e_buf,
DType val, DType val_ref, Idx uid, Idx eid) {
static __device__ __forceinline__ void CallArg(
Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref,
Idx uid, Idx eid) {
if (atomic) {
if (val == val_ref) {
if (arg_u_buf)
arg_u_buf[fid] = uid;
if (arg_e_buf)
arg_e_buf[fid] = eid;
if (arg_u_buf) arg_u_buf[fid] = uid;
if (arg_e_buf) arg_e_buf[fid] = eid;
}
}
}
};
template <typename Idx,
typename DType,
bool atomic = false>
struct Max : _Max<Idx, DType, atomic> { };
template <typename Idx, typename DType, bool atomic = false>
struct Max : _Max<Idx, DType, atomic> {};
template <typename Idx,
bool atomic>
template <typename Idx, bool atomic>
struct Max<Idx, half, atomic> : _Max<Idx, half, atomic> {
static constexpr __host__ __device__ __forceinline__ half zero() {
return __float2half_rn(-6.550400e+04f);
......@@ -237,8 +246,7 @@ struct Max<Idx, half, atomic> : _Max<Idx, half, atomic> {
};
#if BF16_ENABLED
template <typename Idx,
bool atomic>
template <typename Idx, bool atomic>
struct Max<Idx, __nv_bfloat16, atomic> : _Max<Idx, __nv_bfloat16, atomic> {
static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {
return __float2bfloat16_rn(-std::numeric_limits<float>::infinity());
......@@ -246,17 +254,15 @@ struct Max<Idx, __nv_bfloat16, atomic> : _Max<Idx, __nv_bfloat16, atomic> {
};
#endif // BF16_ENABLED
template <typename Idx,
typename DType,
bool atomic>
template <typename Idx, typename DType, bool atomic>
struct _Min {
static constexpr __host__ __device__ __forceinline__ DType zero() {
return std::numeric_limits<DType>::infinity();
}
static constexpr bool require_arg = true;
static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf,
DType val, Idx uid, Idx eid) {
DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid,
Idx eid) {
if (!atomic) {
if (*out_buf > val) {
*out_buf = val;
......@@ -268,8 +274,7 @@ struct _Min {
}
}
static __device__ __forceinline__ void Call(
DType *out_buf, Idx *arg_buf,
DType val, Idx id) {
DType *out_buf, Idx *arg_buf, DType val, Idx id) {
if (!atomic) {
if (*out_buf > val) {
*out_buf = val;
......@@ -279,27 +284,22 @@ struct _Min {
cuda::AtomicMin(out_buf, val);
}
}
static __device__ __forceinline__ void CallArg(Idx fid,
Idx *arg_u_buf, Idx *arg_e_buf,
DType val, DType val_ref, Idx uid, Idx eid) {
static __device__ __forceinline__ void CallArg(
Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref,
Idx uid, Idx eid) {
if (atomic) {
if (val == val_ref) {
if (arg_u_buf)
arg_u_buf[fid] = uid;
if (arg_e_buf)
arg_e_buf[fid] = eid;
if (arg_u_buf) arg_u_buf[fid] = uid;
if (arg_e_buf) arg_e_buf[fid] = eid;
}
}
}
};
template <typename Idx,
typename DType,
bool atomic = false>
struct Min : _Min<Idx, DType, atomic> { };
template <typename Idx, typename DType, bool atomic = false>
struct Min : _Min<Idx, DType, atomic> {};
template <typename Idx,
bool atomic>
template <typename Idx, bool atomic>
struct Min<Idx, half, atomic> : _Min<Idx, half, atomic> {
static constexpr __host__ __device__ __forceinline__ half zero() {
return __float2half_rn(6.550400e+04f);
......@@ -307,8 +307,7 @@ struct Min<Idx, half, atomic> : _Min<Idx, half, atomic> {
};
#if BF16_ENABLED
template <typename Idx,
bool atomic>
template <typename Idx, bool atomic>
struct Min<Idx, __nv_bfloat16, atomic> : _Min<Idx, __nv_bfloat16, atomic> {
static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {
return __float2bfloat16_rn(std::numeric_limits<float>::infinity());
......
......@@ -4,10 +4,12 @@
* @brief GatherMM C APIs and definitions.
*/
#include <dgl/array.h>
#include <algorithm> // std::swap
#include "./utils.h"
#include "./functor.cuh"
#include "./atomic.cuh"
#include "./functor.cuh"
#include "./utils.h"
namespace dgl {
using namespace cuda;
......@@ -15,62 +17,58 @@ namespace aten {
namespace {
/** @brief Call cuBLAS GEMM API for dense matmul operation for float and double. */
/** @brief Call cuBLAS GEMM API for dense matmul operation for float and double.
*/
template <typename DType>
cublasStatus_t cublasGemm(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const DType* alpha, const DType* A, int lda,
const DType* B, int ldb, const DType* beta,
DType* C, int ldc) {
cublasStatus_t cublasGemm(
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k, const DType* alpha, const DType* A, int lda,
const DType* B, int ldb, const DType* beta, DType* C, int ldc) {
LOG(INFO) << "Not supported dtype";
return CUBLAS_STATUS_EXECUTION_FAILED;
}
template <>
cublasStatus_t cublasGemm<__half>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const __half* alpha, const __half* A, int lda,
const __half* B, int ldb, const __half* beta,
__half* C, int ldc) {
return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda,
B, ldb, beta, C, ldc);
cublasStatus_t cublasGemm<__half>(
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k, const __half* alpha, const __half* A, int lda,
const __half* B, int ldb, const __half* beta, __half* C, int ldc) {
return cublasHgemm(
handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
#if BF16_ENABLED
template <>
cublasStatus_t cublasGemm<__nv_bfloat16>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const __nv_bfloat16* alpha, const __nv_bfloat16* A, int lda,
const __nv_bfloat16* B, int ldb, const __nv_bfloat16* beta,
cublasStatus_t cublasGemm<__nv_bfloat16>(
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k, const __nv_bfloat16* alpha, const __nv_bfloat16* A,
int lda, const __nv_bfloat16* B, int ldb, const __nv_bfloat16* beta,
__nv_bfloat16* C, int ldc) {
float alpha_float = __bfloat162float(*alpha);
float beta_float = __bfloat162float(*beta);
return cublasGemmEx(handle, transa, transb, m, n, k,
&alpha_float, A, CUDA_R_16BF, lda,
B, CUDA_R_16BF, ldb,
&beta_float, C, CUDA_R_16BF, ldc,
CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
return cublasGemmEx(
handle, transa, transb, m, n, k, &alpha_float, A, CUDA_R_16BF, lda, B,
CUDA_R_16BF, ldb, &beta_float, C, CUDA_R_16BF, ldc, CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
#endif // BF16_ENABLED
template <>
cublasStatus_t cublasGemm<float>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const float* alpha, const float* A, int lda,
const float* B, int ldb, const float* beta,
float* C, int ldc) {
return cublasSgemm(handle, transa, transb, m, n, k, alpha, A, lda,
B, ldb, beta, C, ldc);
cublasStatus_t cublasGemm<float>(
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k, const float* alpha, const float* A, int lda,
const float* B, int ldb, const float* beta, float* C, int ldc) {
return cublasSgemm(
handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
template <>
cublasStatus_t cublasGemm<double>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const double* alpha, const double* A, int lda,
const double* B, int ldb, const double* beta,
double* C, int ldc) {
return cublasDgemm(handle, transa, transb, m, n, k, alpha, A, lda,
B, ldb, beta, C, ldc);
cublasStatus_t cublasGemm<double>(
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k, const double* alpha, const double* A, int lda,
const double* B, int ldb, const double* beta, double* C, int ldc) {
return cublasDgemm(
handle, transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
}
} // namespace
......@@ -78,122 +76,113 @@ cublasStatus_t cublasGemm<double>(cublasHandle_t handle, cublasOperation_t trans
namespace cuda {
/**
* @note Each row of A multiplies a segment of matrix of B of dimension in_len * outlen.
* One warp is assigned to process one row of A. Each WARP sequentially multiplies
* one element of A and a row of B to compute partial result of the output. A
* is loaded in shared memory in a coalesced way. Output matrix is loaded in
* registers. B should get benefit from L2 cache.
* @note Each row of A multiplies a segment of matrix of B of dimension in_len *
* outlen. One warp is assigned to process one row of A. Each WARP sequentially
* multiplies one element of A and a row of B to compute partial result of the
* output. A is loaded in shared memory in a coalesced way. Output matrix is
* loaded in registers. B should get benefit from L2 cache.
*/
template <typename Idx, typename DType>
__global__ void GatherMMScatterKernel(
const DType* __restrict__ A,
const DType* __restrict__ B,
DType* __restrict__ C,
const Idx* __restrict__ idx_a,
const Idx* __restrict__ idx_b,
const Idx* __restrict__ idx_c,
const int64_t num_rows,
const int64_t in_len,
const int64_t out_len) {
const DType* __restrict__ A, const DType* __restrict__ B,
DType* __restrict__ C, const Idx* __restrict__ idx_a,
const Idx* __restrict__ idx_b, const Idx* __restrict__ idx_c,
const int64_t num_rows, const int64_t in_len, const int64_t out_len) {
unsigned int tId = threadIdx.x;
unsigned int laneId = tId & 31;
unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x);
unsigned int warpId = gId >> 5;
unsigned int row = warpId;
if (row < num_rows) {
const unsigned int local_row =
row & 3; // hardcoded for TB size 128 (4 warps)
const Idx cur_rowA = (idx_a) ? idx_a[row] : row;
const Idx cur_rowB = (idx_b) ? idx_b[row] : row;
const Idx cur_rowC = (idx_c) ? idx_c[row] : row;
const Idx B_offset = cur_rowB * in_len * out_len;
const int sh_a_tile = 64;
__shared__ DType sh_A[4 * sh_a_tile];
int a_tile = sh_a_tile;
for (unsigned int k_start = 0; k_start < in_len; k_start += 64) {
if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;
// Load A in shared mem in a coalesced way
for (unsigned int l = laneId; l < a_tile; l += 32)
sh_A[local_row * sh_a_tile + l] = A[cur_rowA * in_len + (k_start + l)];
__syncwarp();
unsigned int tId = threadIdx.x;
unsigned int laneId = tId & 31;
unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x);
unsigned int warpId = gId >> 5;
unsigned int row = warpId;
if (row < num_rows) {
const unsigned int local_row = row & 3; // hardcoded for TB size 128 (4 warps)
const Idx cur_rowA = (idx_a) ? idx_a[row] : row;
const Idx cur_rowB = (idx_b) ? idx_b[row] : row;
const Idx cur_rowC = (idx_c) ? idx_c[row] : row;
const Idx B_offset = cur_rowB * in_len * out_len;
const int sh_a_tile = 64;
__shared__ DType sh_A[4 * sh_a_tile];
int a_tile = sh_a_tile;
for (unsigned int k_start = 0; k_start < in_len; k_start += 64) {
if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;
// Load A in shared mem in a coalesced way
for (unsigned int l = laneId; l < a_tile; l += 32)
sh_A[local_row * sh_a_tile + l] = A[cur_rowA * in_len + (k_start + l)];
__syncwarp();
for (unsigned int outloop = 0; outloop < out_len; outloop +=32) {
DType out_reg = static_cast<DType>(0.0f); // thread private
const unsigned int l = laneId;
if (l < out_len) {
// iterate over elements of a row of A
for (unsigned int i = 0; i < a_tile; i++) {
const DType a_val = sh_A[local_row * sh_a_tile + i];
// iterate over elements of a row of B in parallel
out_reg += a_val * B[B_offset + ((i + k_start) * out_len + (outloop + l))];
}
if (idx_c) {
AtomicAdd(C + cur_rowC * out_len + (outloop + l), out_reg);
} else {
C[cur_rowC * out_len + (outloop + l)] += out_reg;
}
}
}
for (unsigned int outloop = 0; outloop < out_len; outloop += 32) {
DType out_reg = static_cast<DType>(0.0f); // thread private
const unsigned int l = laneId;
if (l < out_len) {
// iterate over elements of a row of A
for (unsigned int i = 0; i < a_tile; i++) {
const DType a_val = sh_A[local_row * sh_a_tile + i];
// iterate over elements of a row of B in parallel
out_reg +=
a_val * B[B_offset + ((i + k_start) * out_len + (outloop + l))];
}
if (idx_c) {
AtomicAdd(C + cur_rowC * out_len + (outloop + l), out_reg);
} else {
C[cur_rowC * out_len + (outloop + l)] += out_reg;
}
}
}
}
}
}
/**
* @note Output matrix is accumulated via atomic operations. Rest of the strategies
* are similar to GatherMMKernel. One warp is assigned to process one row of A. Each
* WARP sequentially multiplies one element of A and a row of B to compute partial
* result of the output. A is loaded in shared memory in a coalesced way. B should
* get benefit from L2 cache.
* @note Output matrix is accumulated via atomic operations. Rest of the
* strategies are similar to GatherMMKernel. One warp is assigned to process one
* row of A. Each WARP sequentially multiplies one element of A and a row of B
* to compute partial result of the output. A is loaded in shared memory in a
* coalesced way. B should get benefit from L2 cache.
*/
template <typename Idx, typename DType>
__global__ void GatherMMScatterKernel2(
const DType* __restrict__ A,
const DType* __restrict__ B,
DType* __restrict__ C,
const Idx* __restrict__ idx_a,
const Idx* __restrict__ idx_b,
const Idx* __restrict__ idx_c,
const int64_t num_rows,
const int64_t in_len,
const int64_t out_len) {
unsigned int tId = threadIdx.x;
unsigned int laneId = tId & 31;
unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x);
unsigned int warpId = gId >> 5;
unsigned int row = warpId;
if (row < num_rows) {
const unsigned int local_row = row & 3; // hardcoded for TB size 128 (4 warps)
const Idx row_a = (idx_a) ? idx_a[row] : row;
const Idx row_b = (idx_b) ? idx_b[row] : row;
const Idx row_c = (idx_c) ? idx_c[row] : row;
const Idx C_offset = row_c * in_len * out_len;
const int sh_a_tile = 64;
__shared__ DType sh_A[4 * sh_a_tile];
int a_tile = sh_a_tile;
for (unsigned int k_start = 0; k_start < in_len; k_start += 64) {
if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;
/* Load A in shared mem in a coalesced way */
for (unsigned int l = laneId; l < a_tile; l += 32)
sh_A[local_row * sh_a_tile + l] = A[row_a * in_len + (k_start + l)];
__syncwarp();
const DType* __restrict__ A, const DType* __restrict__ B,
DType* __restrict__ C, const Idx* __restrict__ idx_a,
const Idx* __restrict__ idx_b, const Idx* __restrict__ idx_c,
const int64_t num_rows, const int64_t in_len, const int64_t out_len) {
unsigned int tId = threadIdx.x;
unsigned int laneId = tId & 31;
unsigned int gId = (blockIdx.x * blockDim.x + threadIdx.x);
unsigned int warpId = gId >> 5;
unsigned int row = warpId;
if (row < num_rows) {
const unsigned int local_row =
row & 3; // hardcoded for TB size 128 (4 warps)
const Idx row_a = (idx_a) ? idx_a[row] : row;
const Idx row_b = (idx_b) ? idx_b[row] : row;
const Idx row_c = (idx_c) ? idx_c[row] : row;
const Idx C_offset = row_c * in_len * out_len;
const int sh_a_tile = 64;
__shared__ DType sh_A[4 * sh_a_tile];
int a_tile = sh_a_tile;
for (unsigned int k_start = 0; k_start < in_len; k_start += 64) {
if ((in_len - k_start) < a_tile) a_tile = in_len - k_start;
/* Load A in shared mem in a coalesced way */
for (unsigned int l = laneId; l < a_tile; l += 32)
sh_A[local_row * sh_a_tile + l] = A[row_a * in_len + (k_start + l)];
__syncwarp();
for (unsigned int outloop = 0; outloop < out_len; outloop +=32) {
DType out_reg = static_cast<DType>(0.0f); // thread private
const unsigned int l = laneId;
if (l < out_len) {
const DType b_val = B[row_b * out_len + (outloop + l)];
/* iterate over elements of a row of A */
for (unsigned int i = 0; i < a_tile; i++) {
const DType a_val = sh_A[local_row * sh_a_tile + i];
const Idx C_idx = C_offset + ((i + k_start) * out_len + (outloop + l));
AtomicAdd(C + C_idx, a_val * b_val);
}
}
}
for (unsigned int outloop = 0; outloop < out_len; outloop += 32) {
DType out_reg = static_cast<DType>(0.0f); // thread private
const unsigned int l = laneId;
if (l < out_len) {
const DType b_val = B[row_b * out_len + (outloop + l)];
/* iterate over elements of a row of A */
for (unsigned int i = 0; i < a_tile; i++) {
const DType a_val = sh_A[local_row * sh_a_tile + i];
const Idx C_idx =
C_offset + ((i + k_start) * out_len + (outloop + l));
AtomicAdd(C + C_idx, a_val * b_val);
}
}
}
}
}
}
} // namespace cuda
......@@ -210,103 +199,89 @@ __global__ void GatherMMScatterKernel2(
* @param b_trans Matrix B to be transposed
*/
template <int XPU, typename IdType, typename DType>
void SegmentMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray seglen_A,
bool a_trans, bool b_trans) {
auto device = runtime::DeviceAPI::Get(A->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType *A_data = A.Ptr<DType>();
const DType *B_data = B.Ptr<DType>();
const IdType* seglen_A_data = seglen_A.Ptr<IdType>();
DType *C_data = C.Ptr<DType>();
int64_t A_offset = 0, B_offset = 0, C_offset = 0;
int64_t m, n, k;
int64_t num_rel = seglen_A.NumElements();
DType alpha = 1., beta = 0.;
void SegmentMM(
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans) {
auto device = runtime::DeviceAPI::Get(A->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* A_data = A.Ptr<DType>();
const DType* B_data = B.Ptr<DType>();
const IdType* seglen_A_data = seglen_A.Ptr<IdType>();
DType* C_data = C.Ptr<DType>();
int64_t A_offset = 0, B_offset = 0, C_offset = 0;
int64_t m, n, k;
int64_t num_rel = seglen_A.NumElements();
DType alpha = 1., beta = 0.;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));
IdType m_offset = 0;
for (IdType etype = 0; etype < num_rel; ++etype) {
m = seglen_A_data[etype]; // rows of A
CHECK_LE(m_offset + m, A->shape[0]) << "Segment index out of bound of A->shape[0].";
n = B->shape[2]; // cols of B
k = B->shape[1]; // cols of A == rows of B
int ldb = n, lda = k, ldc = n;
cublasOperation_t transB = CUBLAS_OP_N;
cublasOperation_t transA = CUBLAS_OP_N;
if (b_trans) {
transB = CUBLAS_OP_T;
ldb = n, lda = n, ldc = k;
std::swap(n, k);
}
CUBLAS_CALL(cublasGemm<DType>(
thr_entry->cublas_handle,
transB,
transA,
n, m, k,
&alpha,
B_data + B_offset, ldb,
A_data + A_offset, lda,
&beta,
C_data + C_offset, ldc));
A_offset += m * k;
B_offset += k * n;
C_offset += m * n;
m_offset += m;
IdType m_offset = 0;
for (IdType etype = 0; etype < num_rel; ++etype) {
m = seglen_A_data[etype]; // rows of A
CHECK_LE(m_offset + m, A->shape[0])
<< "Segment index out of bound of A->shape[0].";
n = B->shape[2]; // cols of B
k = B->shape[1]; // cols of A == rows of B
int ldb = n, lda = k, ldc = n;
cublasOperation_t transB = CUBLAS_OP_N;
cublasOperation_t transA = CUBLAS_OP_N;
if (b_trans) {
transB = CUBLAS_OP_T;
ldb = n, lda = n, ldc = k;
std::swap(n, k);
}
CUBLAS_CALL(cublasGemm<DType>(
thr_entry->cublas_handle, transB, transA, n, m, k, &alpha,
B_data + B_offset, ldb, A_data + A_offset, lda, &beta,
C_data + C_offset, ldc));
A_offset += m * k;
B_offset += k * n;
C_offset += m * n;
m_offset += m;
}
}
template <int XPU, typename IdType, typename DType>
void SegmentMMBackwardB(const NDArray A,
const NDArray dC,
NDArray dB,
const NDArray seglen) {
auto device = runtime::DeviceAPI::Get(A->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType *A_data = A.Ptr<DType>();
const DType *dC_data = dC.Ptr<DType>();
const IdType* seglen_data = seglen.Ptr<IdType>();
DType *dB_data = dB.Ptr<DType>();
int64_t A_offset = 0, dC_offset = 0, dB_offset = 0;
int64_t m, n, k;
int64_t num_rel = seglen.NumElements();
DType alpha = 1., beta = 1.;
void SegmentMMBackwardB(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) {
auto device = runtime::DeviceAPI::Get(A->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* A_data = A.Ptr<DType>();
const DType* dC_data = dC.Ptr<DType>();
const IdType* seglen_data = seglen.Ptr<IdType>();
DType* dB_data = dB.Ptr<DType>();
int64_t A_offset = 0, dC_offset = 0, dB_offset = 0;
int64_t m, n, k;
int64_t num_rel = seglen.NumElements();
DType alpha = 1., beta = 1.;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));
IdType k_offset = 0;
for (IdType etype = 0; etype < num_rel; ++etype) {
m = dC->shape[1];
n = A->shape[1];
k = seglen_data[etype];
CHECK_LE(k_offset + k, A->shape[0]) << "Segement index out of bound of A->shape[0].";
int lddC = m, ldA = n, lddB = m;
cublasOperation_t trans_dC = CUBLAS_OP_N;
cublasOperation_t trans_A = CUBLAS_OP_T;
CUBLAS_CALL(cublasGemm<DType>(
thr_entry->cublas_handle,
trans_dC,
trans_A,
m, n, k,
&alpha,
dC_data + dC_offset, lddC,
A_data + A_offset, ldA,
&beta,
dB_data + dB_offset, lddB));
dC_offset += m * k;
A_offset += n * k;
dB_offset += m * n;
k_offset += k;
}
IdType k_offset = 0;
for (IdType etype = 0; etype < num_rel; ++etype) {
m = dC->shape[1];
n = A->shape[1];
k = seglen_data[etype];
CHECK_LE(k_offset + k, A->shape[0])
<< "Segement index out of bound of A->shape[0].";
int lddC = m, ldA = n, lddB = m;
cublasOperation_t trans_dC = CUBLAS_OP_N;
cublasOperation_t trans_A = CUBLAS_OP_T;
CUBLAS_CALL(cublasGemm<DType>(
thr_entry->cublas_handle, trans_dC, trans_A, m, n, k, &alpha,
dC_data + dC_offset, lddC, A_data + A_offset, ldA, &beta,
dB_data + dB_offset, lddB));
dC_offset += m * k;
A_offset += n * k;
dB_offset += m * n;
k_offset += k;
}
}
/**
......@@ -320,30 +295,23 @@ void SegmentMMBackwardB(const NDArray A,
*/
template <int XPU, typename IdType, typename DType>
void GatherMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b) {
auto device = runtime::DeviceAPI::Get(A->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t out_len = B->shape[2]; // cols of B
int64_t in_len = A->shape[1]; // cols of A
const int64_t tot_num_rows = A->shape[0];
const int ntx = 128;
const int warp_size = 32;
const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);
const dim3 nblks(nbx);
const dim3 nthrs(ntx);
CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel<IdType, DType>),
nblks, nthrs, 0, stream,
A.Ptr<DType>(),
B.Ptr<DType>(),
C.Ptr<DType>(),
idx_a.Ptr<IdType>(),
idx_b.Ptr<IdType>(),
nullptr,
tot_num_rows, in_len, out_len);
void GatherMM(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b) {
auto device = runtime::DeviceAPI::Get(A->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t out_len = B->shape[2]; // cols of B
int64_t in_len = A->shape[1]; // cols of A
const int64_t tot_num_rows = A->shape[0];
const int ntx = 128;
const int warp_size = 32;
const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);
const dim3 nblks(nbx);
const dim3 nthrs(ntx);
CUDA_KERNEL_CALL(
(cuda::GatherMMScatterKernel<IdType, DType>), nblks, nthrs, 0, stream,
A.Ptr<DType>(), B.Ptr<DType>(), C.Ptr<DType>(), idx_a.Ptr<IdType>(),
idx_b.Ptr<IdType>(), nullptr, tot_num_rows, in_len, out_len);
}
/**
......@@ -360,128 +328,118 @@ void GatherMM(const NDArray A,
* @param b_trans Matrix B to be transposed
*/
template <int XPU, typename IdType, typename DType>
void GatherMMScatter(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b,
const NDArray idx_c) {
auto device = runtime::DeviceAPI::Get(A->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
const IdType *idx_c_data = idx_c.Ptr<IdType>();
int64_t out_len = (B->ndim == 2)? B->shape[1] : B->shape[2]; // cols of B
int64_t in_len = A->shape[1]; // cols of A
int64_t tot_num_rows = A->shape[0];
const int ntx = 128;
const int warp_size = 32;
const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);
const dim3 nblks(nbx);
const dim3 nthrs(ntx);
if (B->ndim == 3) {
CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel<IdType, DType>),
nblks, nthrs, 0, stream,
A.Ptr<DType>(),
B.Ptr<DType>(),
C.Ptr<DType>(),
idx_a.Ptr<IdType>(),
idx_b.Ptr<IdType>(),
idx_c.Ptr<IdType>(),
tot_num_rows, in_len, out_len);
} else {
// Custom kernel for W_grad[idx_c[i]] = H^T[i] * C.grad[i]
// This kernel accesses rows of A in a transposed way w/o explicitly converting A
CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel2<IdType, DType>),
nblks, nthrs, 0, stream,
A.Ptr<DType>(),
B.Ptr<DType>(),
C.Ptr<DType>(),
idx_a.Ptr<IdType>(),
idx_b.Ptr<IdType>(),
idx_c.Ptr<IdType>(),
tot_num_rows, in_len, out_len);
}
void GatherMMScatter(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c) {
auto device = runtime::DeviceAPI::Get(A->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
const IdType* idx_c_data = idx_c.Ptr<IdType>();
int64_t out_len = (B->ndim == 2) ? B->shape[1] : B->shape[2]; // cols of B
int64_t in_len = A->shape[1]; // cols of A
int64_t tot_num_rows = A->shape[0];
const int ntx = 128;
const int warp_size = 32;
const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);
const dim3 nblks(nbx);
const dim3 nthrs(ntx);
if (B->ndim == 3) {
CUDA_KERNEL_CALL(
(cuda::GatherMMScatterKernel<IdType, DType>), nblks, nthrs, 0, stream,
A.Ptr<DType>(), B.Ptr<DType>(), C.Ptr<DType>(), idx_a.Ptr<IdType>(),
idx_b.Ptr<IdType>(), idx_c.Ptr<IdType>(), tot_num_rows, in_len,
out_len);
} else {
// Custom kernel for W_grad[idx_c[i]] = H^T[i] * C.grad[i]
// This kernel accesses rows of A in a transposed way w/o explicitly
// converting A
CUDA_KERNEL_CALL(
(cuda::GatherMMScatterKernel2<IdType, DType>), nblks, nthrs, 0, stream,
A.Ptr<DType>(), B.Ptr<DType>(), C.Ptr<DType>(), idx_a.Ptr<IdType>(),
idx_b.Ptr<IdType>(), idx_c.Ptr<IdType>(), tot_num_rows, in_len,
out_len);
}
}
template void GatherMM<kDGLCUDA, int32_t, __half>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMM<kDGLCUDA, int64_t, __half>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
#if BF16_ENABLED
template void GatherMM<kDGLCUDA, int32_t, __nv_bfloat16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMM<kDGLCUDA, int64_t, __nv_bfloat16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
#endif // BF16_ENABLED
template void GatherMM<kDGLCUDA, int32_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMM<kDGLCUDA, int64_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMM<kDGLCUDA, int32_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMM<kDGLCUDA, int64_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMMScatter<kDGLCUDA, int32_t, __half>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCUDA, int64_t, __half>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
#if BF16_ENABLED
template void GatherMMScatter<kDGLCUDA, int32_t, __nv_bfloat16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCUDA, int64_t, __nv_bfloat16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
#endif // BF16_ENABLED
template void GatherMMScatter<kDGLCUDA, int32_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCUDA, int64_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCUDA, int32_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCUDA, int64_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void SegmentMM<kDGLCUDA, int32_t, __half>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int64_t, __half>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
#if BF16_ENABLED
template void SegmentMM<kDGLCUDA, int32_t, __nv_bfloat16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int64_t, __nv_bfloat16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
#endif // BF16_ENABLED
template void SegmentMM<kDGLCUDA, int32_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int64_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int32_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int64_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDGLCUDA, int32_t, __half>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
......
......@@ -6,18 +6,20 @@
* sampling code rowwise_sampling.cu.
* @author pengqirong (OPPO), dlasalle and Xin from Nvidia.
*/
#include <curand_kernel.h>
#include <dgl/random.h>
#include <dgl/runtime/device_api.h>
#include <curand_kernel.h>
#include <numeric>
#include "./dgl_cub.cuh"
#include "./utils.h"
#include "../../array/cuda/atomic.cuh"
#include "../../runtime/cuda/cuda_common.h"
#include "./dgl_cub.cuh"
#include "./utils.h"
// require CUB 1.17 to use DeviceSegmentedSort
static_assert(CUB_VERSION >= 101700, "Require CUB >= 1.17 to use DeviceSegmentedSort");
static_assert(
CUB_VERSION >= 101700, "Require CUB >= 1.17 to use DeviceSegmentedSort");
using namespace dgl::aten::cuda;
......@@ -30,26 +32,26 @@ namespace {
constexpr int BLOCK_SIZE = 128;
/**
* @brief Compute the size of each row in the sampled CSR, without replacement.
* temp_deg is calculated for rows with deg > num_picks.
* For these rows, we will calculate their A-Res values and sort them to get top-num_picks.
*
* @tparam IdType The type of node and edge indexes.
* @param num_picks The number of non-zero entries to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The index where each row's edges start.
* @param out_deg The size of each row in the sampled matrix, as indexed by `in_rows` (output).
* @param temp_deg The size of each row in the input matrix, as indexed by `in_rows` (output).
*/
template<typename IdType>
* @brief Compute the size of each row in the sampled CSR, without replacement.
* temp_deg is calculated for rows with deg > num_picks.
* For these rows, we will calculate their A-Res values and sort them to get
* top-num_picks.
*
* @tparam IdType The type of node and edge indexes.
* @param num_picks The number of non-zero entries to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The index where each row's edges start.
* @param out_deg The size of each row in the sampled matrix, as indexed by
* `in_rows` (output).
* @param temp_deg The size of each row in the input matrix, as indexed by
* `in_rows` (output).
*/
template <typename IdType>
__global__ void _CSRRowWiseSampleDegreeKernel(
const int64_t num_picks,
const int64_t num_rows,
const IdType * const in_rows,
const IdType * const in_ptr,
IdType * const out_deg,
IdType * const temp_deg) {
const int64_t num_picks, const int64_t num_rows,
const IdType* const in_rows, const IdType* const in_ptr,
IdType* const out_deg, IdType* const temp_deg) {
const int64_t tIdx = threadIdx.x + blockIdx.x * blockDim.x;
if (tIdx < num_rows) {
......@@ -69,25 +71,24 @@ __global__ void _CSRRowWiseSampleDegreeKernel(
}
/**
* @brief Compute the size of each row in the sampled CSR, with replacement.
* We need the actual in degree of each row to store CDF values.
*
* @tparam IdType The type of node and edge indexes.
* @param num_picks The number of non-zero entries to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The index where each row's edges start.
* @param out_deg The size of each row in the sampled matrix, as indexed by `in_rows` (output).
* @param temp_deg The size of each row in the input matrix, as indexed by `in_rows` (output).
*/
template<typename IdType>
* @brief Compute the size of each row in the sampled CSR, with replacement.
* We need the actual in degree of each row to store CDF values.
*
* @tparam IdType The type of node and edge indexes.
* @param num_picks The number of non-zero entries to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The index where each row's edges start.
* @param out_deg The size of each row in the sampled matrix, as indexed by
* `in_rows` (output).
* @param temp_deg The size of each row in the input matrix, as indexed by
* `in_rows` (output).
*/
template <typename IdType>
__global__ void _CSRRowWiseSampleDegreeReplaceKernel(
const int64_t num_picks,
const int64_t num_rows,
const IdType * const in_rows,
const IdType * const in_ptr,
IdType * const out_deg,
IdType * const temp_deg) {
const int64_t num_picks, const int64_t num_rows,
const IdType* const in_rows, const IdType* const in_ptr,
IdType* const out_deg, IdType* const temp_deg) {
const int64_t tIdx = threadIdx.x + blockIdx.x * blockDim.x;
if (tIdx < num_rows) {
......@@ -106,23 +107,20 @@ __global__ void _CSRRowWiseSampleDegreeReplaceKernel(
}
/**
* @brief Equivalent to numpy expression: array[idx[off:off + len]]
*
* @tparam IdType The ID type used for indices.
* @tparam FloatType The float type used for array values.
* @param array The array to be selected.
* @param idx_data The index mapping array.
* @param index The index of value to be selected.
* @param offset The offset to start.
* @param out The selected value (output).
*/
template<typename IdType, typename FloatType>
* @brief Equivalent to numpy expression: array[idx[off:off + len]]
*
* @tparam IdType The ID type used for indices.
* @tparam FloatType The float type used for array values.
* @param array The array to be selected.
* @param idx_data The index mapping array.
* @param index The index of value to be selected.
* @param offset The offset to start.
* @param out The selected value (output).
*/
template <typename IdType, typename FloatType>
__device__ void _DoubleSlice(
const FloatType * const array,
const IdType * const idx_data,
const IdType idx,
const IdType offset,
FloatType* const out) {
const FloatType* const array, const IdType* const idx_data,
const IdType idx, const IdType offset, FloatType* const out) {
if (idx_data) {
*out = array[idx_data[offset + idx]];
} else {
......@@ -131,39 +129,35 @@ __device__ void _DoubleSlice(
}
/**
* @brief Compute A-Res value. A-Res value needs to be calculated only if deg
* is greater than num_picks in weighted rowwise sampling without replacement.
*
* @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR.
* @param data The data array of the input CSR.
* @param prob The probability array of the input CSR.
* @param ares_ptr The offset to write each row to in the A-res array.
* @param ares_idxs The A-Res value corresponding index array, the index of input CSR (output).
* @param ares The A-Res value array (output).
* @author pengqirong (OPPO)
*/
template<typename IdType, typename FloatType, int TILE_SIZE>
* @brief Compute A-Res value. A-Res value needs to be calculated only if deg
* is greater than num_picks in weighted rowwise sampling without replacement.
*
* @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR.
* @param data The data array of the input CSR.
* @param prob The probability array of the input CSR.
* @param ares_ptr The offset to write each row to in the A-res array.
* @param ares_idxs The A-Res value corresponding index array, the index of
* input CSR (output).
* @param ares The A-Res value array (output).
* @author pengqirong (OPPO)
*/
template <typename IdType, typename FloatType, int TILE_SIZE>
__global__ void _CSRAResValueKernel(
const uint64_t rand_seed,
const int64_t num_picks,
const int64_t num_rows,
const IdType * const in_rows,
const IdType * const in_ptr,
const IdType * const data,
const FloatType * const prob,
const IdType * const ares_ptr,
IdType * const ares_idxs,
FloatType * const ares) {
const uint64_t rand_seed, const int64_t num_picks, const int64_t num_rows,
const IdType* const in_rows, const IdType* const in_ptr,
const IdType* const data, const FloatType* const prob,
const IdType* const ares_ptr, IdType* const ares_idxs,
FloatType* const ares) {
int64_t out_row = blockIdx.x * TILE_SIZE;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
const int64_t last_row =
min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
curandStatePhilox4_32_10_t rng;
curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
......@@ -181,9 +175,11 @@ __global__ void _CSRAResValueKernel(
const int64_t in_idx = in_row_start + idx;
const int64_t ares_idx = ares_row_start + idx;
FloatType item_prob;
_DoubleSlice<IdType, FloatType>(prob, data, idx, in_row_start, &item_prob);
_DoubleSlice<IdType, FloatType>(
prob, data, idx, in_row_start, &item_prob);
// compute A-Res value
ares[ares_idx] = static_cast<FloatType>(__powf(curand_uniform(&rng), 1.0f / item_prob));
ares[ares_idx] = static_cast<FloatType>(
__powf(curand_uniform(&rng), 1.0f / item_prob));
ares_idxs[ares_idx] = static_cast<IdType>(in_idx);
}
}
......@@ -191,47 +187,42 @@ __global__ void _CSRAResValueKernel(
}
}
/**
* @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO matrix,
* without replacement. After sorting, we select top-num_picks items.
*
* @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR.
* @param in_cols The columns array of the input CSR.
* @param data The data array of the input CSR.
* @param out_ptr The offset to write each row to in the output COO.
* @param ares_ptr The offset to write each row to in the ares array.
* @param sort_ares_idxs The sorted A-Res value corresponding index array, the index of input CSR.
* @param out_rows The rows of the output COO (output).
* @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output).
* @author pengqirong (OPPO)
*/
template<typename IdType, typename FloatType, int TILE_SIZE>
* @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO
* matrix, without replacement. After sorting, we select top-num_picks items.
*
* @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR.
* @param in_cols The columns array of the input CSR.
* @param data The data array of the input CSR.
* @param out_ptr The offset to write each row to in the output COO.
* @param ares_ptr The offset to write each row to in the ares array.
* @param sort_ares_idxs The sorted A-Res value corresponding index array, the
* index of input CSR.
* @param out_rows The rows of the output COO (output).
* @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output).
* @author pengqirong (OPPO)
*/
template <typename IdType, typename FloatType, int TILE_SIZE>
__global__ void _CSRRowWiseSampleKernel(
const int64_t num_picks,
const int64_t num_rows,
const IdType * const in_rows,
const IdType * const in_ptr,
const IdType * const in_cols,
const IdType * const data,
const IdType * const out_ptr,
const IdType * const ares_ptr,
const IdType * const sort_ares_idxs,
IdType * const out_rows,
IdType * const out_cols,
IdType * const out_idxs) {
const int64_t num_picks, const int64_t num_rows,
const IdType* const in_rows, const IdType* const in_ptr,
const IdType* const in_cols, const IdType* const data,
const IdType* const out_ptr, const IdType* const ares_ptr,
const IdType* const sort_ares_idxs, IdType* const out_rows,
IdType* const out_cols, IdType* const out_idxs) {
// we assign one warp per row
assert(blockDim.x == BLOCK_SIZE);
int64_t out_row = blockIdx.x * TILE_SIZE;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
const int64_t last_row =
min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
while (out_row < last_row) {
const int64_t row = in_rows[out_row];
......@@ -267,70 +258,64 @@ __global__ void _CSRRowWiseSampleKernel(
}
}
// A stateful callback functor that maintains a running prefix to be applied
// during consecutive scan operations.
template<typename FloatType>
template <typename FloatType>
struct BlockPrefixCallbackOp {
// Running prefix
FloatType running_total;
// Constructor
__device__ BlockPrefixCallbackOp(FloatType running_total) : running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide scan.
__device__ FloatType operator()(FloatType block_aggregate) {
FloatType old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
// Running prefix
FloatType running_total;
// Constructor
__device__ BlockPrefixCallbackOp(FloatType running_total)
: running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the block.
// Thread-0 is responsible for returning a value for seeding the block-wide
// scan.
__device__ FloatType operator()(FloatType block_aggregate) {
FloatType old_prefix = running_total;
running_total += block_aggregate;
return old_prefix;
}
};
/**
* @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO matrix,
* with replacement. We store the CDF (unnormalized) of all neighbors of a row
* in global memory and use binary search to find inverse indices as selected items.
*
* @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR.
* @param in_cols The columns array of the input CSR.
* @param data The data array of the input CSR.
* @param prob The probability array of the input CSR.
* @param out_ptr The offset to write each row to in the output COO.
* @param cdf_ptr The offset of each cdf segment.
* @param cdf The global buffer to store cdf segments.
* @param out_rows The rows of the output COO (output).
* @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output).
* @author pengqirong (OPPO)
*/
template<typename IdType, typename FloatType, int TILE_SIZE>
* @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO
* matrix, with replacement. We store the CDF (unnormalized) of all neighbors of
* a row in global memory and use binary search to find inverse indices as
* selected items.
*
* @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices.
* @tparam TILE_SIZE The number of rows covered by each threadblock.
* @param rand_seed The random seed to use.
* @param num_picks The number of non-zeros to pick per row.
* @param num_rows The number of rows to pick.
* @param in_rows The set of rows to pick.
* @param in_ptr The indptr array of the input CSR.
* @param in_cols The columns array of the input CSR.
* @param data The data array of the input CSR.
* @param prob The probability array of the input CSR.
* @param out_ptr The offset to write each row to in the output COO.
* @param cdf_ptr The offset of each cdf segment.
* @param cdf The global buffer to store cdf segments.
* @param out_rows The rows of the output COO (output).
* @param out_cols The columns of the output COO (output).
* @param out_idxs The data array of the output COO (output).
* @author pengqirong (OPPO)
*/
template <typename IdType, typename FloatType, int TILE_SIZE>
__global__ void _CSRRowWiseSampleReplaceKernel(
const uint64_t rand_seed,
const int64_t num_picks,
const int64_t num_rows,
const IdType * const in_rows,
const IdType * const in_ptr,
const IdType * const in_cols,
const IdType * const data,
const FloatType * const prob,
const IdType * const out_ptr,
const IdType * const cdf_ptr,
FloatType * const cdf,
IdType * const out_rows,
IdType * const out_cols,
IdType * const out_idxs
) {
const uint64_t rand_seed, const int64_t num_picks, const int64_t num_rows,
const IdType* const in_rows, const IdType* const in_ptr,
const IdType* const in_cols, const IdType* const data,
const FloatType* const prob, const IdType* const out_ptr,
const IdType* const cdf_ptr, FloatType* const cdf, IdType* const out_rows,
IdType* const out_cols, IdType* const out_idxs) {
// we assign one warp per row
assert(blockDim.x == BLOCK_SIZE);
int64_t out_row = blockIdx.x * TILE_SIZE;
const int64_t last_row = min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
const int64_t last_row =
min(static_cast<int64_t>(blockIdx.x + 1) * TILE_SIZE, num_rows);
curandStatePhilox4_32_10_t rng;
curand_init(rand_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
......@@ -357,12 +342,14 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
// Load a segment of consecutive items that are blocked across threads
FloatType thread_data;
if (idx < deg)
_DoubleSlice<IdType, FloatType>(prob, data, idx, in_row_start, &thread_data);
_DoubleSlice<IdType, FloatType>(
prob, data, idx, in_row_start, &thread_data);
else
thread_data = MIN_THREAD_DATA;
thread_data = max(thread_data, MIN_THREAD_DATA);
// Collectively compute the block-wide inclusive prefix sum
BlockScan(temp_storage).InclusiveSum(thread_data, thread_data, prefix_op);
BlockScan(temp_storage)
.InclusiveSum(thread_data, thread_data, prefix_op);
__syncthreads();
// Store scanned items to cdf array
......@@ -376,7 +363,8 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
// get random value
FloatType sum = cdf[cdf_row_start + deg - 1];
FloatType rand = static_cast<FloatType>(curand_uniform(&rng) * sum);
// get the offset of the first value within cdf array which is greater than random value.
// get the offset of the first value within cdf array which is greater
// than random value.
int64_t item = cub::UpperBound<FloatType*, int64_t, FloatType>(
&cdf[cdf_row_start], deg, rand);
item = min(item, deg - 1);
......@@ -395,7 +383,8 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
template <typename IdType, typename DType, typename BoolType>
__global__ void _GenerateFlagsKernel(
int64_t n, const IdType* idx, const DType* values, DType criteria, BoolType* output) {
int64_t n, const IdType* idx, const DType* values, DType criteria,
BoolType* output) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
while (tx < n) {
......@@ -413,8 +402,8 @@ COOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) {
const int64_t nnz = coo.row->shape[0];
const IdType* row = coo.row.Ptr<IdType>();
const IdType* col = coo.col.Ptr<IdType>();
const IdArray& eid = COOHasData(coo) ? coo.data : Range(
0, nnz, sizeof(IdType) * 8, ctx);
const IdArray& eid =
COOHasData(coo) ? coo.data : Range(0, nnz, sizeof(IdType) * 8, ctx);
const IdType* data = coo.data.Ptr<IdType>();
IdArray new_row = IdArray::Empty({nnz}, idtype, ctx);
IdArray new_col = IdArray::Empty({nnz}, idtype, ctx);
......@@ -431,7 +420,8 @@ COOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) {
maskgen(nb, nt, stream, nnz, data, flags);
int64_t* rst = static_cast<int64_t*>(device->AllocWorkspace(ctx, sizeof(int64_t)));
int64_t* rst =
static_cast<int64_t*>(device->AllocWorkspace(ctx, sizeof(int64_t)));
MaskSelect(device, ctx, row, flags, new_row_data, nnz, rst, stream);
MaskSelect(device, ctx, col, flags, new_col_data, nnz, rst, stream);
MaskSelect(device, ctx, data, flags, new_eid_data, nnz, rst, stream);
......@@ -441,24 +431,24 @@ COOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) {
device->FreeWorkspace(ctx, flags);
device->FreeWorkspace(ctx, rst);
return COOMatrix(
coo.num_rows,
coo.num_cols,
new_row.CreateView({new_len}, idtype, 0),
coo.num_rows, coo.num_cols, new_row.CreateView({new_len}, idtype, 0),
new_col.CreateView({new_len}, idtype, 0),
new_eid.CreateView({new_len}, idtype, 0));
}
template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix _COORemoveIf(const COOMatrix& coo, const NDArray& values, DType criteria) {
COOMatrix _COORemoveIf(
const COOMatrix& coo, const NDArray& values, DType criteria) {
const DType* val = values.Ptr<DType>();
auto maskgen = [val, criteria] (
int nb, int nt, cudaStream_t stream, int64_t nnz, const IdType* data,
int8_t* flags) {
CUDA_KERNEL_CALL((_GenerateFlagsKernel<IdType, DType, int8_t>),
nb, nt, 0, stream,
nnz, data, val, criteria, flags);
auto maskgen = [val, criteria](
int nb, int nt, cudaStream_t stream, int64_t nnz,
const IdType* data, int8_t* flags) {
CUDA_KERNEL_CALL(
(_GenerateFlagsKernel<IdType, DType, int8_t>), nb, nt, 0, stream, nnz,
data, val, criteria, flags);
};
return COOGeneralRemoveIf<XPU, IdType, DType, decltype(maskgen)>(coo, maskgen);
return COOGeneralRemoveIf<XPU, IdType, DType, decltype(maskgen)>(
coo, maskgen);
}
} // namespace
......@@ -466,42 +456,42 @@ COOMatrix _COORemoveIf(const COOMatrix& coo, const NDArray& values, DType criter
/////////////////////////////// CSR ///////////////////////////////
/**
* @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO matrix.
* Use CDF sampling algorithm for with replacement:
* 1) Calculate the CDF of all neighbor's prob.
* 2) For each [0, num_picks), generate a rand ~ U(0, 1).
* Use binary search to find its index in the CDF array as a chosen item.
* Use A-Res sampling algorithm for without replacement:
* 1) For rows with deg > num_picks, calculate A-Res values for all neighbors.
* 2) Sort the A-Res array and select top-num_picks as chosen items.
*
* @tparam XPU The device type used for matrices.
* @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices.
* @param mat The CSR matrix.
* @param rows The set of rows to pick.
* @param num_picks The number of non-zeros to pick per row.
* @param prob The probability array of the input CSR.
* @param replace Is replacement sampling?
* @author pengqirong (OPPO), dlasalle and Xin from Nvidia.
*/
* @brief Perform weighted row-wise sampling on a CSR matrix, and generate a COO
* matrix. Use CDF sampling algorithm for with replacement:
* 1) Calculate the CDF of all neighbor's prob.
* 2) For each [0, num_picks), generate a rand ~ U(0, 1). Use binary search to
* find its index in the CDF array as a chosen item.
* Use A-Res sampling algorithm for without replacement:
* 1) For rows with deg > num_picks, calculate A-Res values for all neighbors.
* 2) Sort the A-Res array and select top-num_picks as chosen items.
*
* @tparam XPU The device type used for matrices.
* @tparam IdType The ID type used for matrices.
* @tparam FloatType The Float type used for matrices.
* @param mat The CSR matrix.
* @param rows The set of rows to pick.
* @param num_picks The number of non-zeros to pick per row.
* @param prob The probability array of the input CSR.
* @param replace Is replacement sampling?
* @author pengqirong (OPPO), dlasalle and Xin from Nvidia.
*/
template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix _CSRRowWiseSampling(
const CSRMatrix& mat,
const IdArray& rows,
int64_t num_picks,
const FloatArray& prob,
bool replace) {
const CSRMatrix& mat, const IdArray& rows, int64_t num_picks,
const FloatArray& prob, bool replace) {
const auto& ctx = rows->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
const int64_t num_rows = rows->shape[0];
const IdType * const slice_rows = static_cast<const IdType*>(rows->data);
IdArray picked_row = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdArray picked_col = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdArray picked_idx = NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
const IdType* const slice_rows = static_cast<const IdType*>(rows->data);
IdArray picked_row =
NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdArray picked_col =
NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdArray picked_idx =
NewIdArray(num_rows * num_picks, ctx, sizeof(IdType) * 8);
IdType* const out_rows = static_cast<IdType*>(picked_row->data);
IdType* const out_cols = static_cast<IdType*>(picked_col->data);
IdType* const out_idxs = static_cast<IdType*>(picked_idx->data);
......@@ -511,16 +501,12 @@ COOMatrix _CSRRowWiseSampling(
const IdType* data = CSRHasData(mat) ? mat.data.Ptr<IdType>() : nullptr;
const FloatType* prob_data = prob.Ptr<FloatType>();
if (mat.is_pinned) {
CUDA_CALL(cudaHostGetDevicePointer(
&in_ptr, mat.indptr.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer(
&in_cols, mat.indices.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer(&in_ptr, mat.indptr.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer(&in_cols, mat.indices.Ptr<IdType>(), 0));
if (CSRHasData(mat)) {
CUDA_CALL(cudaHostGetDevicePointer(
&data, mat.data.Ptr<IdType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer(&data, mat.data.Ptr<IdType>(), 0));
}
CUDA_CALL(cudaHostGetDevicePointer(
&prob_data, prob.Ptr<FloatType>(), 0));
CUDA_CALL(cudaHostGetDevicePointer(&prob_data, prob.Ptr<FloatType>(), 0));
}
// compute degree
......@@ -528,41 +514,33 @@ COOMatrix _CSRRowWiseSampling(
// temp_deg: the size of each row we will manipulate in sampling
// 1) for w/o replacement: in degree if it's greater than num_picks else 0
// 2) for w/ replacement: in degree
IdType * out_deg = static_cast<IdType*>(
IdType* out_deg = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
IdType * temp_deg = static_cast<IdType*>(
IdType* temp_deg = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
if (replace) {
const dim3 block(512);
const dim3 grid((num_rows + block.x - 1) / block.x);
CUDA_KERNEL_CALL(
_CSRRowWiseSampleDegreeReplaceKernel,
grid, block, 0, stream,
num_picks, num_rows, slice_rows, in_ptr, out_deg, temp_deg);
_CSRRowWiseSampleDegreeReplaceKernel, grid, block, 0, stream, num_picks,
num_rows, slice_rows, in_ptr, out_deg, temp_deg);
} else {
const dim3 block(512);
const dim3 grid((num_rows + block.x - 1) / block.x);
CUDA_KERNEL_CALL(
_CSRRowWiseSampleDegreeKernel,
grid, block, 0, stream,
num_picks, num_rows, slice_rows, in_ptr, out_deg, temp_deg);
_CSRRowWiseSampleDegreeKernel, grid, block, 0, stream, num_picks,
num_rows, slice_rows, in_ptr, out_deg, temp_deg);
}
// fill temp_ptr
IdType * temp_ptr = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows + 1)*sizeof(IdType)));
IdType* temp_ptr = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
size_t prefix_temp_size = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_temp_size,
temp_deg,
temp_ptr,
num_rows + 1,
stream));
void * prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_temp, prefix_temp_size,
temp_deg,
temp_ptr,
num_rows + 1,
stream));
CUDA_CALL(cub::DeviceScan::ExclusiveSum(
nullptr, prefix_temp_size, temp_deg, temp_ptr, num_rows + 1, stream));
void* prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(
prefix_temp, prefix_temp_size, temp_deg, temp_ptr, num_rows + 1, stream));
device->FreeWorkspace(ctx, prefix_temp);
device->FreeWorkspace(ctx, temp_deg);
......@@ -570,49 +548,40 @@ COOMatrix _CSRRowWiseSampling(
// cuda events cannot be ignored. Just use synchronized copy.
IdType temp_len;
// copy using the internal current stream.
device->CopyDataFromTo(temp_ptr, num_rows * sizeof(temp_len), &temp_len, 0,
sizeof(temp_len),
ctx,
DGLContext{kDGLCPU, 0},
mat.indptr->dtype);
device->CopyDataFromTo(
temp_ptr, num_rows * sizeof(temp_len), &temp_len, 0, sizeof(temp_len),
ctx, DGLContext{kDGLCPU, 0}, mat.indptr->dtype);
device->StreamSync(ctx, stream);
// fill out_ptr
IdType * out_ptr = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows+1)*sizeof(IdType)));
IdType* out_ptr = static_cast<IdType*>(
device->AllocWorkspace(ctx, (num_rows + 1) * sizeof(IdType)));
prefix_temp_size = 0;
CUDA_CALL(cub::DeviceScan::ExclusiveSum(nullptr, prefix_temp_size,
out_deg,
out_ptr,
num_rows+1,
stream));
CUDA_CALL(cub::DeviceScan::ExclusiveSum(
nullptr, prefix_temp_size, out_deg, out_ptr, num_rows + 1, stream));
prefix_temp = device->AllocWorkspace(ctx, prefix_temp_size);
CUDA_CALL(cub::DeviceScan::ExclusiveSum(prefix_temp, prefix_temp_size,
out_deg,
out_ptr,
num_rows+1,
stream));
CUDA_CALL(cub::DeviceScan::ExclusiveSum(
prefix_temp, prefix_temp_size, out_deg, out_ptr, num_rows + 1, stream));
device->FreeWorkspace(ctx, prefix_temp);
device->FreeWorkspace(ctx, out_deg);
cudaEvent_t copyEvent;
CUDA_CALL(cudaEventCreate(&copyEvent));
// TODO(dlasalle): use pinned memory to overlap with the actual sampling, and wait on
// a cudaevent
// TODO(dlasalle): use pinned memory to overlap with the actual sampling, and
// wait on a cudaevent
IdType new_len;
// copy using the internal current stream.
device->CopyDataFromTo(out_ptr, num_rows * sizeof(new_len), &new_len, 0,
sizeof(new_len),
ctx,
DGLContext{kDGLCPU, 0},
mat.indptr->dtype);
device->CopyDataFromTo(
out_ptr, num_rows * sizeof(new_len), &new_len, 0, sizeof(new_len), ctx,
DGLContext{kDGLCPU, 0}, mat.indptr->dtype);
CUDA_CALL(cudaEventRecord(copyEvent, stream));
// allocate workspace
// 1) for w/ replacement, it's a global buffer to store cdf segments (one segment for each row).
// 1) for w/ replacement, it's a global buffer to store cdf segments (one
// segment for each row).
// 2) for w/o replacement, it's used to store a-res segments (one segment for
// each row with degree > num_picks)
FloatType * temp = static_cast<FloatType*>(
// each row with degree > num_picks)
FloatType* temp = static_cast<FloatType*>(
device->AllocWorkspace(ctx, temp_len * sizeof(FloatType)));
const uint64_t rand_seed = RandomEngine::ThreadLocal()->RandInt(1000000000);
......@@ -624,21 +593,9 @@ COOMatrix _CSRRowWiseSampling(
const dim3 block(BLOCK_SIZE);
const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
CUDA_KERNEL_CALL(
(_CSRRowWiseSampleReplaceKernel<IdType, FloatType, TILE_SIZE>),
grid, block, 0, stream,
rand_seed,
num_picks,
num_rows,
slice_rows,
in_ptr,
in_cols,
data,
prob_data,
out_ptr,
temp_ptr,
temp,
out_rows,
out_cols,
(_CSRRowWiseSampleReplaceKernel<IdType, FloatType, TILE_SIZE>), grid,
block, 0, stream, rand_seed, num_picks, num_rows, slice_rows, in_ptr,
in_cols, data, prob_data, out_ptr, temp_ptr, temp, out_rows, out_cols,
out_idxs);
device->FreeWorkspace(ctx, temp);
} else { // without replacement
......@@ -646,53 +603,33 @@ COOMatrix _CSRRowWiseSampling(
device->AllocWorkspace(ctx, (temp_len) * sizeof(IdType)));
// Compute A-Res value. A-Res value needs to be calculated only if deg
// is greater than num_picks in weighted rowwise sampling without replacement.
// is greater than num_picks in weighted rowwise sampling without
// replacement.
const dim3 block(BLOCK_SIZE);
const dim3 grid((num_rows + TILE_SIZE - 1) / TILE_SIZE);
CUDA_KERNEL_CALL(
(_CSRAResValueKernel<IdType, FloatType, TILE_SIZE>),
grid, block, 0, stream,
rand_seed,
num_picks,
num_rows,
slice_rows,
in_ptr,
data,
prob_data,
temp_ptr,
temp_idxs,
temp);
(_CSRAResValueKernel<IdType, FloatType, TILE_SIZE>), grid, block, 0,
stream, rand_seed, num_picks, num_rows, slice_rows, in_ptr, data,
prob_data, temp_ptr, temp_idxs, temp);
// sort A-Res value array.
FloatType* sort_temp = static_cast<FloatType*>(
device->AllocWorkspace(ctx, temp_len * sizeof(FloatType)));
device->AllocWorkspace(ctx, temp_len * sizeof(FloatType)));
IdType* sort_temp_idxs = static_cast<IdType*>(
device->AllocWorkspace(ctx, temp_len * sizeof(IdType)));
device->AllocWorkspace(ctx, temp_len * sizeof(IdType)));
cub::DoubleBuffer<FloatType> sort_keys(temp, sort_temp);
cub::DoubleBuffer<IdType> sort_values(temp_idxs, sort_temp_idxs);
void *d_temp_storage = nullptr;
void* d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
CUDA_CALL(cub::DeviceSegmentedSort::SortPairsDescending(
d_temp_storage,
temp_storage_bytes,
sort_keys,
sort_values,
temp_len,
num_rows,
temp_ptr,
temp_ptr + 1, stream));
d_temp_storage, temp_storage_bytes, sort_keys, sort_values, temp_len,
num_rows, temp_ptr, temp_ptr + 1, stream));
d_temp_storage = device->AllocWorkspace(ctx, temp_storage_bytes);
CUDA_CALL(cub::DeviceSegmentedSort::SortPairsDescending(
d_temp_storage,
temp_storage_bytes,
sort_keys,
sort_values,
temp_len,
num_rows,
temp_ptr,
temp_ptr + 1, stream));
d_temp_storage, temp_storage_bytes, sort_keys, sort_values, temp_len,
num_rows, temp_ptr, temp_ptr + 1, stream));
device->FreeWorkspace(ctx, d_temp_storage);
device->FreeWorkspace(ctx, temp);
device->FreeWorkspace(ctx, temp_idxs);
......@@ -701,20 +638,9 @@ COOMatrix _CSRRowWiseSampling(
// select tok-num_picks as results
CUDA_KERNEL_CALL(
(_CSRRowWiseSampleKernel<IdType, FloatType, TILE_SIZE>),
grid, block, 0, stream,
num_picks,
num_rows,
slice_rows,
in_ptr,
in_cols,
data,
out_ptr,
temp_ptr,
sort_values.Current(),
out_rows,
out_cols,
out_idxs);
(_CSRRowWiseSampleKernel<IdType, FloatType, TILE_SIZE>), grid, block, 0,
stream, num_picks, num_rows, slice_rows, in_ptr, in_cols, data, out_ptr,
temp_ptr, sort_values.Current(), out_rows, out_cols, out_idxs);
}
device->FreeWorkspace(ctx, temp_ptr);
......@@ -728,44 +654,48 @@ COOMatrix _CSRRowWiseSampling(
picked_col = picked_col.CreateView({new_len}, picked_col->dtype);
picked_idx = picked_idx.CreateView({new_len}, picked_idx->dtype);
return COOMatrix(mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx);
return COOMatrix(
mat.num_rows, mat.num_cols, picked_row, picked_col, picked_idx);
}
template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_picks, FloatArray prob, bool replace) {
CSRMatrix mat, IdArray rows, int64_t num_picks, FloatArray prob,
bool replace) {
COOMatrix result;
if (num_picks == -1) {
// Basically this is UnitGraph::InEdges().
COOMatrix coo = CSRToCOO(CSRSliceRows(mat, rows), false);
IdArray sliced_rows = IndexSelect(rows, coo.row);
result = COOMatrix(mat.num_rows, mat.num_cols, sliced_rows, coo.col, coo.data);
result =
COOMatrix(mat.num_rows, mat.num_cols, sliced_rows, coo.col, coo.data);
} else {
result = _CSRRowWiseSampling<XPU, IdType, DType>(mat, rows, num_picks, prob, replace);
result = _CSRRowWiseSampling<XPU, IdType, DType>(
mat, rows, num_picks, prob, replace);
}
// NOTE(BarclayII): I'm removing the entries with zero probability after sampling.
// Is there a better way?
// NOTE(BarclayII): I'm removing the entries with zero probability after
// sampling. Is there a better way?
return _COORemoveIf<XPU, IdType, DType>(result, prob, static_cast<DType>(0));
}
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
CSRMatrix, IdArray, int64_t, FloatArray, bool);
// These are not being called, but we instantiate them anyway to prevent missing
// symbols in Debug build
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, int8_t>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, int8_t>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, uint8_t>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, uint8_t>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
CSRMatrix, IdArray, int64_t, FloatArray, bool);
} // namespace impl
} // namespace aten
......
......@@ -4,8 +4,9 @@
* @brief SDDMM C APIs and definitions.
*/
#include <dgl/array.h>
#include "./sddmm.cuh"
#include "./functor.cuh"
#include "./sddmm.cuh"
namespace dgl {
namespace aten {
......@@ -14,110 +15,85 @@ namespace aten {
* @brief CUDA implementation of g-SDDMM on Csr format.
*/
template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target) {
void SDDMMCsr(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, csr, lhs, rhs, out);
});
});
}
/**
* @brief CUDA implementation of g-SDDMM on Coo format.
*/
template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op,
const BcastOff& bcast,
const COOMatrix& coo,
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target) {
void SDDMMCoo(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target) {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, coo, lhs, rhs, out);
});
});
}
template void SDDMMCsr<kDGLCUDA, int32_t, __half>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int64_t, __half>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
#if BF16_ENABLED
template void SDDMMCsr<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
#endif // BF16_ENABLED
template void SDDMMCsr<kDGLCUDA, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int64_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int32_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int64_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int32_t, __half>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int64_t, __half>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
#if BF16_ENABLED
template void SDDMMCoo<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
#endif // BF16_ENABLED
template void SDDMMCoo<kDGLCUDA, int32_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int64_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int32_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int64_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
} // namespace aten
} // namespace dgl
......@@ -7,15 +7,16 @@
#define DGL_ARRAY_CUDA_SDDMM_CUH_
#include <dgl/bcast.h>
#include "macro.cuh"
#include "../../runtime/cuda/cuda_common.h"
#include "../selector.h"
#include "./functor.cuh"
#include "./utils.h"
#include "atomic.cuh"
#include "functor.cuh"
#include "fp16.cuh"
#include "bf16.cuh"
#include "./utils.h"
#include "./functor.cuh"
#include "../selector.h"
#include "../../runtime/cuda/cuda_common.h"
#include "fp16.cuh"
#include "functor.cuh"
#include "macro.cuh"
namespace dgl {
......@@ -24,64 +25,64 @@ using namespace cuda;
namespace aten {
namespace cuda {
#define SWITCH_OP(op, Op, ...) \
do { \
if ((op) == "add") { \
typedef cuda::binary::Add<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "sub") { \
typedef cuda::binary::Sub<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "mul") { \
typedef cuda::binary::Mul<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "div") { \
typedef cuda::binary::Div<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_lhs") { \
typedef cuda::binary::CopyLhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_rhs") { \
typedef cuda::binary::CopyRhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "dot") { \
typedef cuda::binary::Dot<DType> Op; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Unsupported SpMM/SDDMM binary operator: " << op; \
} \
} while (0)
#define SWITCH_RHS(rhs_target, RhsTarget, ...) \
#define SWITCH_OP(op, Op, ...) \
do { \
if ((rhs_target) == 0) { \
constexpr int RhsTarget = 0; \
if ((op) == "add") { \
typedef cuda::binary::Add<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "sub") { \
typedef cuda::binary::Sub<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "mul") { \
typedef cuda::binary::Mul<DType> Op; \
{ __VA_ARGS__ } \
} else if ((rhs_target) == 1) { \
constexpr int RhsTarget = 1; \
} else if ((op) == "div") { \
typedef cuda::binary::Div<DType> Op; \
{ __VA_ARGS__ } \
} else if ((rhs_target) == 2) { \
constexpr int RhsTarget = 2; \
} else if ((op) == "copy_lhs") { \
typedef cuda::binary::CopyLhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_rhs") { \
typedef cuda::binary::CopyRhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "dot") { \
typedef cuda::binary::Dot<DType> Op; \
{ __VA_ARGS__ } \
} else { \
LOG(INFO) << "Invalid rhs target: " << (rhs_target); \
LOG(FATAL) << "Unsupported SpMM/SDDMM binary operator: " << op; \
} \
} while (0)
#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...)\
do { \
if ((lhs_target) == 0) { \
constexpr int LhsTarget = 0; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 1) { \
constexpr int LhsTarget = 1; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 2) { \
constexpr int LhsTarget = 2; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else { \
LOG(INFO) << "Invalid lhs target: " << (lhs_target); \
} \
#define SWITCH_RHS(rhs_target, RhsTarget, ...) \
do { \
if ((rhs_target) == 0) { \
constexpr int RhsTarget = 0; \
{ __VA_ARGS__ } \
} else if ((rhs_target) == 1) { \
constexpr int RhsTarget = 1; \
{ __VA_ARGS__ } \
} else if ((rhs_target) == 2) { \
constexpr int RhsTarget = 2; \
{ __VA_ARGS__ } \
} else { \
LOG(INFO) << "Invalid rhs target: " << (rhs_target); \
} \
} while (0)
#define SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, ...) \
do { \
if ((lhs_target) == 0) { \
constexpr int LhsTarget = 0; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 1) { \
constexpr int LhsTarget = 1; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else if ((lhs_target) == 2) { \
constexpr int LhsTarget = 2; \
SWITCH_RHS(rhs_target, RhsTarget, __VA_ARGS__); \
} else { \
LOG(INFO) << "Invalid lhs target: " << (lhs_target); \
} \
} while (0)
constexpr unsigned int full_mask = 0xffffffff;
......@@ -90,23 +91,19 @@ constexpr unsigned int full_mask = 0xffffffff;
* @brief CUDA kernel of g-SDDMM on Coo format.
* @note it uses edge parallel strategy, different threadblocks (on y-axis)
* is responsible for the computation on different edges. Threadblocks
* on the x-axis are responsible for the computation on different positions
* in feature dimension.
* on the x-axis are responsible for the computation on different
* positions in feature dimension.
*/
template <typename Idx, typename DType, typename BinaryOp,
bool UseBcast = false, bool UseIdx = false,
int LhsTarget = 0, int RhsTarget = 2>
template <
typename Idx, typename DType, typename BinaryOp, bool UseBcast = false,
bool UseIdx = false, int LhsTarget = 0, int RhsTarget = 2>
__global__ void SDDMMCooKernel(
const DType* __restrict__ lhs,
const DType* __restrict__ rhs,
DType* __restrict__ out,
const Idx* __restrict__ row,
const Idx* __restrict__ col,
const Idx* __restrict__ edge_map,
int64_t N, int64_t M, int64_t E, int64_t reduce_size,
const int64_t* __restrict__ lhs_off,
const int64_t* __restrict__ rhs_off,
int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
const DType* __restrict__ lhs, const DType* __restrict__ rhs,
DType* __restrict__ out, const Idx* __restrict__ row,
const Idx* __restrict__ col, const Idx* __restrict__ edge_map, int64_t N,
int64_t M, int64_t E, int64_t reduce_size,
const int64_t* __restrict__ lhs_off, const int64_t* __restrict__ rhs_off,
int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
// SDDMM with COO.
Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.y;
......@@ -114,10 +111,14 @@ __global__ void SDDMMCooKernel(
const Idx src = _ldg(row + ty);
const Idx dst = _ldg(col + ty);
const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
const DType* lhsoff = BinaryOp::use_lhs ?
(lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len): nullptr;
const DType* rhsoff = BinaryOp::use_rhs ?
(rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len): nullptr;
const DType* lhsoff =
BinaryOp::use_lhs
? (lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len)
: nullptr;
const DType* rhsoff =
BinaryOp::use_rhs
? (rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len)
: nullptr;
DType* outoff = out + eid * out_len;
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = blockDim.x * gridDim.x;
......@@ -125,8 +126,7 @@ __global__ void SDDMMCooKernel(
const Idx lhs_add = UseBcast ? lhs_off[tx] : tx;
const Idx rhs_add = UseBcast ? rhs_off[tx] : tx;
DType val = BinaryOp::Call(
lhsoff + lhs_add * reduce_size,
rhsoff + rhs_add * reduce_size,
lhsoff + lhs_add * reduce_size, rhsoff + rhs_add * reduce_size,
reduce_size);
outoff[tx] = val;
tx += stride_x;
......@@ -136,56 +136,58 @@ __global__ void SDDMMCooKernel(
}
/**
* @brief CUDA kernel of SDDMM-dot on Coo format, accelerated with tree reduction.
* @brief CUDA kernel of SDDMM-dot on Coo format, accelerated with tree
* reduction.
* @note it uses edge parallel strategy, different threadblocks (on y-axis)
* is responsible for the computation on different edges. Threadblocks
* on the x-axis are responsible for the computation on different positions
* in feature dimension.
* on the x-axis are responsible for the computation on different
* positions in feature dimension.
*/
template <typename Idx, typename DType,
bool UseBcast = false, bool UseIdx = false,
int LhsTarget = 0, int RhsTarget = 2>
template <
typename Idx, typename DType, bool UseBcast = false, bool UseIdx = false,
int LhsTarget = 0, int RhsTarget = 2>
__global__ void SDDMMCooTreeReduceKernel(
const DType* __restrict__ lhs,
const DType* __restrict__ rhs,
DType* __restrict__ out,
const Idx* __restrict__ row,
const Idx* __restrict__ col,
const Idx* __restrict__ edge_map,
int64_t N, int64_t M, int64_t E, int64_t reduce_size,
const int64_t* __restrict__ lhs_off,
const int64_t* __restrict__ rhs_off,
int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
const DType* __restrict__ lhs, const DType* __restrict__ rhs,
DType* __restrict__ out, const Idx* __restrict__ row,
const Idx* __restrict__ col, const Idx* __restrict__ edge_map, int64_t N,
int64_t M, int64_t E, int64_t reduce_size,
const int64_t* __restrict__ lhs_off, const int64_t* __restrict__ rhs_off,
int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
Idx ty = blockIdx.x * blockDim.y + threadIdx.y;
if (ty < E) {
const Idx src = _ldg(row + ty);
const Idx dst = _ldg(col + ty);
const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
const DType* lhsoff = lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len;
const DType* rhsoff = rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len;
const DType* lhsoff =
lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len;
const DType* rhsoff =
rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len;
DType* outoff = out + eid * out_len;
int tx = threadIdx.x; // tx < 32
for (int i = blockIdx.y; i < out_len; i += gridDim.y) { // over output feature dimension
for (int i = blockIdx.y; i < out_len;
i += gridDim.y) { // over output feature dimension
const Idx lhs_add = UseBcast ? __ldg(lhs_off + i) : i;
const Idx rhs_add = UseBcast ? __ldg(rhs_off + i) : i;
DType val = reduce::Sum<Idx, DType>::zero();;
DType val = reduce::Sum<Idx, DType>::zero();
for (int j = tx; j < reduce_size; j += 64) {
val += lhsoff[lhs_add * reduce_size + j] * rhsoff[rhs_add * reduce_size + j];
val += lhsoff[lhs_add * reduce_size + j] *
rhsoff[rhs_add * reduce_size + j];
if (j + 32 < reduce_size)
val += lhsoff[lhs_add * reduce_size + j + 32] * rhsoff[rhs_add * reduce_size + j + 32];
val += lhsoff[lhs_add * reduce_size + j + 32] *
rhsoff[rhs_add * reduce_size + j + 32];
}
#pragma unroll
for (int offset = 16; offset > 0; offset /= 2)
val += __shfl_down_sync(full_mask, val, offset);
if (tx == 0)
outoff[i] = val;
if (tx == 0) outoff[i] = val;
}
}
}
// Binary search the row_offsets to find the source node of the edge id.
template <typename Idx>
__device__ __forceinline__ Idx BinarySearchSrc(const Idx *array, Idx length, Idx eid) {
__device__ __forceinline__ Idx
BinarySearchSrc(const Idx* array, Idx length, Idx eid) {
Idx lo = 0, hi = length - 1;
while (lo < hi) {
Idx mid = (lo + hi) >> 1;
......@@ -207,25 +209,21 @@ __device__ __forceinline__ Idx BinarySearchSrc(const Idx *array, Idx length, Idx
* @brief CUDA kernel of g-SDDMM on Csr format.
* @note it uses edge parallel strategy, different threadblocks (on y-axis)
* is responsible for the computation on different edges. Threadblocks
* on the x-axis are responsible for the computation on different positions
* in feature dimension.
* To efficiently find the source node idx and destination node index of an
* given edge on Csr format, it uses binary search (time complexity O(log N)).
* on the x-axis are responsible for the computation on different
* positions in feature dimension. To efficiently find the source node idx and
* destination node index of an given edge on Csr format, it uses binary search
* (time complexity O(log N)).
*/
template <typename Idx, typename DType, typename BinaryOp,
bool UseBcast = false, bool UseIdx = false,
int LhsTarget = 0, int RhsTarget = 2>
template <
typename Idx, typename DType, typename BinaryOp, bool UseBcast = false,
bool UseIdx = false, int LhsTarget = 0, int RhsTarget = 2>
__global__ void SDDMMCsrKernel(
const DType* __restrict__ lhs,
const DType* __restrict__ rhs,
DType* __restrict__ out,
const Idx* __restrict__ indptr,
const Idx* __restrict__ indices,
const Idx* __restrict__ edge_map,
int64_t N, int64_t M, int64_t E, int64_t reduce_size,
const int64_t* __restrict__ lhs_off,
const int64_t* __restrict__ rhs_off,
int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
const DType* __restrict__ lhs, const DType* __restrict__ rhs,
DType* __restrict__ out, const Idx* __restrict__ indptr,
const Idx* __restrict__ indices, const Idx* __restrict__ edge_map,
int64_t N, int64_t M, int64_t E, int64_t reduce_size,
const int64_t* __restrict__ lhs_off, const int64_t* __restrict__ rhs_off,
int64_t lhs_len, int64_t rhs_len, int64_t out_len) {
// SDDMM with Csr.
Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.y;
......@@ -235,17 +233,20 @@ __global__ void SDDMMCsrKernel(
const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t stride_x = blockDim.x * gridDim.x;
const DType* lhsoff = BinaryOp::use_lhs ?
(lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len): nullptr;
const DType* rhsoff = BinaryOp::use_rhs ?
(rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len): nullptr;
const DType* lhsoff =
BinaryOp::use_lhs
? (lhs + Selector<LhsTarget>::Call(src, eid, dst) * lhs_len)
: nullptr;
const DType* rhsoff =
BinaryOp::use_rhs
? (rhs + Selector<RhsTarget>::Call(src, eid, dst) * rhs_len)
: nullptr;
DType* outoff = out + eid * out_len;
while (tx < out_len) {
const Idx lhs_add = UseBcast ? lhs_off[tx] : tx;
const Idx rhs_add = UseBcast ? rhs_off[tx] : tx;
DType val = BinaryOp::Call(
lhsoff + lhs_add * reduce_size,
rhsoff + rhs_add * reduce_size,
lhsoff + lhs_add * reduce_size, rhsoff + rhs_add * reduce_size,
reduce_size);
outoff[tx] = val;
tx += stride_x;
......@@ -262,26 +263,22 @@ __global__ void SDDMMCsrKernel(
* @param rhs The right hand size operand feature.
* @param out The result feature on edges.
*/
template <typename Idx, typename DType, typename Op,
int LhsTarget = 0, int RhsTarget = 2>
template <
typename Idx, typename DType, typename Op, int LhsTarget = 0,
int RhsTarget = 2>
void SDDMMCoo(
const BcastOff& bcast,
const COOMatrix& coo,
NDArray lhs,
NDArray rhs,
const BcastOff& bcast, const COOMatrix& coo, NDArray lhs, NDArray rhs,
NDArray out) {
const Idx *row = coo.row.Ptr<Idx>();
const Idx *col = coo.col.Ptr<Idx>();
const Idx *edge_map = coo.data.Ptr<Idx>();
const DType *lhs_data = lhs.Ptr<DType>();
const DType *rhs_data = rhs.Ptr<DType>();
DType *out_data = out.Ptr<DType>();
const Idx* row = coo.row.Ptr<Idx>();
const Idx* col = coo.col.Ptr<Idx>();
const Idx* edge_map = coo.data.Ptr<Idx>();
const DType* lhs_data = lhs.Ptr<DType>();
const DType* rhs_data = rhs.Ptr<DType>();
DType* out_data = out.Ptr<DType>();
cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t *lhs_off = nullptr, *rhs_off = nullptr;
int64_t len = bcast.out_len,
lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len;
int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
int64_t reduce_dim = bcast.reduce_size;
const int64_t nnz = coo.row->shape[0];
......@@ -296,13 +293,11 @@ void SDDMMCoo(
const dim3 nthrs(ntx, nty);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
CUDA_KERNEL_CALL(
(SDDMMCooTreeReduceKernel<Idx, DType, UseBcast, UseIdx, LhsTarget, RhsTarget>),
nblks, nthrs, 0, stream,
lhs_data, rhs_data, out_data,
row, col, edge_map,
coo.num_rows, coo.num_cols, nnz, reduce_dim,
lhs_off, rhs_off,
lhs_len, rhs_len, len);
(SDDMMCooTreeReduceKernel<
Idx, DType, UseBcast, UseIdx, LhsTarget, RhsTarget>),
nblks, nthrs, 0, stream, lhs_data, rhs_data, out_data, row, col,
edge_map, coo.num_rows, coo.num_cols, nnz, reduce_dim, lhs_off,
rhs_off, lhs_len, rhs_len, len);
});
} else {
const int ntx = FindNumThreads(len);
......@@ -312,13 +307,12 @@ void SDDMMCoo(
const dim3 nblks(nbx, nby);
const dim3 nthrs(ntx, nty);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
CUDA_KERNEL_CALL((SDDMMCooKernel<Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),
nblks, nthrs, 0, stream,
lhs_data, rhs_data, out_data,
row, col, edge_map,
coo.num_rows, coo.num_cols, nnz, reduce_dim,
lhs_off, rhs_off,
lhs_len, rhs_len, len);
CUDA_KERNEL_CALL(
(SDDMMCooKernel<
Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),
nblks, nthrs, 0, stream, lhs_data, rhs_data, out_data, row, col,
edge_map, coo.num_rows, coo.num_cols, nnz, reduce_dim, lhs_off,
rhs_off, lhs_len, rhs_len, len);
});
}
}
......@@ -331,27 +325,23 @@ void SDDMMCoo(
* @param rhs The right hand size operand feature.
* @param out The result feature on edges.
*/
template <typename Idx, typename DType, typename Op,
int LhsTarget = 0, int RhsTarget = 2>
template <
typename Idx, typename DType, typename Op, int LhsTarget = 0,
int RhsTarget = 2>
void SDDMMCsr(
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray lhs,
NDArray rhs,
const BcastOff& bcast, const CSRMatrix& csr, NDArray lhs, NDArray rhs,
NDArray out) {
const Idx *indptr = csr.indptr.Ptr<Idx>();
const Idx *indices = csr.indices.Ptr<Idx>();
const Idx *edge_map = csr.data.Ptr<Idx>();
const DType *lhs_data = lhs.Ptr<DType>();
const DType *rhs_data = rhs.Ptr<DType>();
DType *out_data = out.Ptr<DType>();
const Idx* indptr = csr.indptr.Ptr<Idx>();
const Idx* indices = csr.indices.Ptr<Idx>();
const Idx* edge_map = csr.data.Ptr<Idx>();
const DType* lhs_data = lhs.Ptr<DType>();
const DType* rhs_data = rhs.Ptr<DType>();
DType* out_data = out.Ptr<DType>();
cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t N = csr.num_rows, M = csr.num_cols, E = csr.indices->shape[0];
int64_t *lhs_off = nullptr, *rhs_off = nullptr;
int64_t len = bcast.out_len,
lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len;
int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
int64_t reduce_dim = bcast.reduce_size;
const int ntx = FindNumThreads(len);
......@@ -363,17 +353,14 @@ void SDDMMCsr(
const bool use_idx = !IsNullArray(csr.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, out->ctx, lhs_off, rhs_off, {
CUDA_KERNEL_CALL((SDDMMCsrKernel<Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),
nblks, nthrs, 0, stream,
lhs_data, rhs_data, out_data,
indptr, indices, edge_map,
N, M, E, reduce_dim,
lhs_off, rhs_off,
lhs_len, rhs_len, len);
CUDA_KERNEL_CALL(
(SDDMMCsrKernel<
Idx, DType, Op, UseBcast, UseIdx, LhsTarget, RhsTarget>),
nblks, nthrs, 0, stream, lhs_data, rhs_data, out_data, indptr, indices,
edge_map, N, M, E, reduce_dim, lhs_off, rhs_off, lhs_len, rhs_len, len);
});
}
} // namespace cuda
} // namespace aten
} // namespace dgl
......
......@@ -4,6 +4,7 @@
* @brief SDDMM C APIs and definitions.
*/
#include <dgl/array.h>
#include "./sddmm.cuh"
namespace dgl {
......@@ -14,16 +15,12 @@ namespace aten {
Csr format.
*/
template <int XPU, typename IdType, typename DType>
void SDDMMCooHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs,
std::vector<NDArray> vec_out,
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid) {
void SDDMMCooHetero(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid) {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM CUDA kernel for each relation type sequentially */
......@@ -33,7 +30,7 @@ void SDDMMCooHetero(const std::string& op,
NDArray rhs = vec_rhs[rhs_eid[etype]];
NDArray out = vec_out[etype];
cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, coo, lhs, rhs, out);
bcast, coo, lhs, rhs, out);
}
});
});
......@@ -41,61 +38,53 @@ void SDDMMCooHetero(const std::string& op,
template void SDDMMCooHetero<kDGLCUDA, int32_t, __half>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int64_t, __half>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
#if BF16_ENABLED
template void SDDMMCooHetero<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
#endif // BF16_ENABLED
template void SDDMMCooHetero<kDGLCUDA, int32_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int64_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int32_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int64_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
} // namespace aten
......
......@@ -4,26 +4,22 @@
* @brief SDDMM C APIs and definitions.
*/
#include <dgl/array.h>
#include "./sddmm.cuh"
namespace dgl {
namespace aten {
/**
* @brief CUDA implementation of g-SDDMM on heterograph using
Csr format.
* @brief CUDA implementation of g-SDDMM on heterograph using Csr format.
*/
template <int XPU, typename IdType, typename DType>
void SDDMMCsrHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs,
std::vector<NDArray> vec_out,
int lhs_target,
int rhs_target,
const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid) {
void SDDMMCsrHetero(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& vec_lhs,
const std::vector<NDArray>& vec_rhs, std::vector<NDArray> vec_out,
int lhs_target, int rhs_target, const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid) {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM CUDA kernel for each relation type sequentially */
......@@ -33,7 +29,7 @@ void SDDMMCsrHetero(const std::string& op,
NDArray rhs = vec_rhs[rhs_eid[etype]];
NDArray out = vec_out[etype];
cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, csr, lhs, rhs, out);
bcast, csr, lhs, rhs, out);
}
});
});
......@@ -41,61 +37,53 @@ void SDDMMCsrHetero(const std::string& op,
template void SDDMMCsrHetero<kDGLCUDA, int32_t, __half>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int64_t, __half>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
#if BF16_ENABLED
template void SDDMMCsrHetero<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
#endif // BF16_ENABLED
template void SDDMMCsrHetero<kDGLCUDA, int32_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int64_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int32_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int64_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
} // namespace aten
......
......@@ -5,24 +5,21 @@
*/
#include <dgl/array.h>
#include <dgl/base_heterograph.h>
#include "./segment_reduce.cuh"
#include "./functor.cuh"
#include "./segment_reduce.cuh"
#include "./utils.h"
namespace dgl {
using namespace cuda;
namespace aten {
template <int XPU, typename IdType, typename DType>
void SegmentReduce(const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg) {
void SegmentReduce(
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg) {
if (op == "sum") {
cuda::SegmentReduce<IdType, DType, cuda::reduce::Sum<IdType, DType>>(
feat, offsets, out, arg);
......@@ -37,119 +34,70 @@ void SegmentReduce(const std::string& op,
}
}
template <int XPU, typename IdType, typename DType>
void ScatterAdd(NDArray feat,
NDArray idx,
NDArray out) {
void ScatterAdd(NDArray feat, NDArray idx, NDArray out) {
cuda::ScatterAdd<IdType, DType>(feat, idx, out);
}
template <int XPU, typename IdType, typename DType>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
const std::string& op,
const std::vector<NDArray>& feat,
const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out) {
cuda::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);
void UpdateGradMinMax_hetero(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {
cuda::UpdateGradMinMax_hetero<IdType, DType>(
g, op, feat, idx, idx_etype, out);
}
template <int XPU, typename IdType, typename DType>
void BackwardSegmentCmp(NDArray feat,
NDArray arg,
NDArray out) {
void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
cuda::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
}
template void SegmentReduce<kDGLCUDA, int32_t, __half>(
const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCUDA, int64_t, __half>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
#if BF16_ENABLED
template void SegmentReduce<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
#endif // BF16_ENABLED
template void SegmentReduce<kDGLCUDA, int32_t, float>(
const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCUDA, int64_t, float>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCUDA, int32_t, double>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCUDA, int64_t, double>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void ScatterAdd<kDGLCUDA, int32_t, __half>(
NDArray feat,
NDArray idx,
NDArray out);
NDArray feat, NDArray idx, NDArray out);
template void ScatterAdd<kDGLCUDA, int64_t, __half>(
NDArray feat,
NDArray idx,
NDArray out);
NDArray feat, NDArray idx, NDArray out);
#if BF16_ENABLED
template void ScatterAdd<kDGLCUDA, int32_t, __nv_bfloat16>(
NDArray feat,
NDArray idx,
NDArray out);
NDArray feat, NDArray idx, NDArray out);
template void ScatterAdd<kDGLCUDA, int64_t, __nv_bfloat16>(
NDArray feat,
NDArray idx,
NDArray out);
NDArray feat, NDArray idx, NDArray out);
#endif // BF16_ENABLED
template void ScatterAdd<kDGLCUDA, int32_t, float>(
NDArray feat,
NDArray idx,
NDArray out);
NDArray feat, NDArray idx, NDArray out);
template void ScatterAdd<kDGLCUDA, int64_t, float>(
NDArray feat,
NDArray idx,
NDArray out);
NDArray feat, NDArray idx, NDArray out);
template void ScatterAdd<kDGLCUDA, int32_t, double>(
NDArray feat,
NDArray idx,
NDArray out);
NDArray feat, NDArray idx, NDArray out);
template void ScatterAdd<kDGLCUDA, int64_t, double>(
NDArray feat,
NDArray idx,
NDArray out);
NDArray feat, NDArray idx, NDArray out);
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, __half>(
const HeteroGraphPtr& g, const std::string& op,
......@@ -187,39 +135,23 @@ template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, double>(
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDGLCUDA, int32_t, __half>(
NDArray feat,
NDArray arg,
NDArray out);
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int64_t, __half>(
NDArray feat,
NDArray arg,
NDArray out);
NDArray feat, NDArray arg, NDArray out);
#if BF16_ENABLED
template void BackwardSegmentCmp<kDGLCUDA, int32_t, __nv_bfloat16>(
NDArray feat,
NDArray arg,
NDArray out);
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int64_t, __nv_bfloat16>(
NDArray feat,
NDArray arg,
NDArray out);
NDArray feat, NDArray arg, NDArray out);
#endif // BF16_ENABLED
template void BackwardSegmentCmp<kDGLCUDA, int32_t, float>(
NDArray feat,
NDArray arg,
NDArray out);
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int64_t, float>(
NDArray feat,
NDArray arg,
NDArray out);
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int32_t, double>(
NDArray feat,
NDArray arg,
NDArray out);
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int64_t, double>(
NDArray feat,
NDArray arg,
NDArray out);
NDArray feat, NDArray arg, NDArray out);
} // namespace aten
} // namespace dgl
......@@ -4,10 +4,11 @@
* @brief SPMM C APIs and definitions.
*/
#include <dgl/array.h>
#include "./spmm.cuh"
#include "./ge_spmm.cuh"
#include "./functor.cuh"
#include "../../runtime/cuda/cuda_common.h"
#include "./functor.cuh"
#include "./ge_spmm.cuh"
#include "./spmm.cuh"
namespace dgl {
......@@ -21,13 +22,10 @@ namespace aten {
* no broadcast, use dgl's kernel in other cases.
*/
template <int XPU, typename IdType, typename DType>
void SpMMCsr(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat,
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux) {
void SpMMCsr(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux) {
bool is_scalar_efeat = efeat.NumElements() == csr.indices->shape[0];
bool use_efeat = op != "copy_lhs";
......@@ -36,27 +34,22 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
if (op == "copy_lhs" && cusparse_available<DType, IdType>(more_nnz)) {
// cusparse
int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i)
x_length *= ufeat->shape[i];
for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i];
CusparseCsrmm2<DType, IdType>(
ufeat->ctx, csr,
static_cast<DType*>(ufeat->data),
nullptr,
static_cast<DType*>(out->data),
x_length);
} else if (op == "mul" && is_scalar_efeat && cusparse_available<DType, IdType>(more_nnz)) {
ufeat->ctx, csr, static_cast<DType*>(ufeat->data), nullptr,
static_cast<DType*>(out->data), x_length);
} else if (
op == "mul" && is_scalar_efeat &&
cusparse_available<DType, IdType>(more_nnz)) {
// cusparse
int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i)
x_length *= ufeat->shape[i];
for (int i = 1; i < ufeat->ndim; ++i) x_length *= ufeat->shape[i];
if (!IsNullArray(csr.data)) {
efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
}
CusparseCsrmm2<DType, IdType>(
ufeat->ctx, csr,
static_cast<DType*>(ufeat->data),
static_cast<DType*>(efeat->data),
static_cast<DType*>(out->data),
ufeat->ctx, csr, static_cast<DType*>(ufeat->data),
static_cast<DType*>(efeat->data), static_cast<DType*>(out->data),
x_length);
} else { // general kernel
SWITCH_OP(op, Op, {
......@@ -79,31 +72,27 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
}
}
/**
* @brief CUDA implementation of g-SpMM on Coo format.
*/
template <int XPU, typename IdType, typename DType>
void SpMMCoo(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const COOMatrix& coo,
NDArray ufeat,
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux) {
void SpMMCoo(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux) {
if (reduce == "sum") {
SWITCH_OP(op, Op, {
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Sum<IdType, DType, true> > (
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Sum<IdType, DType, true> >(
bcast, coo, ufeat, efeat, out, NullArray(), NullArray());
});
} else if (reduce == "max") {
SWITCH_OP(op, Op, {
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Max<IdType, DType, true> > (
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Max<IdType, DType, true> >(
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
} else if (reduce == "min") {
} else if (reduce == "min") {
SWITCH_OP(op, Op, {
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Min<IdType, DType, true> > (
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Min<IdType, DType, true> >(
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
} else {
......@@ -112,74 +101,74 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
}
template void SpMMCsr<kDGLCUDA, int32_t, __half>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int64_t, __half>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
#if BF16_ENABLED
template void SpMMCsr<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
#endif // BF16_ENABLED
template void SpMMCsr<kDGLCUDA, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int32_t, __half>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int64_t, __half>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
#if BF16_ENABLED
template void SpMMCoo<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
#endif // BF16_ENABLED
template void SpMMCoo<kDGLCUDA, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
} // namespace aten
} // namespace dgl
......@@ -7,13 +7,15 @@
#define DGL_ARRAY_CUDA_SPMM_CUH_
#include <dgl/bcast.h>
#include <limits>
#include "macro.cuh"
#include "fp16.cuh"
#include "bf16.cuh"
#include "atomic.cuh"
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
#include "atomic.cuh"
#include "bf16.cuh"
#include "fp16.cuh"
#include "macro.cuh"
namespace dgl {
......@@ -32,9 +34,11 @@ inline bool cusparse_available(bool more_nnz_than_matrix_size) {
return true;
return false;
#else
if (std::is_same<DType, __half>::value || std::is_same<DType, __nv_bfloat16>::value)
if (std::is_same<DType, __half>::value ||
std::is_same<DType, __nv_bfloat16>::value)
return false; // cusparse's SpMM on fp16 is slow, temporally disabled.
// If the CSR matrix has more NNZ than matrix size, we should not use cuSPARSE 11.1.
// If the CSR matrix has more NNZ than matrix size, we should not use
// cuSPARSE 11.1.
return !more_nnz_than_matrix_size;
#endif
}
......@@ -43,21 +47,19 @@ namespace {
/** @brief Call cuBLAS geam API for transpose operation for float and double. */
template <typename DType>
cublasStatus_t Xgeam(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
const DType* alpha, const DType* A, int lda,
const DType* beta, const DType* B, int ldb,
DType* C, int ldc) {
cublasStatus_t Xgeam(
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, const DType* alpha, const DType* A, int lda,
const DType* beta, const DType* B, int ldb, DType* C, int ldc) {
LOG(FATAL) << "Not supported dtype";
return CUBLAS_STATUS_EXECUTION_FAILED;
}
template <>
cublasStatus_t Xgeam<__half>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
const __half* alpha, const __half* A, int lda,
const __half* beta, const __half* B, int ldb,
__half* C, int ldc) {
cublasStatus_t Xgeam<__half>(
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, const __half* alpha, const __half* A, int lda,
const __half* beta, const __half* B, int ldb, __half* C, int ldc) {
// TODO(ndickson): There is no cublasHgeam, so a different
// implementation would be required.
LOG(FATAL) << "Xgeam does not support dtype half (FP16)";
......@@ -66,9 +68,9 @@ cublasStatus_t Xgeam<__half>(cublasHandle_t handle, cublasOperation_t transa,
#if BF16_ENABLED
template <>
cublasStatus_t Xgeam<__nv_bfloat16>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
const __nv_bfloat16* alpha, const __nv_bfloat16* A, int lda,
cublasStatus_t Xgeam<__nv_bfloat16>(
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, const __nv_bfloat16* alpha, const __nv_bfloat16* A, int lda,
const __nv_bfloat16* beta, const __nv_bfloat16* B, int ldb,
__nv_bfloat16* C, int ldc) {
// TODO(ndickson): There is no cublasHgeam, so a different
......@@ -79,23 +81,21 @@ cublasStatus_t Xgeam<__nv_bfloat16>(cublasHandle_t handle, cublasOperation_t tra
#endif // BF16_ENABLED
template <>
cublasStatus_t Xgeam<float>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
const float* alpha, const float* A, int lda,
const float* beta, const float* B, int ldb,
float* C, int ldc) {
return cublasSgeam(handle, transa, transb, m, n, alpha, A, lda,
beta, B, ldb, C, ldc);
cublasStatus_t Xgeam<float>(
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, const float* alpha, const float* A, int lda,
const float* beta, const float* B, int ldb, float* C, int ldc) {
return cublasSgeam(
handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
}
template <>
cublasStatus_t Xgeam<double>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
const double* alpha, const double* A, int lda,
const double* beta, const double* B, int ldb,
double* C, int ldc) {
return cublasDgeam(handle, transa, transb, m, n, alpha, A, lda,
beta, B, ldb, C, ldc);
cublasStatus_t Xgeam<double>(
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, const double* alpha, const double* A, int lda,
const double* beta, const double* B, int ldb, double* C, int ldc) {
return cublasDgeam(
handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc);
}
/**
......@@ -104,10 +104,8 @@ cublasStatus_t Xgeam<double>(cublasHandle_t handle, cublasOperation_t transa,
*/
template <typename DType, typename IdType>
__global__ void _IndexSelectKernel(
const DType* __restrict__ in,
const IdType* __restrict__ idx,
DType* __restrict__ out,
int n, int m) {
const DType* __restrict__ in, const IdType* __restrict__ idx,
DType* __restrict__ out, int n, int m) {
int i = blockIdx.x;
for (int j = threadIdx.x; j < m; j += blockDim.x)
out[i * m + j] = in[idx[i] * m + j];
......@@ -119,9 +117,7 @@ __global__ void _IndexSelectKernel(
*/
template <typename DType>
__global__ void _TransposeKernel(
const DType* __restrict__ in,
DType* __restrict__ out,
int n, int m) {
const DType* __restrict__ in, DType* __restrict__ out, int n, int m) {
int i = blockIdx.x;
for (int j = threadIdx.x; j < m; j += blockDim.x)
out[i * m + j] = in[j * n + i];
......@@ -133,8 +129,7 @@ __global__ void _TransposeKernel(
* @param col number of columns of input matrix.
*/
template <typename DType>
void _Transpose(const DType* in, DType* out,
int row, int col) {
void _Transpose(const DType* in, DType* out, int row, int col) {
DType alpha = 1., beta = 0.;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
......@@ -142,13 +137,8 @@ void _Transpose(const DType* in, DType* out,
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));
CUBLAS_CALL(Xgeam<DType>(
thr_entry->cublas_handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
row, col,
&alpha, in, col,
&beta, nullptr, row,
out, row));
thr_entry->cublas_handle, CUBLAS_OP_T, CUBLAS_OP_N, row, col, &alpha, in,
col, &beta, nullptr, row, out, row));
}
/**
......@@ -156,8 +146,7 @@ void _Transpose(const DType* in, DType* out,
* @note cuBLAS has no geam API for half data type, fallback to our kernel.
*/
template <>
void _Transpose<half>(const half* in, half* out,
int row, int col) {
void _Transpose<half>(const half* in, half* out, int row, int col) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = FindNumThreads(row);
int nb = col;
......@@ -170,8 +159,8 @@ void _Transpose<half>(const half* in, half* out,
* @note cuBLAS has no geam API for bf16 data type, fallback to our kernel.
*/
template <>
void _Transpose<__nv_bfloat16>(const __nv_bfloat16* in, __nv_bfloat16* out,
int row, int col) {
void _Transpose<__nv_bfloat16>(
const __nv_bfloat16* in, __nv_bfloat16* out, int row, int col) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = FindNumThreads(row);
int nb = col;
......@@ -183,8 +172,8 @@ void _Transpose<__nv_bfloat16>(const __nv_bfloat16* in, __nv_bfloat16* out,
* @brief
*/
template <typename DType, typename IdType>
__global__ void _IndexSelectKernel(const DType* array, const IdType* index,
int64_t length, DType* out) {
__global__ void _IndexSelectKernel(
const DType* array, const IdType* index, int64_t length, DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
......@@ -197,7 +186,7 @@ __global__ void _IndexSelectKernel(const DType* array, const IdType* index,
* @note duplicate of IndexSelect defined in array_op.h but it can
* not be applied to float16 dtype.
*/
template<typename DType, typename IdType>
template <typename DType, typename IdType>
NDArray _IndexSelect(NDArray array, NDArray index) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* array_data = static_cast<DType*>(array->data);
......@@ -205,65 +194,65 @@ NDArray _IndexSelect(NDArray array, NDArray index) {
const int64_t arr_len = array->shape[0];
const int64_t len = index->shape[0];
NDArray ret = NDArray::Empty({len}, array->dtype, array->ctx);
if (len == 0)
return ret;
if (len == 0) return ret;
DType* ret_data = static_cast<DType*>(ret->data);
const int nt = FindNumThreads(len);
const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(_IndexSelectKernel, nb, nt, 0, stream,
array_data, idx_data, len, ret_data);
CUDA_KERNEL_CALL(
_IndexSelectKernel, nb, nt, 0, stream, array_data, idx_data, len,
ret_data);
return ret;
}
#if CUDART_VERSION < 11000
template <typename DType>
cusparseStatus_t Xcsrmm2(cusparseHandle_t handle, cusparseOperation_t transA,
cusparseStatus_t Xcsrmm2(
cusparseHandle_t handle, cusparseOperation_t transA,
cusparseOperation_t transB, int m, int n, int k, int nnz,
const DType* alpha, const cusparseMatDescr_t descrA,
const DType* csrValA, const int* csrRowPtrA, const int* csrColIndA,
const DType* B, int ldb, const DType* beta, DType* C, int ldc) {
const DType* alpha, const cusparseMatDescr_t descrA, const DType* csrValA,
const int* csrRowPtrA, const int* csrColIndA, const DType* B, int ldb,
const DType* beta, DType* C, int ldc) {
LOG(INFO) << "Not supported dtype";
return CUSPARSE_STATUS_EXECUTION_FAILED;
}
template <>
cusparseStatus_t Xcsrmm2<float>(cusparseHandle_t handle, cusparseOperation_t transA,
cusparseStatus_t Xcsrmm2<float>(
cusparseHandle_t handle, cusparseOperation_t transA,
cusparseOperation_t transB, int m, int n, int k, int nnz,
const float* alpha, const cusparseMatDescr_t descrA,
const float* csrValA, const int* csrRowPtrA, const int* csrColIndA,
const float* B, int ldb, const float* beta, float* C, int ldc) {
return cusparseScsrmm2(handle, transA, transB, m, n, k, nnz,
alpha, descrA, csrValA, csrRowPtrA, csrColIndA,
B, ldb, beta, C, ldc);
const float* alpha, const cusparseMatDescr_t descrA, const float* csrValA,
const int* csrRowPtrA, const int* csrColIndA, const float* B, int ldb,
const float* beta, float* C, int ldc) {
return cusparseScsrmm2(
handle, transA, transB, m, n, k, nnz, alpha, descrA, csrValA, csrRowPtrA,
csrColIndA, B, ldb, beta, C, ldc);
}
template <>
cusparseStatus_t Xcsrmm2<double>(cusparseHandle_t handle, cusparseOperation_t transA,
cusparseStatus_t Xcsrmm2<double>(
cusparseHandle_t handle, cusparseOperation_t transA,
cusparseOperation_t transB, int m, int n, int k, int nnz,
const double* alpha, const cusparseMatDescr_t descrA,
const double* csrValA, const int* csrRowPtrA, const int* csrColIndA,
const double* B, int ldb, const double* beta, double* C, int ldc) {
return cusparseDcsrmm2(handle, transA, transB, m, n, k, nnz,
alpha, descrA, csrValA, csrRowPtrA, csrColIndA,
B, ldb, beta, C, ldc);
const double* alpha, const cusparseMatDescr_t descrA, const double* csrValA,
const int* csrRowPtrA, const int* csrColIndA, const double* B, int ldb,
const double* beta, double* C, int ldc) {
return cusparseDcsrmm2(
handle, transA, transB, m, n, k, nnz, alpha, descrA, csrValA, csrRowPtrA,
csrColIndA, B, ldb, beta, C, ldc);
}
#endif
/** Cusparse implementation of SpMM on Csr format. */
template <typename DType, typename IdType>
void CusparseCsrmm2(
const DGLContext& ctx,
const CSRMatrix& csr,
const DType* B_data, const DType* A_data,
DType* C_data,
int x_length) {
const DGLContext& ctx, const CSRMatrix& csr, const DType* B_data,
const DType* A_data, DType* C_data, int x_length) {
// We use csrmm2 to perform following operation:
// C = A x B, where A is a sparse matrix in csr format, B is the dense matrix for node
// feature tensor. However, since cusparse only supports column-major, while our tensor
// is stored in row-major, the actual computation is:
// C = trans(A x trans(B)).
// Currently, we use cublasXgeam to implement transposition and allocate intermediate
// workspace memory for this.
// C = A x B, where A is a sparse matrix in csr format, B is the dense matrix
// for node feature tensor. However, since cusparse only supports
// column-major, while our tensor is stored in row-major, the actual
// computation is: C = trans(A x trans(B)). Currently, we use cublasXgeam to
// implement transposition and allocate intermediate workspace memory for
// this.
const int m = csr.num_rows;
const int n = x_length;
const int k = csr.num_cols;
......@@ -282,7 +271,8 @@ void CusparseCsrmm2(
// all one data array
DType* valptr = nullptr;
if (!A_data) {
valptr = static_cast<DType*>(device->AllocWorkspace(ctx, nnz * sizeof(DType)));
valptr =
static_cast<DType*>(device->AllocWorkspace(ctx, nnz * sizeof(DType)));
_Fill(valptr, nnz, static_cast<DType>(1.));
}
#if CUDART_VERSION >= 11000
......@@ -290,34 +280,26 @@ void CusparseCsrmm2(
cusparseDnMatDescr_t matB, matC;
constexpr auto dtype = cuda_dtype<DType>::value;
constexpr auto idtype = cusparse_idtype<IdType>::value;
CUSPARSE_CALL(cusparseCreateCsr(&matA,
m, k, nnz,
static_cast<IdType*>(csr.indptr->data),
CUSPARSE_CALL(cusparseCreateCsr(
&matA, m, k, nnz, static_cast<IdType*>(csr.indptr->data),
static_cast<IdType*>(csr.indices->data),
const_cast<DType*>(valptr? valptr : A_data),
idtype, idtype,
const_cast<DType*>(valptr ? valptr : A_data), idtype, idtype,
CUSPARSE_INDEX_BASE_ZERO, dtype));
CUSPARSE_CALL(cusparseCreateDnMat(&matB,
k, n, n,
const_cast<DType*>(B_data), dtype, CUSPARSE_ORDER_ROW));
CUSPARSE_CALL(cusparseCreateDnMat(&matC,
m, n, n,
C_data, dtype, CUSPARSE_ORDER_ROW));
CUSPARSE_CALL(cusparseCreateDnMat(
&matB, k, n, n, const_cast<DType*>(B_data), dtype, CUSPARSE_ORDER_ROW));
CUSPARSE_CALL(
cusparseCreateDnMat(&matC, m, n, n, C_data, dtype, CUSPARSE_ORDER_ROW));
auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE;
size_t workspace_size;
CUSPARSE_CALL(cusparseSpMM_bufferSize(
thr_entry->cusparse_handle, transA, transB,
&alpha, matA, matB, &beta, matC,
dtype, CUSPARSE_SPMM_CSR_ALG2,
&workspace_size));
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, CUSPARSE_SPMM_CSR_ALG2, &workspace_size));
void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUSPARSE_CALL(cusparseSpMM(
thr_entry->cusparse_handle, transA, transB,
&alpha, matA, matB, &beta, matC,
dtype, CUSPARSE_SPMM_CSR_ALG2,
workspace));
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, CUSPARSE_SPMM_CSR_ALG2, workspace));
device->FreeWorkspace(ctx, workspace);
CUSPARSE_CALL(cusparseDestroySpMat(matA));
......@@ -325,46 +307,40 @@ void CusparseCsrmm2(
CUSPARSE_CALL(cusparseDestroyDnMat(matC));
#else
// allocate matrix for temporary transposed output
DType* trans_out = static_cast<DType*>(device->AllocWorkspace(ctx, m * n * sizeof(DType)));
DType* trans_out =
static_cast<DType*>(device->AllocWorkspace(ctx, m * n * sizeof(DType)));
cusparseMatDescr_t descr;
CUSPARSE_CALL(cusparseCreateMatDescr(&descr));
CUSPARSE_CALL(cusparseSetMatType(descr, CUSPARSE_MATRIX_TYPE_GENERAL));
CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO));
CUSPARSE_CALL(Xcsrmm2<DType>(
thr_entry->cusparse_handle,
CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_TRANSPOSE,
m, n, k, nnz, &alpha,
descr, (valptr)? valptr : A_data,
static_cast<int32_t*>(csr.indptr->data),
static_cast<int32_t*>(csr.indices->data),
B_data, n, &beta, trans_out, m));
thr_entry->cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_TRANSPOSE, m, n, k, nnz, &alpha, descr,
(valptr) ? valptr : A_data, static_cast<int32_t*>(csr.indptr->data),
static_cast<int32_t*>(csr.indices->data), B_data, n, &beta, trans_out,
m));
CUSPARSE_CALL(cusparseDestroyMatDescr(descr));
// transpose the output matrix
_Transpose(trans_out, C_data, n, m);
device->FreeWorkspace(ctx, trans_out);
#endif
if (valptr)
device->FreeWorkspace(ctx, valptr);
if (valptr) device->FreeWorkspace(ctx, valptr);
}
/** Cusparse implementation of SpMM on Csr format. */
template <typename DType, typename IdType>
void CusparseCsrmm2Hetero(
const DGLContext& ctx,
const CSRMatrix& csr,
const DType* B_data, const DType* A_data,
DType* C_data,
int64_t x_length,
const DGLContext& ctx, const CSRMatrix& csr, const DType* B_data,
const DType* A_data, DType* C_data, int64_t x_length,
cudaStream_t strm_id) {
// We use csrmm2 to perform following operation:
// C = A x B, where A is a sparse matrix in csr format, B is the dense matrix for node
// feature tensor. However, since cusparse only supports column-major, while our tensor
// is stored in row-major, the actual computation is:
// C = trans(A x trans(B)).
// Currently, we use cublasXgeam to implement transposition and allocate intermediate
// workspace memory for this.
// C = A x B, where A is a sparse matrix in csr format, B is the dense matrix
// for node feature tensor. However, since cusparse only supports
// column-major, while our tensor is stored in row-major, the actual
// computation is: C = trans(A x trans(B)). Currently, we use cublasXgeam to
// implement transposition and allocate intermediate workspace memory for
// this.
int int_maxlimit = std::numeric_limits<int>::max();
CHECK_GE(int_maxlimit, (csr.num_rows));
CHECK_GE(int_maxlimit, csr.num_cols);
......@@ -386,7 +362,8 @@ void CusparseCsrmm2Hetero(
// all one data array
DType* valptr = nullptr;
if (!A_data) {
valptr = static_cast<DType*>(device->AllocWorkspace(ctx, nnz * sizeof(DType)));
valptr =
static_cast<DType*>(device->AllocWorkspace(ctx, nnz * sizeof(DType)));
_Fill(valptr, nnz, static_cast<DType>(1.));
}
#if CUDART_VERSION >= 11000
......@@ -394,34 +371,26 @@ void CusparseCsrmm2Hetero(
cusparseDnMatDescr_t matB, matC;
constexpr auto dtype = cuda_dtype<DType>::value;
constexpr auto idtype = cusparse_idtype<IdType>::value;
CUSPARSE_CALL(cusparseCreateCsr(&matA,
m, k, nnz,
static_cast<IdType*>(csr.indptr->data),
CUSPARSE_CALL(cusparseCreateCsr(
&matA, m, k, nnz, static_cast<IdType*>(csr.indptr->data),
static_cast<IdType*>(csr.indices->data),
const_cast<DType*>(valptr? valptr : A_data),
idtype, idtype,
const_cast<DType*>(valptr ? valptr : A_data), idtype, idtype,
CUSPARSE_INDEX_BASE_ZERO, dtype));
CUSPARSE_CALL(cusparseCreateDnMat(&matB,
k, n, n,
const_cast<DType*>(B_data), dtype, CUSPARSE_ORDER_ROW));
CUSPARSE_CALL(cusparseCreateDnMat(&matC,
m, n, n,
C_data, dtype, CUSPARSE_ORDER_ROW));
CUSPARSE_CALL(cusparseCreateDnMat(
&matB, k, n, n, const_cast<DType*>(B_data), dtype, CUSPARSE_ORDER_ROW));
CUSPARSE_CALL(
cusparseCreateDnMat(&matC, m, n, n, C_data, dtype, CUSPARSE_ORDER_ROW));
auto transA = CUSPARSE_OPERATION_NON_TRANSPOSE;
auto transB = CUSPARSE_OPERATION_NON_TRANSPOSE;
size_t workspace_size;
CUSPARSE_CALL(cusparseSpMM_bufferSize(
thr_entry->cusparse_handle, transA, transB,
&alpha, matA, matB, &beta, matC,
dtype, CUSPARSE_SPMM_CSR_ALG2,
&workspace_size));
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, CUSPARSE_SPMM_CSR_ALG2, &workspace_size));
void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUSPARSE_CALL(cusparseSpMM(
thr_entry->cusparse_handle, transA, transB,
&alpha, matA, matB, &beta, matC,
dtype, CUSPARSE_SPMM_CSR_ALG2,
workspace));
thr_entry->cusparse_handle, transA, transB, &alpha, matA, matB, &beta,
matC, dtype, CUSPARSE_SPMM_CSR_ALG2, workspace));
device->FreeWorkspace(ctx, workspace);
CUSPARSE_CALL(cusparseDestroySpMat(matA));
......@@ -434,74 +403,63 @@ void CusparseCsrmm2Hetero(
CUSPARSE_CALL(cusparseSetMatIndexBase(descr, CUSPARSE_INDEX_BASE_ZERO));
CHECK_EQ(sizeof(IdType), sizeof(int32_t));
CUSPARSE_CALL(Xcsrmm2<DType>(
thr_entry->cusparse_handle,
CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_TRANSPOSE,
m, n, k, nnz, &alpha,
descr, (valptr)? valptr : A_data,
static_cast<int32_t*>(csr.indptr->data),
static_cast<int32_t*>(csr.indices->data),
B_data, n, &beta, C_data, m));
thr_entry->cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE,
CUSPARSE_OPERATION_TRANSPOSE, m, n, k, nnz, &alpha, descr,
(valptr) ? valptr : A_data, static_cast<int32_t*>(csr.indptr->data),
static_cast<int32_t*>(csr.indices->data), B_data, n, &beta, C_data, m));
CUSPARSE_CALL(cusparseDestroyMatDescr(descr));
#endif
if (valptr)
device->FreeWorkspace(ctx, valptr);
if (valptr) device->FreeWorkspace(ctx, valptr);
}
} // namespace
#define SWITCH_OP(op, Op, ...) \
do { \
if ((op) == "add") { \
typedef cuda::binary::Add<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "sub") { \
typedef cuda::binary::Sub<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "mul") { \
typedef cuda::binary::Mul<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "div") { \
typedef cuda::binary::Div<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_lhs") { \
typedef cuda::binary::CopyLhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_rhs") { \
typedef cuda::binary::CopyRhs<DType> Op; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Unsupported SpMM binary operator: " << op; \
} \
#define SWITCH_OP(op, Op, ...) \
do { \
if ((op) == "add") { \
typedef cuda::binary::Add<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "sub") { \
typedef cuda::binary::Sub<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "mul") { \
typedef cuda::binary::Mul<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "div") { \
typedef cuda::binary::Div<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_lhs") { \
typedef cuda::binary::CopyLhs<DType> Op; \
{ __VA_ARGS__ } \
} else if ((op) == "copy_rhs") { \
typedef cuda::binary::CopyRhs<DType> Op; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Unsupported SpMM binary operator: " << op; \
} \
} while (0)
namespace cuda {
/**
* @brief CUDA kernel of g-SpMM on Coo format.
* @note it uses edge parallel strategy, different threadblocks (on y-axis)
* is responsible for the computation on different edges. Threadblocks
* on the x-axis are responsible for the computation on different positions
* in feature dimension.
* To avoid possible data hazards, it uses atomic operators for reduction.
* on the x-axis are responsible for the computation on different
* positions in feature dimension. To avoid possible data hazards, it uses
* atomic operators for reduction.
*/
template <typename Idx, typename DType,
typename BinaryOp, typename ReduceOp,
bool UseBcast = false, bool UseIdx = false>
template <
typename Idx, typename DType, typename BinaryOp, typename ReduceOp,
bool UseBcast = false, bool UseIdx = false>
__global__ void SpMMCooKernel(
const DType* __restrict__ ufeat,
const DType* __restrict__ efeat,
DType* __restrict__ out,
Idx* __restrict__ arg_u,
Idx* __restrict__ arg_e,
const Idx* __restrict__ row,
const Idx* __restrict__ col,
const Idx* __restrict__ edge_map,
int64_t N, int64_t M, int64_t E,
const int64_t* __restrict__ ubcast_off,
const int64_t* __restrict__ ebcast_off,
int64_t ufeat_len, int64_t efeat_len, int64_t out_len) {
const DType* __restrict__ ufeat, const DType* __restrict__ efeat,
DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,
const Idx* __restrict__ row, const Idx* __restrict__ col,
const Idx* __restrict__ edge_map, int64_t N, int64_t M, int64_t E,
const int64_t* __restrict__ ubcast_off,
const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,
int64_t efeat_len, int64_t out_len) {
// SPMM with COO.
Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.y;
......@@ -511,8 +469,8 @@ __global__ void SpMMCooKernel(
const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t stride_x = blockDim.x * gridDim.x;
const DType* uoff = BinaryOp::use_lhs ? (ufeat + src * ufeat_len): nullptr;
const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len): nullptr;
const DType* uoff = BinaryOp::use_lhs ? (ufeat + src * ufeat_len) : nullptr;
const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;
DType* outoff = out + dst * out_len;
while (tx < out_len) {
const int64_t lhs_add = UseBcast ? ubcast_off[tx] : tx;
......@@ -531,25 +489,20 @@ __global__ void SpMMCooKernel(
* @brief CUDA kernel to compute argu and arge in g-SpMM on Coo format.
* @note it uses edge parallel strategy, different threadblocks (on y-axis)
* is responsible for the computation on different edges. Threadblocks
* on the x-axis are responsible for the computation on different positions
* in feature dimension.
* on the x-axis are responsible for the computation on different
* positions in feature dimension.
*/
template <typename Idx, typename DType,
typename BinaryOp, typename ReduceOp,
bool UseBcast = false, bool UseIdx = false>
template <
typename Idx, typename DType, typename BinaryOp, typename ReduceOp,
bool UseBcast = false, bool UseIdx = false>
__global__ void ArgSpMMCooKernel(
const DType* __restrict__ ufeat,
const DType* __restrict__ efeat,
DType* __restrict__ out,
Idx* __restrict__ arg_u,
Idx* __restrict__ arg_e,
const Idx* __restrict__ row,
const Idx* __restrict__ col,
const Idx* __restrict__ edge_map,
int64_t N, int64_t M, int64_t E,
const int64_t* __restrict__ ubcast_off,
const int64_t* __restrict__ ebcast_off,
int64_t ufeat_len, int64_t efeat_len, int64_t out_len) {
const DType* __restrict__ ufeat, const DType* __restrict__ efeat,
DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,
const Idx* __restrict__ row, const Idx* __restrict__ col,
const Idx* __restrict__ edge_map, int64_t N, int64_t M, int64_t E,
const int64_t* __restrict__ ubcast_off,
const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,
int64_t efeat_len, int64_t out_len) {
// SPMM with COO arg max/min.
Idx ty = blockIdx.y * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.y;
......@@ -559,11 +512,11 @@ __global__ void ArgSpMMCooKernel(
const Idx eid = UseIdx ? _ldg(edge_map + ty) : ty;
int64_t tx = blockIdx.x * blockDim.x + threadIdx.x;
const int64_t stride_x = blockDim.x * gridDim.x;
const DType* uoff = BinaryOp::use_lhs ? (ufeat + src * ufeat_len): nullptr;
const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len): nullptr;
const DType* uoff = BinaryOp::use_lhs ? (ufeat + src * ufeat_len) : nullptr;
const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;
const DType* outoff = out + dst * out_len;
Idx* arguoff = BinaryOp::use_lhs ? (arg_u + dst * out_len): nullptr;
Idx* argeoff = BinaryOp::use_rhs ? (arg_e + dst * out_len): nullptr;
Idx* arguoff = BinaryOp::use_lhs ? (arg_u + dst * out_len) : nullptr;
Idx* argeoff = BinaryOp::use_rhs ? (arg_e + dst * out_len) : nullptr;
while (tx < out_len) {
int64_t lhs_add = UseBcast ? ubcast_off[tx] : tx;
int64_t rhs_add = UseBcast ? ebcast_off[tx] : tx;
......@@ -582,22 +535,17 @@ __global__ void ArgSpMMCooKernel(
* Threadblocks on the x-axis are responsible for the computation on
* different positions in feature dimension.
*/
template <typename Idx, typename DType,
typename BinaryOp, typename ReduceOp,
bool UseBcast = false, bool UseIdx = false>
template <
typename Idx, typename DType, typename BinaryOp, typename ReduceOp,
bool UseBcast = false, bool UseIdx = false>
__global__ void SpMMCsrKernel(
const DType* __restrict__ ufeat,
const DType* __restrict__ efeat,
DType* __restrict__ out,
Idx* __restrict__ arg_u,
Idx* __restrict__ arg_e,
const Idx* __restrict__ indptr,
const Idx* __restrict__ indices,
const Idx* __restrict__ edge_map,
int64_t num_rows, int64_t num_cols,
const int64_t* __restrict__ ubcast_off,
const int64_t* __restrict__ ebcast_off,
int64_t ufeat_len, int64_t efeat_len, int64_t out_len) {
const DType* __restrict__ ufeat, const DType* __restrict__ efeat,
DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,
const Idx* __restrict__ indptr, const Idx* __restrict__ indices,
const Idx* __restrict__ edge_map, int64_t num_rows, int64_t num_cols,
const int64_t* __restrict__ ubcast_off,
const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,
int64_t efeat_len, int64_t out_len) {
// SPMM with CSR.
int ty = blockIdx.x * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.x;
......@@ -612,16 +560,19 @@ __global__ void SpMMCsrKernel(
for (Idx i = indptr[ty]; i < indptr[ty + 1]; ++i) {
const Idx eid = UseIdx ? _ldg(edge_map + i) : i;
const Idx cid = _ldg(indices + i);
const DType* uoff = BinaryOp::use_lhs ? (ufeat + cid * ufeat_len): nullptr;
const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len): nullptr;
const DType* uoff =
BinaryOp::use_lhs ? (ufeat + cid * ufeat_len) : nullptr;
const DType* eoff =
BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;
DType out = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);
ReduceOp::Call(&local_accum, &local_argu, &local_arge, out, cid, eid);
}
// The use of += is to compute cross-type reducing on heterogeneous graph
// when reduce op is `sum`.
// C = SpMM(SpA, B) + C
// Separate kernel `SpMMCmpCsrHeteroKernel` is used for max- and min-reducer. It
// does not affect the output on homogeneous graph as `out` is initialized to zero.
// Separate kernel `SpMMCmpCsrHeteroKernel` is used for max- and
// min-reducer. It does not affect the output on homogeneous graph as
// `out` is initialized to zero.
out[ty * out_len + tx] += local_accum;
if (ReduceOp::require_arg && BinaryOp::use_lhs)
arg_u[ty * out_len + tx] = local_argu;
......@@ -640,23 +591,18 @@ __global__ void SpMMCsrKernel(
* Threadblocks on the x-axis are responsible for the computation on
* different positions in feature dimension.
*/
template <typename Idx, typename DType,
typename BinaryOp, typename ReduceOp,
bool UseBcast = false, bool UseIdx = false>
template <
typename Idx, typename DType, typename BinaryOp, typename ReduceOp,
bool UseBcast = false, bool UseIdx = false>
__global__ void SpMMCmpCsrHeteroKernel(
const DType* __restrict__ ufeat,
const DType* __restrict__ efeat,
DType* __restrict__ out,
Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,
Idx* __restrict__ arg_u_ntype, Idx* __restrict__ arg_e_etype,
const Idx* __restrict__ indptr,
const Idx* __restrict__ indices,
const Idx* __restrict__ edge_map,
int64_t num_rows, int64_t num_cols,
const int64_t* __restrict__ ubcast_off,
const int64_t* __restrict__ ebcast_off,
int64_t ufeat_len, int64_t efeat_len, int64_t out_len,
const int src_type, const int etype) {
const DType* __restrict__ ufeat, const DType* __restrict__ efeat,
DType* __restrict__ out, Idx* __restrict__ arg_u, Idx* __restrict__ arg_e,
Idx* __restrict__ arg_u_ntype, Idx* __restrict__ arg_e_etype,
const Idx* __restrict__ indptr, const Idx* __restrict__ indices,
const Idx* __restrict__ edge_map, int64_t num_rows, int64_t num_cols,
const int64_t* __restrict__ ubcast_off,
const int64_t* __restrict__ ebcast_off, int64_t ufeat_len,
int64_t efeat_len, int64_t out_len, const int src_type, const int etype) {
// SPMM with CSR.
int ty = blockIdx.y * blockDim.y + threadIdx.y;
const Idx stride_y = blockDim.y * gridDim.y;
......@@ -671,12 +617,15 @@ __global__ void SpMMCmpCsrHeteroKernel(
for (Idx i = indptr[ty]; i < indptr[ty + 1]; ++i) {
const Idx eid = UseIdx ? _ldg(edge_map + i) : i;
const Idx cid = _ldg(indices + i);
const DType* uoff = BinaryOp::use_lhs ? (ufeat + cid * ufeat_len): nullptr;
const DType* eoff = BinaryOp::use_rhs ? (efeat + eid * efeat_len): nullptr;
const DType* uoff =
BinaryOp::use_lhs ? (ufeat + cid * ufeat_len) : nullptr;
const DType* eoff =
BinaryOp::use_rhs ? (efeat + eid * efeat_len) : nullptr;
DType tmp_out = BinaryOp::Call(uoff + lhs_add, eoff + rhs_add);
ReduceOp::Call(&new_out, &local_argu, &local_arge, tmp_out, cid, eid);
}
// Update output only when max/min values are different that original output
// Update output only when max/min values are different that original
// output
if (out[ty * out_len + tx] != new_out) {
out[ty * out_len + tx] = new_out;
if (ReduceOp::require_arg && BinaryOp::use_lhs) {
......@@ -703,17 +652,16 @@ __global__ void SpMMCmpCsrHeteroKernel(
* @param out The result feature on destination nodes.
* @param argu Arg-Min/Max on source nodes, which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* destination nodes. It's useful in computing gradients of Min/Max
* reducer.
* @param arge Arg-Min/Max on edges. which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* destination nodes. It's useful in computing gradients of Min/Max
* reducer.
*/
template <typename Idx, typename DType,
typename BinaryOp, typename ReduceOp>
template <typename Idx, typename DType, typename BinaryOp, typename ReduceOp>
void SpMMCoo(
const BcastOff& bcast,
const COOMatrix& coo,
NDArray ufeat, NDArray efeat,
const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
NDArray out, NDArray argu, NDArray arge) {
#if defined(CUDART_VERSION) && CUDART_VERSION <= 10000
if (std::is_same<DType, half>::value)
......@@ -721,27 +669,23 @@ void SpMMCoo(
<< "for float16 in CUDA 10.0. Please upgrade your CUDA "
<< "to later versions.";
#endif
const Idx *row = coo.row.Ptr<Idx>(),
*col = coo.col.Ptr<Idx>(),
const Idx *row = coo.row.Ptr<Idx>(), *col = coo.col.Ptr<Idx>(),
*edge_map = coo.data.Ptr<Idx>();
const DType *ufeat_data = ufeat.Ptr<DType>(),
*efeat_data = efeat.Ptr<DType>();
DType *out_data = out.Ptr<DType>();
Idx *argu_data = argu.Ptr<Idx>(),
*arge_data = arge.Ptr<Idx>();
DType* out_data = out.Ptr<DType>();
Idx *argu_data = argu.Ptr<Idx>(), *arge_data = arge.Ptr<Idx>();
cudaStream_t stream = runtime::getCurrentCUDAStream();
const int64_t N = coo.num_rows, M = coo.num_cols, E = coo.row->shape[0];
int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
int64_t len = bcast.out_len,
lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len;
int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
int64_t out_size = out.NumElements();
const int nt = FindNumThreads(out_size);
const int nb = (out_size + nt - 1) / nt;
CUDA_KERNEL_CALL(_FillKernel, nb, nt, 0, stream,
out_data, out_size, ReduceOp::zero());
CUDA_KERNEL_CALL(
_FillKernel, nb, nt, 0, stream, out_data, out_size, ReduceOp::zero());
const int ntx = FindNumThreads(len);
const int nty = CUDA_MAX_NUM_THREADS / ntx;
......@@ -752,20 +696,16 @@ void SpMMCoo(
const bool use_idx = !IsNullArray(coo.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, {
CUDA_KERNEL_CALL((SpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
nblks, nthrs, 0, stream,
ufeat_data, efeat_data, out_data, argu_data, arge_data,
row, col, edge_map,
N, M, E,
ubcast_off, ebcast_off,
lhs_len, rhs_len, len);
CUDA_KERNEL_CALL(
(SpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,
arge_data, row, col, edge_map, N, M, E, ubcast_off, ebcast_off, lhs_len,
rhs_len, len);
if (ReduceOp::require_arg) {
CUDA_KERNEL_CALL((ArgSpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
nblks, nthrs, 0, stream,
ufeat_data, efeat_data, out_data, argu_data, arge_data,
row, col, edge_map,
N, M, E,
ubcast_off, ebcast_off,
CUDA_KERNEL_CALL(
(ArgSpMMCooKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,
arge_data, row, col, edge_map, N, M, E, ubcast_off, ebcast_off,
lhs_len, rhs_len, len);
}
});
......@@ -780,33 +720,30 @@ void SpMMCoo(
* @param out The result feature on destination nodes.
* @param argu Arg-Min/Max on source nodes, which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* destination nodes. It's useful in computing gradients of Min/Max
* reducer.
* @param arge Arg-Min/Max on edges. which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* destination nodes. It's useful in computing gradients of Min/Max
* reducer.
*/
template <typename Idx, typename DType,
typename BinaryOp, typename ReduceOp>
template <typename Idx, typename DType, typename BinaryOp, typename ReduceOp>
void SpMMCsr(
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat, NDArray efeat,
const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
NDArray out, NDArray argu, NDArray arge) {
const Idx *indptr = csr.indptr.Ptr<Idx>();
const Idx *indices = csr.indices.Ptr<Idx>();
const Idx *edge_map = csr.data.Ptr<Idx>();
const DType *ufeat_data = ufeat.Ptr<DType>();
const DType *efeat_data = efeat.Ptr<DType>();
DType *out_data = out.Ptr<DType>();
const Idx* indptr = csr.indptr.Ptr<Idx>();
const Idx* indices = csr.indices.Ptr<Idx>();
const Idx* edge_map = csr.data.Ptr<Idx>();
const DType* ufeat_data = ufeat.Ptr<DType>();
const DType* efeat_data = efeat.Ptr<DType>();
DType* out_data = out.Ptr<DType>();
Idx* argu_data = argu.Ptr<Idx>();
Idx* arge_data = arge.Ptr<Idx>();
cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
int64_t len = bcast.out_len,
lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len;
int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
const int ntx = FindNumThreads(len);
const int nty = CUDA_MAX_NUM_THREADS / ntx;
const int nby = (len + ntx - 1) / ntx;
......@@ -815,15 +752,13 @@ void SpMMCsr(
const dim3 nthrs(ntx, nty);
const bool use_idx = !IsNullArray(csr.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, {
CUDA_KERNEL_CALL((SpMMCsrKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
nblks, nthrs, 0, stream,
ufeat_data, efeat_data, out_data, argu_data, arge_data,
indptr, indices, edge_map,
csr.num_rows, csr.num_cols,
ubcast_off, ebcast_off,
lhs_len, rhs_len, len)
});
BCAST_IDX_CTX_SWITCH(
bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off,
{CUDA_KERNEL_CALL(
(SpMMCsrKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,
arge_data, indptr, indices, edge_map, csr.num_rows, csr.num_cols,
ubcast_off, ebcast_off, lhs_len, rhs_len, len)});
}
/**
......@@ -835,43 +770,41 @@ void SpMMCsr(
* @param out The result feature on destination nodes.
* @param argu Arg-Min/Max on source nodes, which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* destination nodes. It's useful in computing gradients of Min/Max
* reducer.
* @param arge Arg-Min/Max on edges. which refers the source node indices
* correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* @param argu_ntype Node type of the arg-Min/Max on source nodes, which refers the
* source node types correspond to the minimum/maximum values of reduction result
* on destination nodes. It's useful in computing gradients of Min/Max reducer.
* @param arge_etype Edge-type of the arg-Min/Max on edges. which refers the source
* node indices correspond to the minimum/maximum values of reduction result on
* destination nodes. It's useful in computing gradients of Min/Max reducer.
* destination nodes. It's useful in computing gradients of Min/Max
* reducer.
* @param argu_ntype Node type of the arg-Min/Max on source nodes, which refers
* the source node types correspond to the minimum/maximum values of reduction
* result on destination nodes. It's useful in computing gradients of Min/Max
* reducer.
* @param arge_etype Edge-type of the arg-Min/Max on edges. which refers the
* source node indices correspond to the minimum/maximum values of reduction
* result on destination nodes. It's useful in computing gradients of Min/Max
* reducer.
* @param src_type Node type of the source nodes of an etype
* @param etype Edge type
*/
template <typename Idx, typename DType,
typename BinaryOp, typename ReduceOp>
template <typename Idx, typename DType, typename BinaryOp, typename ReduceOp>
void SpMMCmpCsrHetero(
const BcastOff& bcast,
const CSRMatrix& csr,
NDArray ufeat, NDArray efeat,
NDArray out, NDArray argu, NDArray arge,
NDArray argu_ntype, NDArray arge_etype,
const int src_type, const int etype) {
const Idx *indptr = csr.indptr.Ptr<Idx>();
const Idx *indices = csr.indices.Ptr<Idx>();
const Idx *edge_map = csr.data.Ptr<Idx>();
const DType *ufeat_data = ufeat.Ptr<DType>();
const DType *efeat_data = efeat.Ptr<DType>();
DType *out_data = out.Ptr<DType>();
const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat, NDArray efeat,
NDArray out, NDArray argu, NDArray arge, NDArray argu_ntype,
NDArray arge_etype, const int src_type, const int etype) {
const Idx* indptr = csr.indptr.Ptr<Idx>();
const Idx* indices = csr.indices.Ptr<Idx>();
const Idx* edge_map = csr.data.Ptr<Idx>();
const DType* ufeat_data = ufeat.Ptr<DType>();
const DType* efeat_data = efeat.Ptr<DType>();
DType* out_data = out.Ptr<DType>();
Idx* argu_data = argu.Ptr<Idx>();
Idx* arge_data = arge.Ptr<Idx>();
cudaStream_t stream = runtime::getCurrentCUDAStream();
int64_t *ubcast_off = nullptr, *ebcast_off = nullptr;
int64_t len = bcast.out_len,
lhs_len = bcast.lhs_len,
rhs_len = bcast.rhs_len;
int64_t len = bcast.out_len, lhs_len = bcast.lhs_len, rhs_len = bcast.rhs_len;
const int ntx = FindNumThreads(len);
const int nty = CUDA_MAX_NUM_THREADS / ntx;
const int nbx = (len + ntx - 1) / ntx;
......@@ -880,20 +813,18 @@ void SpMMCmpCsrHetero(
const dim3 nthrs(ntx, nty);
const bool use_idx = !IsNullArray(csr.data);
BCAST_IDX_CTX_SWITCH(bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off, {
CUDA_KERNEL_CALL((SpMMCmpCsrHeteroKernel<Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
nblks, nthrs, 0, stream,
ufeat_data, efeat_data, out_data, argu_data, arge_data,
static_cast<Idx*>(argu_ntype->data),
static_cast<Idx*>(arge_etype->data),
indptr, indices, edge_map,
csr.num_rows, csr.num_cols,
ubcast_off, ebcast_off,
lhs_len, rhs_len, len, src_type, etype)
});
BCAST_IDX_CTX_SWITCH(
bcast, use_idx, ufeat->ctx, ubcast_off, ebcast_off,
{CUDA_KERNEL_CALL(
(SpMMCmpCsrHeteroKernel<
Idx, DType, BinaryOp, ReduceOp, UseBcast, UseIdx>),
nblks, nthrs, 0, stream, ufeat_data, efeat_data, out_data, argu_data,
arge_data, static_cast<Idx*>(argu_ntype->data),
static_cast<Idx*>(arge_etype->data), indptr, indices, edge_map,
csr.num_rows, csr.num_cols, ubcast_off, ebcast_off, lhs_len, rhs_len,
len, src_type, etype)});
}
} // namespace cuda
} // namespace aten
} // namespace dgl
......
......@@ -4,10 +4,11 @@
* @brief SPMM C APIs and definitions.
*/
#include <dgl/array.h>
#include "./spmm.cuh"
#include "./ge_spmm.cuh"
#include "./functor.cuh"
#include "../../runtime/cuda/cuda_common.h"
#include "./functor.cuh"
#include "./ge_spmm.cuh"
#include "./spmm.cuh"
namespace dgl {
......@@ -21,16 +22,16 @@ namespace aten {
* no broadcast, use dgl's kernel in other cases.
*/
template <int XPU, typename IdType, typename DType>
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_ufeat,
const std::vector<NDArray>& vec_efeat,
std::vector<NDArray>* vec_out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, // ufeat node type id
const std::vector<dgl_type_t>& out_ntids) { // output node type id
bool is_scalar_efeat = vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0];
void SpMMCsrHetero(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& vec_ufeat,
const std::vector<NDArray>& vec_efeat, std::vector<NDArray>* vec_out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, // ufeat node type id
const std::vector<dgl_type_t>& out_ntids) { // output node type id
bool is_scalar_efeat =
vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0];
bool use_efeat = op != "copy_lhs";
auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);
std::vector<DType*> trans_out((*vec_out).size(), NULL);
......@@ -39,15 +40,16 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
(CUDART_VERSION < 11000) && (reduce == "sum") &&
// legacy cuSPARSE does not care about NNZ, hence the argument "false".
((op == "copy_lhs" && cusparse_available<DType, IdType>(false)) ||
(op == "mul" && is_scalar_efeat && cusparse_available<DType, IdType>(false)));
(op == "mul" && is_scalar_efeat &&
cusparse_available<DType, IdType>(false)));
// Create temporary output buffer to store non-transposed output
if (use_legacy_cusparsemm) {
for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
const int m = (*vec_out)[ntype]->shape[0];
const int n = (*vec_out)[ntype]->shape[1];
if (m == 0) continue;
DType *out = static_cast<DType*>(device->AllocWorkspace(vec_csr[0].indptr->ctx,
m * n * sizeof(DType)));
DType* out = static_cast<DType*>(device->AllocWorkspace(
vec_csr[0].indptr->ctx, m * n * sizeof(DType)));
CUDA_CALL(cudaMemset(out, 0, m * n * sizeof(DType)));
trans_out[ntype] = out;
}
......@@ -57,43 +59,53 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
for (dgl_type_t etype = 0; etype < (ufeat_ntids.size() - 1); ++etype) {
NDArray ufeat = vec_ufeat[ufeat_ntids[etype]];
NDArray next_ufeat = vec_ufeat[ufeat_ntids[etype + 1]];
CHECK_EQ(ufeat->ndim, next_ufeat->ndim) << "Input features have different shapes";
CHECK_EQ(ufeat->ndim, next_ufeat->ndim)
<< "Input features have different shapes";
for (int i = 1; i < ufeat->ndim; ++i) {
if (ufeat->shape[i] != next_ufeat->shape[i]) {
if (ufeat->shape[i] == 1 || next_ufeat->shape[i] == 1)
LOG(FATAL) <<
"Homogenized message passing on heterogeneous graphs does not support " <<
"automatic broadcasting. Please manually broadcast it before calling " <<
"message passing functions.";
LOG(FATAL) << "Homogenized message passing on heterogeneous graphs "
"does not support "
<< "automatic broadcasting. Please manually broadcast it "
"before calling "
<< "message passing functions.";
else
LOG(FATAL) << "Input features have different shapes.";
return;
}
if (etype == 0)
x_length *= ufeat->shape[i];
if (etype == 0) x_length *= ufeat->shape[i];
}
}
// TODO(Israt): Can python do the following initializations while creating the tensors?
if (reduce == "max" || reduce == "min") {
// TODO(Israt): Can python do the following initializations while creating the
// tensors?
if (reduce == "max" || reduce == "min") {
const int64_t dim = bcast.out_len;
std::vector<bool> updated((*vec_out).size(), false);
for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
DType *out_off = (*vec_out)[out_ntids[etype]].Ptr<DType>();
DType* out_off = (*vec_out)[out_ntids[etype]].Ptr<DType>();
if (reduce == "max")
_Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Max<IdType, DType>::zero());
_Fill(
out_off, vec_csr[etype].num_rows * dim,
cuda::reduce::Max<IdType, DType>::zero());
else // min
_Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Min<IdType, DType>::zero());
_Fill(
out_off, vec_csr[etype].num_rows * dim,
cuda::reduce::Min<IdType, DType>::zero());
const dgl_type_t dst_id = out_ntids[etype];
if (!updated[dst_id]) {
updated[dst_id] = true;
if (op == "copy_lhs") {
IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
_Fill(argu_ntype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1));
IdType* argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
_Fill(
argu_ntype, vec_csr[etype].num_rows * dim,
static_cast<IdType>(-1));
}
if (op == "copy_rhs") {
IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
_Fill(arge_etype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1));
IdType* arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
_Fill(
arge_etype, vec_csr[etype].num_rows * dim,
static_cast<IdType>(-1));
}
}
}
......@@ -106,60 +118,63 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
CSRMatrix csr = vec_csr[etype];
if (reduce == "sum") {
bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
/* Call SpMM for each relation type */
if (op == "copy_lhs" && cusparse_available<DType, IdType>(more_nnz)) { // cusparse
/* If CUDA is less than 11.0, put the output in trans_out for later transposition */
DType *out = (CUDART_VERSION < 11000) ? trans_out[dst_id] :
static_cast<DType*>((*vec_out)[dst_id]->data);
/* Call SpMM for each relation type */
if (op == "copy_lhs" &&
cusparse_available<DType, IdType>(more_nnz)) { // cusparse
/* If CUDA is less than 11.0, put the output in trans_out for later
* transposition */
DType* out = (CUDART_VERSION < 11000)
? trans_out[dst_id]
: static_cast<DType*>((*vec_out)[dst_id]->data);
CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr,
static_cast<DType*>(vec_ufeat[src_id]->data),
nullptr,
out,
x_length, stream);
} else if (op == "mul" && is_scalar_efeat &&
csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data),
nullptr, out, x_length, stream);
} else if (
op == "mul" && is_scalar_efeat &&
cusparse_available<DType, IdType>(more_nnz)) { // cusparse
NDArray efeat = vec_efeat[etype];
if (!IsNullArray(csr.data))
efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr,
static_cast<DType*>(vec_ufeat[src_id]->data),
csr.indptr->ctx, csr, static_cast<DType*>(vec_ufeat[src_id]->data),
static_cast<DType*>(efeat->data),
// TODO(Israt): Change (*vec_out) to trans_out to support CUDA version < 11
static_cast<DType*>((*vec_out)[dst_id]->data),
x_length, stream);
// TODO(Israt): Change (*vec_out) to trans_out to support CUDA
// version < 11
static_cast<DType*>((*vec_out)[dst_id]->data), x_length, stream);
} else { // general kernel
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
NDArray ufeat =
(vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat =
(vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(), NullArray());
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType>>(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(),
NullArray());
});
}
} else if (reduce == "max") {
SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
src_id, etype);
});
SWITCH_OP(op, Op, {
NDArray ufeat =
(vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat =
(vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
cuda::SpMMCmpCsrHetero<
IdType, DType, Op, cuda::reduce::Max<IdType, DType>>(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
src_id, etype);
});
} else if (reduce == "min") {
SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
src_id, etype);
SWITCH_OP(op, Op, {
NDArray ufeat =
(vec_ufeat.size() == 0) ? NullArray() : vec_ufeat[src_id];
NDArray efeat =
(vec_efeat.size() == 0) ? NullArray() : vec_efeat[etype];
cuda::SpMMCmpCsrHetero<
IdType, DType, Op, cuda::reduce::Min<IdType, DType>>(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
src_id, etype);
});
} else {
LOG(FATAL) << "Not implemented";
......@@ -172,7 +187,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const int m = (*vec_out)[ntype]->shape[0];
const int n = (*vec_out)[ntype]->shape[1];
if (m == 0) continue;
DType *C_data = static_cast<DType*>((*vec_out)[ntype]->data);
DType* C_data = static_cast<DType*>((*vec_out)[ntype]->data);
_Transpose(trans_out[ntype], C_data, n, m);
device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]);
}
......@@ -180,55 +195,63 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
}
template void SpMMCsrHetero<kDGLCUDA, int32_t, __half>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids,
const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int64_t, __half>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids,
const std::vector<dgl_type_t>& out_ntids);
#if BF16_ENABLED
template void SpMMCsrHetero<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids,
const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids,
const std::vector<dgl_type_t>& out_ntids);
#endif // BF16_ENABLED
template void SpMMCsrHetero<kDGLCUDA, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids,
const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids,
const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids,
const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids,
const std::vector<dgl_type_t>& out_ntids);
} // namespace aten
} // namespace dgl
......@@ -4,9 +4,9 @@
* @brief Utilities for CUDA kernels.
*/
#include "./utils.h"
#include "./dgl_cub.cuh"
#include "../../runtime/cuda/cuda_common.h"
#include "./dgl_cub.cuh"
#include "./utils.h"
namespace dgl {
namespace cuda {
......@@ -17,9 +17,11 @@ bool AllTrue(int8_t* flags, int64_t length, const DGLContext& ctx) {
// Call CUB's reduction
size_t workspace_size = 0;
cudaStream_t stream = runtime::getCurrentCUDAStream();
CUDA_CALL(cub::DeviceReduce::Min(nullptr, workspace_size, flags, rst, length, stream));
CUDA_CALL(cub::DeviceReduce::Min(
nullptr, workspace_size, flags, rst, length, stream));
void* workspace = device->AllocWorkspace(ctx, workspace_size);
CUDA_CALL(cub::DeviceReduce::Min(workspace, workspace_size, flags, rst, length, stream));
CUDA_CALL(cub::DeviceReduce::Min(
workspace, workspace_size, flags, rst, length, stream));
int8_t cpu_rst = GetCUDAScalar(device, ctx, rst);
device->FreeWorkspace(ctx, workspace);
device->FreeWorkspace(ctx, rst);
......
......@@ -6,10 +6,11 @@
#ifndef DGL_ARRAY_CUDA_UTILS_H_
#define DGL_ARRAY_CUDA_UTILS_H_
#include <dmlc/logging.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/ndarray.h>
#include <dmlc/logging.h>
#include "../../runtime/cuda/cuda_common.h"
#include "dgl_cub.cuh"
......@@ -22,7 +23,6 @@ namespace cuda {
// The max number of threads per block
#define CUDA_MAX_NUM_THREADS 256
/** @brief Calculate the number of threads needed given the dimension length.
*
* It finds the biggest number that is smaller than min(dim, max_nthrs)
......@@ -30,8 +30,7 @@ namespace cuda {
*/
inline int FindNumThreads(int dim, int max_nthrs = CUDA_MAX_NUM_THREADS) {
CHECK_GE(dim, 0);
if (dim == 0)
return 1;
if (dim == 0) return 1;
int ret = max_nthrs;
while (ret > dim) {
ret = ret >> 1;
......@@ -60,11 +59,9 @@ inline int FindNumBlocks(int nblks, int max_nblks = -1) {
LOG(FATAL) << "Axis " << axis << " not recognized";
break;
}
if (max_nblks == -1)
max_nblks = default_max_nblks;
if (max_nblks == -1) max_nblks = default_max_nblks;
CHECK_NE(nblks, 0);
if (nblks < max_nblks)
return nblks;
if (nblks < max_nblks) return nblks;
return max_nblks;
}
......@@ -108,7 +105,8 @@ template <typename DType>
void _Fill(DType* ptr, size_t length, DType val) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = FindNumThreads(length);
int nb = (length + nt - 1) / nt; // on x-axis, no need to worry about upperbound.
int nb =
(length + nt - 1) / nt; // on x-axis, no need to worry about upperbound.
CUDA_KERNEL_CALL(cuda::_FillKernel, nb, nt, 0, stream, ptr, length, val);
}
......@@ -123,9 +121,9 @@ void _Fill(DType* ptr, size_t length, DType val) {
template <typename IdType, typename DType>
__global__ void _LinearSearchKernel(
const IdType* indptr, const IdType* indices, const IdType* data,
const IdType* row, const IdType* col,
int64_t row_stride, int64_t col_stride,
int64_t length, const DType* weights, DType filler, DType* out) {
const IdType* row, const IdType* col, int64_t row_stride,
int64_t col_stride, int64_t length, const DType* weights, DType filler,
DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
......@@ -148,7 +146,7 @@ __global__ void _LinearSearchKernel(
// constructor for __half.
// The using statement is to avoid a linter error about using
// long or long long.
using LongLong = long long; // NOLINT
using LongLong = long long; // NOLINT
out[tx] = weights ? weights[v] : DType(LongLong(v));
}
tx += stride_x;
......@@ -163,9 +161,9 @@ __global__ void _LinearSearchKernel(
template <typename IdType>
__global__ void _LinearSearchKernel(
const IdType* indptr, const IdType* indices, const IdType* data,
const IdType* row, const IdType* col,
int64_t row_stride, int64_t col_stride, int64_t length,
const __nv_bfloat16* weights, __nv_bfloat16 filler, __nv_bfloat16* out) {
const IdType* row, const IdType* col, int64_t row_stride,
int64_t col_stride, int64_t length, const __nv_bfloat16* weights,
__nv_bfloat16 filler, __nv_bfloat16* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
......@@ -181,7 +179,8 @@ __global__ void _LinearSearchKernel(
if (v == -1) {
out[tx] = filler;
} else {
// If the result is saved in bf16, it should be fine to convert it to float first
// If the result is saved in bf16, it should be fine to convert it to
// float first
out[tx] = weights ? weights[v] : __nv_bfloat16(static_cast<float>(v));
}
tx += stride_x;
......@@ -191,16 +190,10 @@ __global__ void _LinearSearchKernel(
template <typename DType>
inline DType GetCUDAScalar(
runtime::DeviceAPI* device_api,
DGLContext ctx,
const DType* cuda_ptr) {
runtime::DeviceAPI* device_api, DGLContext ctx, const DType* cuda_ptr) {
DType result;
device_api->CopyDataFromTo(
cuda_ptr, 0,
&result, 0,
sizeof(result),
ctx,
DGLContext{kDGLCPU, 0},
cuda_ptr, 0, &result, 0, sizeof(result), ctx, DGLContext{kDGLCPU, 0},
DGLDataTypeTraits<DType>::dtype);
return result;
}
......@@ -217,12 +210,12 @@ inline DType GetCUDAScalar(
* if x<A[0] then it returns 0.
*/
template <typename IdType>
__device__ IdType _UpperBound(const IdType *A, int64_t n, IdType x) {
__device__ IdType _UpperBound(const IdType* A, int64_t n, IdType x) {
IdType l = 0, r = n, m = 0;
while (l < r) {
m = l + (r-l)/2;
m = l + (r - l) / 2;
if (x >= A[m]) {
l = m+1;
l = m + 1;
} else {
r = m;
}
......@@ -241,17 +234,17 @@ __device__ IdType _UpperBound(const IdType *A, int64_t n, IdType x) {
* @return index, i, st. A[i]==x. If such an index not exists returns 'n'.
*/
template <typename IdType>
__device__ IdType _BinarySearch(const IdType *A, int64_t n, IdType x) {
IdType l = 0, r = n-1, m = 0;
__device__ IdType _BinarySearch(const IdType* A, int64_t n, IdType x) {
IdType l = 0, r = n - 1, m = 0;
while (l <= r) {
m = l + (r-l)/2;
m = l + (r - l) / 2;
if (A[m] == x) {
return m;
}
if (A[m] < x) {
l = m+1;
l = m + 1;
} else {
r = m-1;
r = m - 1;
}
}
return n; // not found
......@@ -259,9 +252,9 @@ __device__ IdType _BinarySearch(const IdType *A, int64_t n, IdType x) {
template <typename DType, typename BoolType>
void MaskSelect(
runtime::DeviceAPI* device, const DGLContext& ctx,
const DType* input, const BoolType* mask, DType* output, int64_t n,
int64_t* rst, cudaStream_t stream) {
runtime::DeviceAPI* device, const DGLContext& ctx, const DType* input,
const BoolType* mask, DType* output, int64_t n, int64_t* rst,
cudaStream_t stream) {
size_t workspace_size = 0;
CUDA_CALL(cub::DeviceSelect::Flagged(
nullptr, workspace_size, input, mask, output, rst, n, stream));
......
......@@ -3,33 +3,28 @@
* @file array/kernel.cc
* @brief New kernels
*/
#include <dgl/packed_func_ext.h>
#include <dgl/base_heterograph.h>
#include <dgl/packed_func_ext.h>
#ifdef USE_TVM
#include <featgraph.h>
#include <dgl/runtime/dlpack_convert.h>
#include <featgraph.h>
#endif // USE_TVM
#include "kernel_decl.h"
#include "../c_api_common.h"
#include "./check.h"
#include "kernel_decl.h"
using namespace dgl::runtime;
namespace dgl {
namespace aten {
namespace {
} // namespace
namespace {} // namespace
/** @brief Generalized Sparse Matrix-Matrix Multiplication. */
void SpMM(const std::string& op, const std::string& reduce,
HeteroGraphPtr graph,
NDArray ufeat,
NDArray efeat,
NDArray out,
std::vector<NDArray> out_aux) {
void SpMM(
const std::string& op, const std::string& reduce, HeteroGraphPtr graph,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux) {
// TODO(zihao): format tuning
SparseFormat format = graph->SelectFormat(0, CSC_CODE);
const auto& bcast = CalcBcastOff(op, ufeat, efeat);
......@@ -39,12 +34,12 @@ void SpMM(const std::string& op, const std::string& reduce,
ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", {
if (format == SparseFormat::kCSC) {
SpMMCsr<XPU, IdType, Dtype>(
op, reduce, bcast, graph->GetCSCMatrix(0),
ufeat, efeat, out, out_aux);
op, reduce, bcast, graph->GetCSCMatrix(0), ufeat, efeat, out,
out_aux);
} else if (format == SparseFormat::kCOO) {
SpMMCoo<XPU, IdType, Dtype>(
op, reduce, bcast, graph->GetCOOMatrix(0),
ufeat, efeat, out, out_aux);
op, reduce, bcast, graph->GetCOOMatrix(0), ufeat, efeat, out,
out_aux);
} else {
LOG(FATAL) << "SpMM only supports CSC and COO formats";
}
......@@ -53,27 +48,27 @@ void SpMM(const std::string& op, const std::string& reduce,
});
}
/** @brief Generalized segmented dense Matrix-Matrix Multiplication. */
void SegmentMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray seglen_A,
bool A_trans, bool B_trans) {
void SegmentMM(
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool A_trans, bool B_trans) {
CHECK_EQ(A->ndim, 2) << "segment_mm expects a 2D tensor for the first input.";
CHECK_EQ(B->ndim, 3) << "segment_mm expects a 3D tensor for the second input.";
CHECK_EQ(B->ndim, 3)
<< "segment_mm expects a 3D tensor for the second input.";
CHECK(!A_trans);
if (B_trans) {
CHECK_EQ(A->shape[1], B->shape[2])
<< "segment_mm expects A.shape[1] == B.shape[2] when B_trans=True";
<< "segment_mm expects A.shape[1] == B.shape[2] when B_trans=True";
} else {
CHECK_EQ(A->shape[1], B->shape[1]) << "segment_mm expects A.shape[1] == B.shape[1]";
CHECK_EQ(A->shape[1], B->shape[1])
<< "segment_mm expects A.shape[1] == B.shape[1]";
}
CHECK_EQ(B->shape[0], seglen_A.NumElements())
<< "segment_mm expects len(seglen_A) == B.shape[0]";
<< "segment_mm expects len(seglen_A) == B.shape[0]";
CHECK_EQ(seglen_A->ctx.device_type, kDGLCPU)
<< "segment_mm expects seglen_A to be on CPU.";
CHECK(A->ctx == B->ctx) << "segment_mm expects A and B to be of the same device";
<< "segment_mm expects seglen_A to be on CPU.";
CHECK(A->ctx == B->ctx)
<< "segment_mm expects A and B to be of the same device";
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "SegmentMM", {
ATEN_ID_TYPE_SWITCH(seglen_A->dtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", {
......@@ -83,15 +78,14 @@ void SegmentMM(const NDArray A,
});
}
void SegmentMMBackwardB(const NDArray A,
const NDArray dC,
NDArray dB,
const NDArray seglen) {
CHECK_EQ(A->ndim, 2) << "segment_mm_backward operator expects a 2D tensor for the first input.";
CHECK_EQ(dC->ndim, 2)
<< "segment_mm_backward operator expects a 2D tensor for the second input.";
void SegmentMMBackwardB(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen) {
CHECK_EQ(A->ndim, 2) << "segment_mm_backward operator expects a 2D tensor "
"for the first input.";
CHECK_EQ(dC->ndim, 2) << "segment_mm_backward operator expects a 2D tensor "
"for the second input.";
CHECK_EQ(seglen->ctx.device_type, kDGLCPU)
<< "segment_mm expects seglen to be on CPU.";
<< "segment_mm expects seglen to be on CPU.";
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "SegmentMMBackwardB", {
ATEN_ID_TYPE_SWITCH(seglen->dtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", {
......@@ -101,34 +95,35 @@ void SegmentMMBackwardB(const NDArray A,
});
}
/** @brief Generalized Dense Matrix-Matrix Multiplication according to relation types. */
void GatherMM(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b) {
CHECK_EQ(A->ndim, 2) << "gather_mm operator expects a 2D tensor for the first input.";
CHECK_EQ(B->ndim, 3) << "gather_mm operator expects a 3D tensor for the second input.";
/** @brief Generalized Dense Matrix-Matrix Multiplication according to relation
* types. */
void GatherMM(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b) {
CHECK_EQ(A->ndim, 2)
<< "gather_mm operator expects a 2D tensor for the first input.";
CHECK_EQ(B->ndim, 3)
<< "gather_mm operator expects a 3D tensor for the second input.";
CHECK(A->ctx == B->ctx)
<< "gather_mm expects all arguments to be on the same device.";
<< "gather_mm expects all arguments to be on the same device.";
if (aten::IsNullArray(idx_a)) {
CHECK_EQ(A->shape[0], idx_b->shape[0])
<< "gather_mm expects len(idx_b) == A.shape[0] when idx_a is None.";
<< "gather_mm expects len(idx_b) == A.shape[0] when idx_a is None.";
CHECK(A->ctx == idx_b->ctx)
<< "gather_mm expects all arguments to be on the same device.";
<< "gather_mm expects all arguments to be on the same device.";
} else if (aten::IsNullArray(idx_b)) {
CHECK_EQ(B->shape[0], idx_a->shape[0])
<< "gather_mm expects len(idx_a) == B.shape[0] when idx_b is None.";
<< "gather_mm expects len(idx_a) == B.shape[0] when idx_b is None.";
CHECK(A->ctx == idx_a->ctx)
<< "gather_mm expects all arguments to be on the same device.";
<< "gather_mm expects all arguments to be on the same device.";
} else {
CHECK_EQ(idx_a->shape[0], idx_b->shape[0])
<< "gather_mm expects len(idx_a) == len(idx_b) when both idx_a and idx_b are given.";
<< "gather_mm expects len(idx_a) == len(idx_b) when both idx_a and "
"idx_b are given.";
CHECK(A->ctx == idx_a->ctx && A->ctx == idx_b->ctx)
<< "gather_mm expects all arguments to be on the same device.";
<< "gather_mm expects all arguments to be on the same device.";
}
const auto idtype = aten::IsNullArray(idx_a)? idx_b->dtype : idx_a->dtype;
const auto idtype = aten::IsNullArray(idx_a) ? idx_b->dtype : idx_a->dtype;
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", {
ATEN_ID_TYPE_SWITCH(idtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", {
......@@ -138,36 +133,36 @@ void GatherMM(const NDArray A,
});
}
/** @brief Generalized Dense Matrix-Matrix Multiplication according to relation types. */
void GatherMMScatter(const NDArray A,
const NDArray B,
NDArray C,
const NDArray idx_a,
const NDArray idx_b,
const NDArray idx_c) {
CHECK_EQ(A->ndim, 2) << "gather_mm_scatter expects a 2D tensor for the first input.";
/** @brief Generalized Dense Matrix-Matrix Multiplication according to relation
* types. */
void GatherMMScatter(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c) {
CHECK_EQ(A->ndim, 2)
<< "gather_mm_scatter expects a 2D tensor for the first input.";
CHECK(A->ctx == B->ctx)
<< "gather_mm_scatter expects all arguments to be on the same device.";
<< "gather_mm_scatter expects all arguments to be on the same device.";
if (!aten::IsNullArray(idx_c))
CHECK(A->ctx == idx_c->ctx)
<< "gather_mm_scatter expects all arguments to be on the same device.";
<< "gather_mm_scatter expects all arguments to be on the same device.";
if (aten::IsNullArray(idx_a) && !aten::IsNullArray(idx_b)) {
CHECK_EQ(A->shape[0], idx_b->shape[0])
<< "gather_mm_scatter expects len(idx_b) == A.shape[0] when idx_a is None.";
<< "gather_mm_scatter expects len(idx_b) == A.shape[0] when idx_a is "
"None.";
CHECK(A->ctx == idx_b->ctx)
<< "gather_mm_scatter expects all arguments to be on the same device.";
<< "gather_mm_scatter expects all arguments to be on the same device.";
} else if (aten::IsNullArray(idx_b) && !aten::IsNullArray(idx_a)) {
CHECK_EQ(B->shape[0], idx_a->shape[0])
<< "gather_mm_scatter expects len(idx_a) == B.shape[0] when idx_b is None.";
<< "gather_mm_scatter expects len(idx_a) == B.shape[0] when idx_b is "
"None.";
CHECK(A->ctx == idx_a->ctx)
<< "gather_mm_scatter expects all arguments to be on the same device.";
<< "gather_mm_scatter expects all arguments to be on the same device.";
} else if (!aten::IsNullArray(idx_b) && !aten::IsNullArray(idx_a)) {
CHECK_EQ(idx_a->shape[0], idx_b->shape[0])
<< "gather_mm_scatter expects len(idx_a) == len(idx_b) "
<< "when both idx_a and idx_b are given.";
<< "gather_mm_scatter expects len(idx_a) == len(idx_b) "
<< "when both idx_a and idx_b are given.";
CHECK(A->ctx == idx_a->ctx && A->ctx == idx_b->ctx)
<< "gather_mm_scatter expects all arguments to be on the same device.";
<< "gather_mm_scatter expects all arguments to be on the same device.";
}
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", {
ATEN_ID_TYPE_SWITCH(idx_c->dtype, IdType, {
......@@ -178,14 +173,13 @@ void GatherMMScatter(const NDArray A,
});
}
/** @brief Generalized Sparse Matrix-Matrix Multiplication with hetero-graph support. */
void SpMMHetero(const std::string& op, const std::string& reduce,
HeteroGraphPtr graph,
const std::vector<NDArray>& ufeat_vec,
const std::vector<NDArray>& efeat_vec,
std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux) {
/** @brief Generalized Sparse Matrix-Matrix Multiplication with hetero-graph
* support. */
void SpMMHetero(
const std::string& op, const std::string& reduce, HeteroGraphPtr graph,
const std::vector<NDArray>& ufeat_vec,
const std::vector<NDArray>& efeat_vec, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux) {
SparseFormat format = graph->SelectFormat(0, CSC_CODE);
std::vector<CSRMatrix> vec_graph;
......@@ -193,7 +187,8 @@ void SpMMHetero(const std::string& op, const std::string& reduce,
std::vector<dgl_type_t> efeat_eid;
std::vector<dgl_type_t> out_eid;
auto pair = graph->meta_graph()->FindEdge(0); // first etype
NDArray ufeat_etype0 = (ufeat_vec.size() == 0) ? NullArray() : ufeat_vec[pair.first];
NDArray ufeat_etype0 =
(ufeat_vec.size() == 0) ? NullArray() : ufeat_vec[pair.first];
NDArray efeat_etype0 = (efeat_vec.size() == 0) ? NullArray() : efeat_vec[0];
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
vec_graph.push_back(graph->GetCSCMatrix(etype));
......@@ -202,54 +197,53 @@ void SpMMHetero(const std::string& op, const std::string& reduce,
efeat_eid.push_back(etype);
out_eid.push_back(pair.second);
if (ufeat_etype0->shape[1] != ufeat_vec[pair.first]->shape[1])
LOG(FATAL) << "Column width of the input node features of all etypes must be same.";
LOG(FATAL) << "Column width of the input node features of all etypes "
"must be same.";
if (efeat_etype0->shape[1] != efeat_vec[etype]->shape[1])
LOG(FATAL) << "Column width of the input edge features of all etypes must be same.";
LOG(FATAL) << "Column width of the input edge features of all etypes "
"must be same.";
}
const auto& bcast = CalcBcastOff(op, ufeat_etype0, efeat_etype0);
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS((*out)[out_eid[0]]->dtype, Dtype, XPU, "Feature data", {
if (format == SparseFormat::kCSC) {
SpMMCsrHetero<XPU, IdType, Dtype>(
op, reduce, bcast, vec_graph,
ufeat_vec, efeat_vec, out, out_aux,
ufeat_eid, out_eid);
} else {
// TODO(Israt): Add support for COO format
LOG(FATAL) << "SpMM only supports CSC format for graphs with number "
<< "of relation types > 1";
}
});
});
ATEN_ID_TYPE_SWITCH(
graph->DataType(), IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(
(*out)[out_eid[0]]->dtype, Dtype, XPU, "Feature data", {
if (format == SparseFormat::kCSC) {
SpMMCsrHetero<XPU, IdType, Dtype>(
op, reduce, bcast, vec_graph, ufeat_vec, efeat_vec, out,
out_aux, ufeat_eid, out_eid);
} else {
// TODO(Israt): Add support for COO format
LOG(FATAL)
<< "SpMM only supports CSC format for graphs with number "
<< "of relation types > 1";
}
});
});
});
}
/** @brief Generalized Sampled Dense-Dense Matrix Multiplication. */
void SDDMM(const std::string& op,
HeteroGraphPtr graph,
NDArray lhs,
NDArray rhs,
NDArray out,
int lhs_target,
int rhs_target) {
void SDDMM(
const std::string& op, HeteroGraphPtr graph, NDArray lhs, NDArray rhs,
NDArray out, int lhs_target, int rhs_target) {
// TODO(zihao): format tuning
SparseFormat format = graph->SelectFormat(0, COO_CODE);
const auto &bcast = CalcBcastOff(op, lhs, rhs);
const auto& bcast = CalcBcastOff(op, lhs, rhs);
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", {
if (format == SparseFormat::kCSR) {
SDDMMCsr<XPU, IdType, Dtype>(
op, bcast, graph->GetCSRMatrix(0),
lhs, rhs, out, lhs_target, rhs_target);
op, bcast, graph->GetCSRMatrix(0), lhs, rhs, out, lhs_target,
rhs_target);
} else if (format == SparseFormat::kCOO) {
SDDMMCoo<XPU, IdType, Dtype>(
op, bcast, graph->GetCOOMatrix(0),
lhs, rhs, out, lhs_target, rhs_target);
op, bcast, graph->GetCOOMatrix(0), lhs, rhs, out, lhs_target,
rhs_target);
} else {
LOG(FATAL) << "SDDMM only supports CSR and COO formats";
}
......@@ -267,21 +261,16 @@ void SDDMM(const std::string& op,
*/
int get_typeid_by_target(HeteroGraphPtr graph, int target, dgl_type_t etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
if (target == 0)
return pair.first;
if (target == 2)
return pair.second;
if (target == 0) return pair.first;
if (target == 2) return pair.second;
return etype;
}
/** @brief Generalized Sampled Dense-Dense Matrix Multiplication. */
void SDDMMHetero(const std::string& op,
HeteroGraphPtr graph,
std::vector<NDArray> lhs,
std::vector<NDArray> rhs,
std::vector<NDArray> out,
int lhs_target,
int rhs_target) {
void SDDMMHetero(
const std::string& op, HeteroGraphPtr graph, std::vector<NDArray> lhs,
std::vector<NDArray> rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target) {
SparseFormat format = graph->SelectFormat(0, COO_CODE);
std::vector<dgl_type_t> lhs_eid;
......@@ -290,79 +279,74 @@ void SDDMMHetero(const std::string& op,
lhs_eid.push_back(get_typeid_by_target(graph, lhs_target, etype));
rhs_eid.push_back(get_typeid_by_target(graph, rhs_target, etype));
}
const auto &bcast = CalcBcastOff(op, lhs[lhs_eid[0]], rhs[rhs_eid[0]]);
const auto& bcast = CalcBcastOff(op, lhs[lhs_eid[0]], rhs[rhs_eid[0]]);
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(out[rhs_eid[0]]->dtype, Dtype, XPU, "Feature data", {
if (format == SparseFormat::kCSR) {
std::vector<CSRMatrix> vec_csr;
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
vec_csr.push_back(graph->GetCSRMatrix(etype));
}
SDDMMCsrHetero<XPU, IdType, Dtype>(
op, bcast, vec_csr,
lhs, rhs, out, lhs_target, rhs_target,
lhs_eid, rhs_eid);
} else if (format == SparseFormat::kCOO) {
std::vector<COOMatrix> vec_coo;
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
vec_coo.push_back(graph->GetCOOMatrix(etype));
}
SDDMMCooHetero<XPU, IdType, Dtype>(
op, bcast, vec_coo,
lhs, rhs, out, lhs_target, rhs_target,
lhs_eid, rhs_eid);
} else {
LOG(FATAL) << "SDDMM only supports CSR and COO formats";
}
});
ATEN_FLOAT_TYPE_SWITCH_16BITS(
out[rhs_eid[0]]->dtype, Dtype, XPU, "Feature data", {
if (format == SparseFormat::kCSR) {
std::vector<CSRMatrix> vec_csr;
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes();
++etype) {
vec_csr.push_back(graph->GetCSRMatrix(etype));
}
SDDMMCsrHetero<XPU, IdType, Dtype>(
op, bcast, vec_csr, lhs, rhs, out, lhs_target, rhs_target,
lhs_eid, rhs_eid);
} else if (format == SparseFormat::kCOO) {
std::vector<COOMatrix> vec_coo;
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes();
++etype) {
vec_coo.push_back(graph->GetCOOMatrix(etype));
}
SDDMMCooHetero<XPU, IdType, Dtype>(
op, bcast, vec_coo, lhs, rhs, out, lhs_target, rhs_target,
lhs_eid, rhs_eid);
} else {
LOG(FATAL) << "SDDMM only supports CSR and COO formats";
}
});
});
});
}
/** @brief Generalized Edge_softmax op for forward */
void Edge_softmax_forward(const std::string& op,
HeteroGraphPtr graph,
NDArray ufeat,
NDArray efeat,
NDArray out) {
void Edge_softmax_forward(
const std::string& op, HeteroGraphPtr graph, NDArray ufeat, NDArray efeat,
NDArray out) {
// TODO(zhejiang): add gpu op for edge_softmax
const auto& bcast = CalcBcastOff(op, ufeat, efeat);
ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "edge_softmax", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "edge_softmax out data", {
Edge_softmax_csr_forward<XPU, IdType, Dtype>(
op, bcast, graph->GetCSCMatrix(0), ufeat, efeat, out);
});
ATEN_FLOAT_TYPE_SWITCH_16BITS(
out->dtype, Dtype, XPU, "edge_softmax out data", {
Edge_softmax_csr_forward<XPU, IdType, Dtype>(
op, bcast, graph->GetCSCMatrix(0), ufeat, efeat, out);
});
});
});
}
/** @brief Generalized Edge_softmax op for backward */
void Edge_softmax_backward(const std::string& op,
HeteroGraphPtr graph,
NDArray out,
NDArray sds,
NDArray back_out,
NDArray ufeat) {
void Edge_softmax_backward(
const std::string& op, HeteroGraphPtr graph, NDArray out, NDArray sds,
NDArray back_out, NDArray ufeat) {
// TODO(zhejiang): add gpu op for edge_softmax
const auto& bcast = CalcBcastOff(op, ufeat, sds);
ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "edge_softmax_back", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "edge_softmax out data_back", {
Edge_softmax_csr_backward<XPU, IdType, Dtype>(
op, bcast, graph->GetCSCMatrix(0), out, sds, back_out);
});
ATEN_FLOAT_TYPE_SWITCH_16BITS(
out->dtype, Dtype, XPU, "edge_softmax out data_back", {
Edge_softmax_csr_backward<XPU, IdType, Dtype>(
op, bcast, graph->GetCSCMatrix(0), out, sds, back_out);
});
});
});
}
NDArray GetEdgeMapping(HeteroGraphRef graph) {
SparseFormat format = graph->SelectFormat(0, CSC_CODE);
if (format == SparseFormat::kCSC) {
......@@ -373,15 +357,13 @@ NDArray GetEdgeMapping(HeteroGraphRef graph) {
}
/** @brief Segment reduce dispatch function. */
void SegmentReduceDispatch(const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg) {
void SegmentReduceDispatch(
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg) {
ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "SegmentReduce", {
ATEN_ID_TYPE_SWITCH(offsets->dtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(feat->dtype, Dtype, XPU, "Feature data", {
SegmentReduce<XPU, IdType, Dtype>(op, feat, offsets, out, arg);
SegmentReduce<XPU, IdType, Dtype>(op, feat, offsets, out, arg);
});
});
});
......@@ -398,20 +380,21 @@ void ScatterAddDispatch(NDArray feat, NDArray idx, NDArray out) {
});
}
/** @brief Update gradients (reduce op max/min) dispatch function on heterogeneous graph. */
void UpdateGradMinMaxDispatchHetero(const HeteroGraphPtr& graph,
const std::string& op,
const std::vector<NDArray>& feat,
const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out) {
/** @brief Update gradients (reduce op max/min) dispatch function on
* heterogeneous graph. */
void UpdateGradMinMaxDispatchHetero(
const HeteroGraphPtr& graph, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {
auto pair = graph->meta_graph()->FindEdge(0); // checking the first etype
auto src_id = pair.first;
ATEN_XPU_SWITCH_CUDA(feat[src_id]->ctx.device_type, XPU, "ScatterAdd", {
ATEN_ID_TYPE_SWITCH(idx[src_id]->dtype, IdType, {
ATEN_FLOAT_TYPE_SWITCH_16BITS(feat[src_id]->dtype, Dtype, XPU, "Feature data", {
UpdateGradMinMax_hetero<XPU, IdType, Dtype>(graph, op, feat, idx, idx_etype, out);
});
ATEN_FLOAT_TYPE_SWITCH_16BITS(
feat[src_id]->dtype, Dtype, XPU, "Feature data", {
UpdateGradMinMax_hetero<XPU, IdType, Dtype>(
graph, op, feat, idx, idx_etype, out);
});
});
});
}
......@@ -428,20 +411,19 @@ void BackwardSegmentCmpDispatch(NDArray feat, NDArray arg, NDArray out) {
}
std::pair<CSRMatrix, NDArray> CSRMM(
CSRMatrix A,
NDArray A_weights,
CSRMatrix B,
NDArray B_weights) {
CHECK_EQ(A.num_cols, B.num_rows) <<
"The number of nodes of destination node type of the first graph must be the "
"same as the number of nodes of source node type of the second graph.";
CSRMatrix A, NDArray A_weights, CSRMatrix B, NDArray B_weights) {
CHECK_EQ(A.num_cols, B.num_rows)
<< "The number of nodes of destination node type of the first graph must "
"be the "
"same as the number of nodes of source node type of the second graph.";
CheckCtx(
A.indptr->ctx,
{A_weights, B_weights},
A.indptr->ctx, {A_weights, B_weights},
{"A's edge weights", "B's edge weights"});
CHECK_EQ(A.indptr->ctx, B.indptr->ctx) << "Device of two graphs must match.";
CHECK_EQ(A.indptr->dtype, B.indptr->dtype) << "ID types of two graphs must match.";
CHECK_EQ(A_weights->dtype, B_weights->dtype) << "Data types of two edge weights must match.";
CHECK_EQ(A.indptr->dtype, B.indptr->dtype)
<< "ID types of two graphs must match.";
CHECK_EQ(A_weights->dtype, B_weights->dtype)
<< "Data types of two edge weights must match.";
std::pair<CSRMatrix, NDArray> ret;
ATEN_XPU_SWITCH_CUDA(A.indptr->ctx.device_type, XPU, "CSRMM", {
......@@ -455,27 +437,31 @@ std::pair<CSRMatrix, NDArray> CSRMM(
}
std::pair<CSRMatrix, NDArray> CSRSum(
const std::vector<CSRMatrix>& A,
const std::vector<NDArray>& A_weights) {
const std::vector<CSRMatrix>& A, const std::vector<NDArray>& A_weights) {
CHECK(A.size() > 0) << "The list of graphs must not be empty.";
CHECK_EQ(A.size(), A_weights.size()) <<
"The list of edge weights must have the same length as the list of graphs.";
CHECK_EQ(A.size(), A_weights.size())
<< "The list of edge weights must have the same length as the list of "
"graphs.";
const auto ctx = A[0].indptr->ctx;
const auto idtype = A[0].indptr->dtype;
const auto dtype = A_weights[0]->dtype;
const auto num_rows = A[0].num_rows;
const auto num_cols = A[0].num_cols;
for (size_t i = 0; i < A.size(); ++i) {
CHECK_EQ(A[i].indptr->ctx, ctx) << "The devices of all graphs must be equal.";
CHECK_EQ(A[i].indptr->dtype, idtype) << "The ID types of all graphs must be equal.";
CHECK_EQ(A[i].indices->shape[0], A_weights[i]->shape[0]) <<
"Shape of edge weights does not match the number of edges.";
CHECK_EQ(A_weights[i]->ctx, ctx) <<
"The devices of edge weights must be the same as that of the graphs.";
CHECK_EQ(A_weights[i]->dtype, dtype) <<
"The data types of all edge weights must be equal.";
CHECK_EQ(A[i].num_rows, num_rows) << "Graphs must have the same number of nodes.";
CHECK_EQ(A[i].num_cols, num_cols) << "Graphs must have the same number of nodes.";
CHECK_EQ(A[i].indptr->ctx, ctx)
<< "The devices of all graphs must be equal.";
CHECK_EQ(A[i].indptr->dtype, idtype)
<< "The ID types of all graphs must be equal.";
CHECK_EQ(A[i].indices->shape[0], A_weights[i]->shape[0])
<< "Shape of edge weights does not match the number of edges.";
CHECK_EQ(A_weights[i]->ctx, ctx) << "The devices of edge weights must be "
"the same as that of the graphs.";
CHECK_EQ(A_weights[i]->dtype, dtype)
<< "The data types of all edge weights must be equal.";
CHECK_EQ(A[i].num_rows, num_rows)
<< "Graphs must have the same number of nodes.";
CHECK_EQ(A[i].num_cols, num_cols)
<< "Graphs must have the same number of nodes.";
}
std::pair<CSRMatrix, NDArray> ret;
......@@ -490,238 +476,246 @@ std::pair<CSRMatrix, NDArray> CSRSum(
}
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
const std::string reduce_op = args[2];
NDArray U = args[3];
NDArray E = args[4];
NDArray V = args[5];
NDArray ArgU = args[6];
NDArray ArgE = args[7];
CheckCtx(graph->Context(), {U, E, V, ArgU, ArgE},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"});
CheckContiguous({U, E, V, ArgU, ArgE},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"});
CHECK_EQ(graph->NumEdgeTypes(), 1);
auto pair = graph->meta_graph()->FindEdge(0); // only one etype in the graph.
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
CheckShape(
{graph->NumVertices(src_vtype), graph->NumEdges(0), graph->NumVertices(dst_vtype)},
{0, 1, 2, 2, 2},
{U, E, V, ArgU, ArgE},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"});
SpMM(op, reduce_op, graph.sptr(), U, E, V, {ArgU, ArgE});
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
const std::string reduce_op = args[2];
NDArray U = args[3];
NDArray E = args[4];
NDArray V = args[5];
NDArray ArgU = args[6];
NDArray ArgE = args[7];
CheckCtx(
graph->Context(), {U, E, V, ArgU, ArgE},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"});
CheckContiguous(
{U, E, V, ArgU, ArgE}, {"U_data", "E_data", "out", "Arg_U", "Arg_E"});
CHECK_EQ(graph->NumEdgeTypes(), 1);
auto pair =
graph->meta_graph()->FindEdge(0); // only one etype in the graph.
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
CheckShape(
{graph->NumVertices(src_vtype), graph->NumEdges(0),
graph->NumVertices(dst_vtype)},
{0, 1, 2, 2, 2}, {U, E, V, ArgU, ArgE},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"});
SpMM(op, reduce_op, graph.sptr(), U, E, V, {ArgU, ArgE});
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGATHERMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray A = args[0];
NDArray B = args[1];
NDArray C = args[2];
NDArray idx_a = args[3];
NDArray idx_b = args[4];
GatherMM(A, B, C, idx_a, idx_b);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
NDArray A = args[0];
NDArray B = args[1];
NDArray C = args[2];
NDArray idx_a = args[3];
NDArray idx_b = args[4];
GatherMM(A, B, C, idx_a, idx_b);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGATHERMMSCATTER")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray A = args[0];
NDArray B = args[1];
NDArray C = args[2];
NDArray idx_a = args[3];
NDArray idx_b = args[4];
NDArray idx_c = args[5];
GatherMMScatter(A, B, C, idx_a, idx_b, idx_c);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
NDArray A = args[0];
NDArray B = args[1];
NDArray C = args[2];
NDArray idx_a = args[3];
NDArray idx_b = args[4];
NDArray idx_c = args[5];
GatherMMScatter(A, B, C, idx_a, idx_b, idx_c);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSEGMENTMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray A = args[0];
NDArray B = args[1];
NDArray C = args[2];
NDArray seglen_A = args[3];
bool A_trans = args[4];
bool B_trans = args[5];
SegmentMM(A, B, C, seglen_A, A_trans, B_trans);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
NDArray A = args[0];
NDArray B = args[1];
NDArray C = args[2];
NDArray seglen_A = args[3];
bool A_trans = args[4];
bool B_trans = args[5];
SegmentMM(A, B, C, seglen_A, A_trans, B_trans);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSEGMENTMMBackwardB")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray A = args[0];
NDArray dC = args[1];
NDArray dB = args[2];
NDArray seglen = args[3];
SegmentMMBackwardB(A, dC, dB, seglen);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
NDArray A = args[0];
NDArray dC = args[1];
NDArray dB = args[2];
NDArray seglen = args[3];
SegmentMMBackwardB(A, dC, dB, seglen);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelEdge_softmax_forward")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
NDArray U = args[2];
NDArray E = args[3];
NDArray V = args[4];
Edge_softmax_forward(op, graph.sptr(), U, E, V);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
NDArray U = args[2];
NDArray E = args[3];
NDArray V = args[4];
Edge_softmax_forward(op, graph.sptr(), U, E, V);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelEdge_softmax_backward")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
NDArray out = args[2];
NDArray sds = args[3];
NDArray back_out = args[4];
NDArray ufeat = args[5];
Edge_softmax_backward(op, graph.sptr(), out, sds, back_out, ufeat);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
NDArray out = args[2];
NDArray sds = args[3];
NDArray back_out = args[4];
NDArray ufeat = args[5];
Edge_softmax_backward(op, graph.sptr(), out, sds, back_out, ufeat);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSpMMHetero")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
const std::string reduce_op = args[2];
List<Value> list_U = args[3];
List<Value> list_E = args[4];
List<Value> list_V = args[5];
List<Value> list_ArgU = args[6];
List<Value> list_ArgE = args[7];
List<Value> list_ArgU_ntype = args[8];
List<Value> list_ArgE_etype = args[9];
std::vector<std::vector<NDArray>> Arg_vec; // ArgU + ArgE
for (int i = 0; i < 4; ++i) { // ArgU + ArgE + ArgU_ntype + ArgE_etype
Arg_vec.push_back(std::vector<NDArray>());
}
std::vector<NDArray> U_vec = ListValueToVector<NDArray>(list_U);
std::vector<NDArray> V_vec = ListValueToVector<NDArray>(list_V);
std::vector<NDArray> E_vec = ListValueToVector<NDArray>(list_E);
Arg_vec[0] = ListValueToVector<NDArray>(list_ArgU);
Arg_vec[1] = ListValueToVector<NDArray>(list_ArgE);
Arg_vec[2] = ListValueToVector<NDArray>(list_ArgU_ntype);
Arg_vec[3] = ListValueToVector<NDArray>(list_ArgE_etype);
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_id_t src_id = pair.first;
const dgl_id_t dst_id = pair.second;
NDArray U = (U_vec.size() == 0) ? NullArray() : U_vec[src_id];
NDArray E = (E_vec.size() == 0) ? NullArray() : E_vec[etype];
CheckCtx(graph->Context(), {U, E, V_vec[dst_id], Arg_vec[0][dst_id], Arg_vec[1][dst_id]},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"});
CheckContiguous({U, E, V_vec[dst_id], Arg_vec[0][dst_id], Arg_vec[1][dst_id]},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"});
}
SpMMHetero(op, reduce_op, graph.sptr(), U_vec, E_vec, &V_vec, &Arg_vec);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
const std::string reduce_op = args[2];
List<Value> list_U = args[3];
List<Value> list_E = args[4];
List<Value> list_V = args[5];
List<Value> list_ArgU = args[6];
List<Value> list_ArgE = args[7];
List<Value> list_ArgU_ntype = args[8];
List<Value> list_ArgE_etype = args[9];
std::vector<std::vector<NDArray>> Arg_vec; // ArgU + ArgE
for (int i = 0; i < 4; ++i) { // ArgU + ArgE + ArgU_ntype + ArgE_etype
Arg_vec.push_back(std::vector<NDArray>());
}
std::vector<NDArray> U_vec = ListValueToVector<NDArray>(list_U);
std::vector<NDArray> V_vec = ListValueToVector<NDArray>(list_V);
std::vector<NDArray> E_vec = ListValueToVector<NDArray>(list_E);
Arg_vec[0] = ListValueToVector<NDArray>(list_ArgU);
Arg_vec[1] = ListValueToVector<NDArray>(list_ArgE);
Arg_vec[2] = ListValueToVector<NDArray>(list_ArgU_ntype);
Arg_vec[3] = ListValueToVector<NDArray>(list_ArgE_etype);
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
auto pair = graph->meta_graph()->FindEdge(etype);
const dgl_id_t src_id = pair.first;
const dgl_id_t dst_id = pair.second;
NDArray U = (U_vec.size() == 0) ? NullArray() : U_vec[src_id];
NDArray E = (E_vec.size() == 0) ? NullArray() : E_vec[etype];
CheckCtx(
graph->Context(),
{U, E, V_vec[dst_id], Arg_vec[0][dst_id], Arg_vec[1][dst_id]},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"});
CheckContiguous(
{U, E, V_vec[dst_id], Arg_vec[0][dst_id], Arg_vec[1][dst_id]},
{"U_data", "E_data", "out", "Arg_U", "Arg_E"});
}
SpMMHetero(op, reduce_op, graph.sptr(), U_vec, E_vec, &V_vec, &Arg_vec);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
NDArray lhs = args[2];
NDArray rhs = args[3];
NDArray out = args[4];
int lhs_target = args[5];
int rhs_target = args[6];
CheckCtx(graph->Context(), {lhs, rhs, out}, {"lhs", "rhs", "out"});
CheckContiguous({lhs, rhs, out}, {"lhs", "rhs", "out"});
CHECK_EQ(graph->NumEdgeTypes(), 1);
auto pair = graph->meta_graph()->FindEdge(0); // only one etype in the graph.
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
CheckShape(
{graph->NumVertices(src_vtype), graph->NumEdges(0), graph->NumVertices(dst_vtype)},
{lhs_target, rhs_target, 1},
{lhs, rhs, out},
{"U_data", "E_data", "V_data"});
SDDMM(op, graph.sptr(), lhs, rhs, out, lhs_target, rhs_target);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
NDArray lhs = args[2];
NDArray rhs = args[3];
NDArray out = args[4];
int lhs_target = args[5];
int rhs_target = args[6];
CheckCtx(graph->Context(), {lhs, rhs, out}, {"lhs", "rhs", "out"});
CheckContiguous({lhs, rhs, out}, {"lhs", "rhs", "out"});
CHECK_EQ(graph->NumEdgeTypes(), 1);
auto pair =
graph->meta_graph()->FindEdge(0); // only one etype in the graph.
const dgl_type_t src_vtype = pair.first;
const dgl_type_t dst_vtype = pair.second;
CheckShape(
{graph->NumVertices(src_vtype), graph->NumEdges(0),
graph->NumVertices(dst_vtype)},
{lhs_target, rhs_target, 1}, {lhs, rhs, out},
{"U_data", "E_data", "V_data"});
SDDMM(op, graph.sptr(), lhs, rhs, out, lhs_target, rhs_target);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSDDMMHetero")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
List<Value> list_lhs = args[2];
List<Value> list_rhs = args[3];
List<Value> list_out = args[4];
int lhs_target = args[5];
int rhs_target = args[6];
std::vector<NDArray> vec_lhs;
std::vector<NDArray> vec_rhs;
std::vector<NDArray> vec_out;
vec_lhs.reserve(list_lhs.size());
vec_rhs.reserve(list_rhs.size());
vec_out.reserve(list_out.size());
for (Value val : list_lhs) {
vec_lhs.push_back(val->data);
}
for (Value val : list_rhs) {
vec_rhs.push_back(val->data);
}
for (Value val : list_out) {
vec_out.push_back(val->data);
}
SDDMMHetero(op, graph.sptr(), vec_lhs, vec_rhs, vec_out, lhs_target, rhs_target);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
List<Value> list_lhs = args[2];
List<Value> list_rhs = args[3];
List<Value> list_out = args[4];
int lhs_target = args[5];
int rhs_target = args[6];
std::vector<NDArray> vec_lhs;
std::vector<NDArray> vec_rhs;
std::vector<NDArray> vec_out;
vec_lhs.reserve(list_lhs.size());
vec_rhs.reserve(list_rhs.size());
vec_out.reserve(list_out.size());
for (Value val : list_lhs) {
vec_lhs.push_back(val->data);
}
for (Value val : list_rhs) {
vec_rhs.push_back(val->data);
}
for (Value val : list_out) {
vec_out.push_back(val->data);
}
SDDMMHetero(
op, graph.sptr(), vec_lhs, vec_rhs, vec_out, lhs_target, rhs_target);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelSegmentReduce")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const std::string op = args[0];
NDArray feat = args[1];
NDArray offsets = args[2];
NDArray out = args[3];
NDArray arg = args[4];
CheckCtx(feat->ctx, {feat, offsets, out}, {"feat", "offsets", "out"});
CheckContiguous({feat, offsets, out}, {"feat", "offsets", "out"});
SegmentReduceDispatch(op, feat, offsets, out, arg);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
const std::string op = args[0];
NDArray feat = args[1];
NDArray offsets = args[2];
NDArray out = args[3];
NDArray arg = args[4];
CheckCtx(feat->ctx, {feat, offsets, out}, {"feat", "offsets", "out"});
CheckContiguous({feat, offsets, out}, {"feat", "offsets", "out"});
SegmentReduceDispatch(op, feat, offsets, out, arg);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelScatterAdd")
.set_body([](DGLArgs args, DGLRetValue *rv) {
NDArray feat = args[0];
NDArray idx = args[1];
NDArray out = args[2];
CheckCtx(feat->ctx, {feat, idx, out}, {"feat", "idx", "out"});
CheckContiguous({feat, idx, out}, {"feat", "idx", "out"});
ScatterAddDispatch(feat, idx, out);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
NDArray feat = args[0];
NDArray idx = args[1];
NDArray out = args[2];
CheckCtx(feat->ctx, {feat, idx, out}, {"feat", "idx", "out"});
CheckContiguous({feat, idx, out}, {"feat", "idx", "out"});
ScatterAddDispatch(feat, idx, out);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelUpdateGradMinMaxHetero")
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
List<Value> list_feat = args[2];
List<Value> list_idx = args[3];
List<Value> list_idx_etype = args[4];
List<Value> list_out = args[5];
std::vector<NDArray> vec_feat = ListValueToVector<NDArray>(list_feat);
std::vector<NDArray> vec_idx = ListValueToVector<NDArray>(list_idx);
std::vector<NDArray> vec_idx_etype = ListValueToVector<NDArray>(list_idx_etype);
std::vector<NDArray> vec_out = ListValueToVector<NDArray>(list_out);
// CheckCtx(feat->ctx, {feat, idx, out}, {"feat", "idx", "out"});
// CheckContiguous({feat, idx, out}, {"feat", "idx", "out"});
UpdateGradMinMaxDispatchHetero(graph.sptr(), op, vec_feat, vec_idx, vec_idx_etype, &vec_out);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
const std::string op = args[1];
List<Value> list_feat = args[2];
List<Value> list_idx = args[3];
List<Value> list_idx_etype = args[4];
List<Value> list_out = args[5];
std::vector<NDArray> vec_feat = ListValueToVector<NDArray>(list_feat);
std::vector<NDArray> vec_idx = ListValueToVector<NDArray>(list_idx);
std::vector<NDArray> vec_idx_etype =
ListValueToVector<NDArray>(list_idx_etype);
std::vector<NDArray> vec_out = ListValueToVector<NDArray>(list_out);
// CheckCtx(feat->ctx, {feat, idx, out}, {"feat", "idx", "out"});
// CheckContiguous({feat, idx, out}, {"feat", "idx", "out"});
UpdateGradMinMaxDispatchHetero(
graph.sptr(), op, vec_feat, vec_idx, vec_idx_etype, &vec_out);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelBwdSegmentCmp")
.set_body([](DGLArgs args, DGLRetValue *rv) {
NDArray feat = args[0];
NDArray arg = args[1];
NDArray out = args[2];
CheckCtx(feat->ctx, {feat, arg, out}, {"feat", "arg", "out"});
CheckContiguous({feat, arg, out}, {"feat", "arg", "out"});
BackwardSegmentCmpDispatch(feat, arg, out);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
NDArray feat = args[0];
NDArray arg = args[1];
NDArray out = args[2];
CheckCtx(feat->ctx, {feat, arg, out}, {"feat", "arg", "out"});
CheckContiguous({feat, arg, out}, {"feat", "arg", "out"});
BackwardSegmentCmpDispatch(feat, arg, out);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGetEdgeMapping")
.set_body([](DGLArgs args, DGLRetValue *rv) {
HeteroGraphRef graph = args[0];
*rv = GetEdgeMapping(graph);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
*rv = GetEdgeMapping(graph);
});
/**
* @brief Sparse matrix multiplication with graph interface.
......@@ -734,107 +728,111 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_DGLKernelGetEdgeMapping")
* @return A pair consisting of the new graph as well as its edge weights.
*/
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRMM")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const HeteroGraphRef A_ref = args[0];
NDArray A_weights = args[1];
const HeteroGraphRef B_ref = args[2];
NDArray B_weights = args[3];
int num_vtypes = args[4];
const HeteroGraphPtr A = A_ref.sptr();
const HeteroGraphPtr B = B_ref.sptr();
CHECK_EQ(A->NumEdgeTypes(), 1) << "The first graph must have only one edge type.";
CHECK_EQ(B->NumEdgeTypes(), 1) << "The second graph must have only one edge type.";
const auto A_csr = A->GetCSRMatrix(0);
const auto B_csr = B->GetCSRMatrix(0);
auto result = CSRMM(A_csr, A_weights, B_csr, B_weights);
List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(CreateFromCSR(num_vtypes, result.first, ALL_CODE)));
ret.push_back(Value(MakeValue(result.second)));
*rv = ret;
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
const HeteroGraphRef A_ref = args[0];
NDArray A_weights = args[1];
const HeteroGraphRef B_ref = args[2];
NDArray B_weights = args[3];
int num_vtypes = args[4];
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRSum")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
List<HeteroGraphRef> A_refs = args[0];
List<Value> A_weights = args[1];
std::vector<NDArray> weights = ListValueToVector<NDArray>(A_weights);
std::vector<CSRMatrix> mats;
mats.reserve(A_refs.size());
int num_vtypes = 0;
for (auto A_ref : A_refs) {
const HeteroGraphPtr A = A_ref.sptr();
CHECK_EQ(A->NumEdgeTypes(), 1) << "Graphs must have only one edge type.";
mats.push_back(A->GetCSRMatrix(0));
if (num_vtypes == 0)
num_vtypes = A->NumVertexTypes();
}
auto result = CSRSum(mats, weights);
List<ObjectRef> ret;
ret.push_back(HeteroGraphRef(CreateFromCSR(num_vtypes, result.first, ALL_CODE)));
ret.push_back(Value(MakeValue(result.second)));
*rv = ret;
});
const HeteroGraphPtr B = B_ref.sptr();
CHECK_EQ(A->NumEdgeTypes(), 1)
<< "The first graph must have only one edge type.";
CHECK_EQ(B->NumEdgeTypes(), 1)
<< "The second graph must have only one edge type.";
const auto A_csr = A->GetCSRMatrix(0);
const auto B_csr = B->GetCSRMatrix(0);
auto result = CSRMM(A_csr, A_weights, B_csr, B_weights);
List<ObjectRef> ret;
ret.push_back(
HeteroGraphRef(CreateFromCSR(num_vtypes, result.first, ALL_CODE)));
ret.push_back(Value(MakeValue(result.second)));
*rv = ret;
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRSum")
.set_body([](DGLArgs args, DGLRetValue* rv) {
List<HeteroGraphRef> A_refs = args[0];
List<Value> A_weights = args[1];
std::vector<NDArray> weights = ListValueToVector<NDArray>(A_weights);
std::vector<CSRMatrix> mats;
mats.reserve(A_refs.size());
int num_vtypes = 0;
for (auto A_ref : A_refs) {
const HeteroGraphPtr A = A_ref.sptr();
CHECK_EQ(A->NumEdgeTypes(), 1)
<< "Graphs must have only one edge type.";
mats.push_back(A->GetCSRMatrix(0));
if (num_vtypes == 0) num_vtypes = A->NumVertexTypes();
}
auto result = CSRSum(mats, weights);
List<ObjectRef> ret;
ret.push_back(
HeteroGraphRef(CreateFromCSR(num_vtypes, result.first, ALL_CODE)));
ret.push_back(Value(MakeValue(result.second)));
*rv = ret;
});
DGL_REGISTER_GLOBAL("sparse._CAPI_DGLCSRMask")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const HeteroGraphRef A_ref = args[0];
NDArray A_weights = args[1];
const HeteroGraphRef B_ref = args[2];
const HeteroGraphPtr A = A_ref.sptr();
const HeteroGraphPtr B = B_ref.sptr();
CHECK_EQ(A->NumEdgeTypes(), 1) << "Both graphs must have only one edge type.";
CHECK_EQ(B->NumEdgeTypes(), 1) << "Both graphs must have only one edge type.";
const CSRMatrix& A_csr = A->GetCSRMatrix(0);
const COOMatrix& B_coo = B->GetCOOMatrix(0);
CHECK_EQ(A_csr.num_rows, B_coo.num_rows) <<
"Both graphs must have the same number of nodes.";
CHECK_EQ(A_csr.num_cols, B_coo.num_cols) <<
"Both graphs must have the same number of nodes.";
NDArray result;
ATEN_FLOAT_TYPE_SWITCH(A_weights->dtype, DType, "Edge weights", {
result = aten::CSRGetData<DType>(A_csr, B_coo.row, B_coo.col, A_weights, 0.);
});
*rv = result;
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
const HeteroGraphRef A_ref = args[0];
NDArray A_weights = args[1];
const HeteroGraphRef B_ref = args[2];
const HeteroGraphPtr A = A_ref.sptr();
const HeteroGraphPtr B = B_ref.sptr();
CHECK_EQ(A->NumEdgeTypes(), 1)
<< "Both graphs must have only one edge type.";
CHECK_EQ(B->NumEdgeTypes(), 1)
<< "Both graphs must have only one edge type.";
const CSRMatrix& A_csr = A->GetCSRMatrix(0);
const COOMatrix& B_coo = B->GetCOOMatrix(0);
CHECK_EQ(A_csr.num_rows, B_coo.num_rows)
<< "Both graphs must have the same number of nodes.";
CHECK_EQ(A_csr.num_cols, B_coo.num_cols)
<< "Both graphs must have the same number of nodes.";
NDArray result;
ATEN_FLOAT_TYPE_SWITCH(A_weights->dtype, DType, "Edge weights", {
result =
aten::CSRGetData<DType>(A_csr, B_coo.row, B_coo.col, A_weights, 0.);
});
*rv = result;
});
#ifdef USE_TVM
DGL_REGISTER_GLOBAL("sparse._CAPI_FG_LoadModule")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
const std::string path = args[0];
dgl::featgraph::LoadFeatGraphModule(path);
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
const std::string path = args[0];
dgl::featgraph::LoadFeatGraphModule(path);
});
DGL_REGISTER_GLOBAL("sparse._CAPI_FG_SDDMMTreeReduction")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
NDArray lhs = args[1];
NDArray rhs = args[2];
NDArray out = args[3];
CheckCtx(graph->Context(), {lhs, rhs, out}, {"lhs", "rhs", "out"});
CheckContiguous({lhs, rhs, out}, {"lhs", "rhs", "out"});
CHECK_EQ(graph->NumEdgeTypes(), 1);
// auto pair = graph->meta_graph()->FindEdge(0); // only one etype in the graph.
// const dgl_type_t src_vtype = pair.first;
// const dgl_type_t dst_vtype = pair.second;
// CheckShape(
// {graph->NumVertices(src_vtype), graph->NumEdges(0), graph->NumVertices(dst_vtype)},
// {lhs_target, rhs_target, 1},
// {lhs, rhs, out},
// {"U_data", "E_data", "V_data"});
COOMatrix coo = graph.sptr()->GetCOOMatrix(0);
dgl::featgraph::SDDMMTreeReduction(
DLPackConvert::ToDLPack(coo.row),
DLPackConvert::ToDLPack(coo.col),
DLPackConvert::ToDLPack(lhs),
DLPackConvert::ToDLPack(rhs),
DLPackConvert::ToDLPack(out));
});
.set_body([](DGLArgs args, DGLRetValue* rv) {
HeteroGraphRef graph = args[0];
NDArray lhs = args[1];
NDArray rhs = args[2];
NDArray out = args[3];
CheckCtx(graph->Context(), {lhs, rhs, out}, {"lhs", "rhs", "out"});
CheckContiguous({lhs, rhs, out}, {"lhs", "rhs", "out"});
CHECK_EQ(graph->NumEdgeTypes(), 1);
// auto pair = graph->meta_graph()->FindEdge(0); // only one etype in the
// graph. const dgl_type_t src_vtype = pair.first; const dgl_type_t
// dst_vtype = pair.second; CheckShape(
// {graph->NumVertices(src_vtype), graph->NumEdges(0),
// graph->NumVertices(dst_vtype)}, {lhs_target, rhs_target, 1}, {lhs,
// rhs, out},
// {"U_data", "E_data", "V_data"});
COOMatrix coo = graph.sptr()->GetCOOMatrix(0);
dgl::featgraph::SDDMMTreeReduction(
DLPackConvert::ToDLPack(coo.row), DLPackConvert::ToDLPack(coo.col),
DLPackConvert::ToDLPack(lhs), DLPackConvert::ToDLPack(rhs),
DLPackConvert::ToDLPack(out));
});
#endif // USE_TVM
} // namespace aten
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment