Unverified Commit 96297fb8 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Add bfloat16 (bf16) support (#4648)

* add bf16 specializations

* remove SWITCH_BITS

* enable amp for bf16

* remove SWITCH_BITS for cpu kernels

* enbale bf16 based on CUDART

* fix compiling for sm<80

* fix cpu build

* enable unit tests

* update doc

* disable test for CUDA < 11.0

* address comments

* address comments
parent 1d229194
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <limits> #include <limits>
#include "./atomic.cuh" #include "./atomic.cuh"
#include "./fp16.cuh" #include "./fp16.cuh"
#include "bf16.cuh"
namespace dgl { namespace dgl {
namespace aten { namespace aten {
...@@ -108,7 +109,7 @@ struct Dot { ...@@ -108,7 +109,7 @@ struct Dot {
static constexpr bool reduce_last_dim = true; static constexpr bool reduce_last_dim = true;
static __device__ __forceinline__ DType Call( static __device__ __forceinline__ DType Call(
const DType *lhs, const DType *rhs, int64_t len = 1) { const DType *lhs, const DType *rhs, int64_t len = 1) {
DType rst = static_cast<DType>(0); DType rst = static_cast<DType>(0.0f);
for (int64_t i = 0; i < len; ++i) { for (int64_t i = 0; i < len; ++i) {
rst += lhs[i] * rhs[i]; rst += lhs[i] * rhs[i];
} }
...@@ -159,14 +160,21 @@ template <typename Idx, ...@@ -159,14 +160,21 @@ template <typename Idx,
bool atomic = false> bool atomic = false>
struct Sum: _Sum<Idx, DType, atomic> { }; struct Sum: _Sum<Idx, DType, atomic> { };
#ifdef USE_FP16
template <typename Idx, bool atomic> template <typename Idx, bool atomic>
struct Sum<Idx, half, atomic>: _Sum<Idx, half, atomic> { struct Sum<Idx, half, atomic>: _Sum<Idx, half, atomic> {
static constexpr __host__ __device__ __forceinline__ half zero() { static constexpr __host__ __device__ __forceinline__ half zero() {
return __float2half_rn(0.); return __float2half_rn(0.);
} }
}; };
#endif // USE_FP16
#if BF16_ENABLED
template <typename Idx, bool atomic>
struct Sum<Idx, __nv_bfloat16, atomic>: _Sum<Idx, __nv_bfloat16, atomic> {
static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {
return __float2bfloat16_rn(0.);
}
};
#endif // BF16_ENABLED
template <typename Idx, template <typename Idx,
typename DType, typename DType,
...@@ -220,7 +228,6 @@ template <typename Idx, ...@@ -220,7 +228,6 @@ template <typename Idx,
bool atomic = false> bool atomic = false>
struct Max : _Max<Idx, DType, atomic> { }; struct Max : _Max<Idx, DType, atomic> { };
#ifdef USE_FP16
template <typename Idx, template <typename Idx,
bool atomic> bool atomic>
struct Max<Idx, half, atomic> : _Max<Idx, half, atomic> { struct Max<Idx, half, atomic> : _Max<Idx, half, atomic> {
...@@ -228,7 +235,16 @@ struct Max<Idx, half, atomic> : _Max<Idx, half, atomic> { ...@@ -228,7 +235,16 @@ struct Max<Idx, half, atomic> : _Max<Idx, half, atomic> {
return __float2half_rn(-6.550400e+04f); return __float2half_rn(-6.550400e+04f);
} }
}; };
#endif
#if BF16_ENABLED
template <typename Idx,
bool atomic>
struct Max<Idx, __nv_bfloat16, atomic> : _Max<Idx, __nv_bfloat16, atomic> {
static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {
return __float2bfloat16_rn(-std::numeric_limits<float>::infinity());
}
};
#endif // BF16_ENABLED
template <typename Idx, template <typename Idx,
typename DType, typename DType,
...@@ -282,7 +298,6 @@ template <typename Idx, ...@@ -282,7 +298,6 @@ template <typename Idx,
bool atomic = false> bool atomic = false>
struct Min : _Min<Idx, DType, atomic> { }; struct Min : _Min<Idx, DType, atomic> { };
#ifdef USE_FP16
template <typename Idx, template <typename Idx,
bool atomic> bool atomic>
struct Min<Idx, half, atomic> : _Min<Idx, half, atomic> { struct Min<Idx, half, atomic> : _Min<Idx, half, atomic> {
...@@ -290,7 +305,16 @@ struct Min<Idx, half, atomic> : _Min<Idx, half, atomic> { ...@@ -290,7 +305,16 @@ struct Min<Idx, half, atomic> : _Min<Idx, half, atomic> {
return __float2half_rn(6.550400e+04f); return __float2half_rn(6.550400e+04f);
} }
}; };
#endif // USE_FP16
#if BF16_ENABLED
template <typename Idx,
bool atomic>
struct Min<Idx, __nv_bfloat16, atomic> : _Min<Idx, __nv_bfloat16, atomic> {
static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {
return __float2bfloat16_rn(std::numeric_limits<float>::infinity());
}
};
#endif // BF16_ENABLED
} // namespace reduce } // namespace reduce
......
...@@ -26,7 +26,6 @@ cublasStatus_t cublasGemm(cublasHandle_t handle, cublasOperation_t transa, ...@@ -26,7 +26,6 @@ cublasStatus_t cublasGemm(cublasHandle_t handle, cublasOperation_t transa,
return CUBLAS_STATUS_EXECUTION_FAILED; return CUBLAS_STATUS_EXECUTION_FAILED;
} }
#ifdef USE_FP16
template <> template <>
cublasStatus_t cublasGemm<__half>(cublasHandle_t handle, cublasOperation_t transa, cublasStatus_t cublasGemm<__half>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k, cublasOperation_t transb, int m, int n, int k,
...@@ -36,7 +35,23 @@ cublasStatus_t cublasGemm<__half>(cublasHandle_t handle, cublasOperation_t trans ...@@ -36,7 +35,23 @@ cublasStatus_t cublasGemm<__half>(cublasHandle_t handle, cublasOperation_t trans
return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda, return cublasHgemm(handle, transa, transb, m, n, k, alpha, A, lda,
B, ldb, beta, C, ldc); B, ldb, beta, C, ldc);
} }
#endif
#if BF16_ENABLED
template <>
cublasStatus_t cublasGemm<__nv_bfloat16>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, int k,
const __nv_bfloat16* alpha, const __nv_bfloat16* A, int lda,
const __nv_bfloat16* B, int ldb, const __nv_bfloat16* beta,
__nv_bfloat16* C, int ldc) {
float alpha_float = __bfloat162float(*alpha);
float beta_float = __bfloat162float(*beta);
return cublasGemmEx(handle, transa, transb, m, n, k,
&alpha_float, A, CUDA_R_16BF, lda,
B, CUDA_R_16BF, ldb,
&beta_float, C, CUDA_R_16BF, ldc,
CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
#endif // BF16_ENABLED
template <> template <>
cublasStatus_t cublasGemm<float>(cublasHandle_t handle, cublasOperation_t transa, cublasStatus_t cublasGemm<float>(cublasHandle_t handle, cublasOperation_t transa,
...@@ -102,7 +117,7 @@ __global__ void GatherMMScatterKernel( ...@@ -102,7 +117,7 @@ __global__ void GatherMMScatterKernel(
__syncwarp(); __syncwarp();
for (unsigned int outloop = 0; outloop < out_len; outloop +=32) { for (unsigned int outloop = 0; outloop < out_len; outloop +=32) {
DType out_reg = 0; // thread private DType out_reg = static_cast<DType>(0.0f); // thread private
const unsigned int l = laneId; const unsigned int l = laneId;
if (l < out_len) { if (l < out_len) {
// iterate over elements of a row of A // iterate over elements of a row of A
...@@ -163,7 +178,7 @@ __global__ void GatherMMScatterKernel2( ...@@ -163,7 +178,7 @@ __global__ void GatherMMScatterKernel2(
__syncwarp(); __syncwarp();
for (unsigned int outloop = 0; outloop < out_len; outloop +=32) { for (unsigned int outloop = 0; outloop < out_len; outloop +=32) {
DType out_reg = 0; // thread private DType out_reg = static_cast<DType>(0.0f); // thread private
const unsigned int l = laneId; const unsigned int l = laneId;
if (l < out_len) { if (l < out_len) {
const DType b_val = B[row_b * out_len + (outloop + l)]; const DType b_val = B[row_b * out_len + (outloop + l)];
...@@ -192,108 +207,104 @@ __global__ void GatherMMScatterKernel2( ...@@ -192,108 +207,104 @@ __global__ void GatherMMScatterKernel2(
* \param a_trans Matrix A to be transposed * \param a_trans Matrix A to be transposed
* \param b_trans Matrix B to be transposed * \param b_trans Matrix B to be transposed
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SegmentMM(const NDArray A, void SegmentMM(const NDArray A,
const NDArray B, const NDArray B,
NDArray C, NDArray C,
const NDArray seglen_A, const NDArray seglen_A,
bool a_trans, bool b_trans) { bool a_trans, bool b_trans) {
SWITCH_BITS(bits, DType, { auto device = runtime::DeviceAPI::Get(A->ctx);
auto device = runtime::DeviceAPI::Get(A->ctx); cudaStream_t stream = runtime::getCurrentCUDAStream();
cudaStream_t stream = runtime::getCurrentCUDAStream(); const DType *A_data = A.Ptr<DType>();
const DType *A_data = A.Ptr<DType>(); const DType *B_data = B.Ptr<DType>();
const DType *B_data = B.Ptr<DType>(); const IdType* seglen_A_data = seglen_A.Ptr<IdType>();
const IdType* seglen_A_data = seglen_A.Ptr<IdType>(); DType *C_data = C.Ptr<DType>();
DType *C_data = C.Ptr<DType>(); int64_t A_offset = 0, B_offset = 0, C_offset = 0;
int64_t A_offset = 0, B_offset = 0, C_offset = 0; int64_t m, n, k;
int64_t m, n, k; int64_t num_rel = seglen_A.NumElements();
int64_t num_rel = seglen_A.NumElements(); DType alpha = 1., beta = 0.;
DType alpha = 1., beta = 0.;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
if (!thr_entry->cublas_handle) if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle))); CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream)); CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));
IdType m_offset = 0; IdType m_offset = 0;
for (IdType etype = 0; etype < num_rel; ++etype) { for (IdType etype = 0; etype < num_rel; ++etype) {
m = seglen_A_data[etype]; // rows of A m = seglen_A_data[etype]; // rows of A
CHECK_LE(m_offset + m, A->shape[0]) << "Segment index out of bound of A->shape[0]."; CHECK_LE(m_offset + m, A->shape[0]) << "Segment index out of bound of A->shape[0].";
n = B->shape[2]; // cols of B n = B->shape[2]; // cols of B
k = B->shape[1]; // cols of A == rows of B k = B->shape[1]; // cols of A == rows of B
int ldb = n, lda = k, ldc = n; int ldb = n, lda = k, ldc = n;
cublasOperation_t transB = CUBLAS_OP_N; cublasOperation_t transB = CUBLAS_OP_N;
cublasOperation_t transA = CUBLAS_OP_N; cublasOperation_t transA = CUBLAS_OP_N;
if (b_trans) { if (b_trans) {
transB = CUBLAS_OP_T; transB = CUBLAS_OP_T;
ldb = n, lda = n, ldc = k; ldb = n, lda = n, ldc = k;
std::swap(n, k); std::swap(n, k);
}
CUBLAS_CALL(cublasGemm<DType>(
thr_entry->cublas_handle,
transB,
transA,
n, m, k,
&alpha,
B_data + B_offset, ldb,
A_data + A_offset, lda,
&beta,
C_data + C_offset, ldc));
A_offset += m * k;
B_offset += k * n;
C_offset += m * n;
m_offset += m;
} }
}); CUBLAS_CALL(cublasGemm<DType>(
thr_entry->cublas_handle,
transB,
transA,
n, m, k,
&alpha,
B_data + B_offset, ldb,
A_data + A_offset, lda,
&beta,
C_data + C_offset, ldc));
A_offset += m * k;
B_offset += k * n;
C_offset += m * n;
m_offset += m;
}
} }
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SegmentMMBackwardB(const NDArray A, void SegmentMMBackwardB(const NDArray A,
const NDArray dC, const NDArray dC,
NDArray dB, NDArray dB,
const NDArray seglen) { const NDArray seglen) {
SWITCH_BITS(bits, DType, { auto device = runtime::DeviceAPI::Get(A->ctx);
auto device = runtime::DeviceAPI::Get(A->ctx); cudaStream_t stream = runtime::getCurrentCUDAStream();
cudaStream_t stream = runtime::getCurrentCUDAStream(); const DType *A_data = A.Ptr<DType>();
const DType *A_data = A.Ptr<DType>(); const DType *dC_data = dC.Ptr<DType>();
const DType *dC_data = dC.Ptr<DType>(); const IdType* seglen_data = seglen.Ptr<IdType>();
const IdType* seglen_data = seglen.Ptr<IdType>(); DType *dB_data = dB.Ptr<DType>();
DType *dB_data = dB.Ptr<DType>(); int64_t A_offset = 0, dC_offset = 0, dB_offset = 0;
int64_t A_offset = 0, dC_offset = 0, dB_offset = 0; int64_t m, n, k;
int64_t m, n, k; int64_t num_rel = seglen.NumElements();
int64_t num_rel = seglen.NumElements(); DType alpha = 1., beta = 1.;
DType alpha = 1., beta = 1.;
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
if (!thr_entry->cublas_handle) if (!thr_entry->cublas_handle)
CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle))); CUBLAS_CALL(cublasCreate(&(thr_entry->cublas_handle)));
CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream)); CUBLAS_CALL(cublasSetStream(thr_entry->cublas_handle, stream));
IdType k_offset = 0; IdType k_offset = 0;
for (IdType etype = 0; etype < num_rel; ++etype) { for (IdType etype = 0; etype < num_rel; ++etype) {
m = dC->shape[1]; m = dC->shape[1];
n = A->shape[1]; n = A->shape[1];
k = seglen_data[etype]; k = seglen_data[etype];
CHECK_LE(k_offset + k, A->shape[0]) << "Segement index out of bound of A->shape[0]."; CHECK_LE(k_offset + k, A->shape[0]) << "Segement index out of bound of A->shape[0].";
int lddC = m, ldA = n, lddB = m; int lddC = m, ldA = n, lddB = m;
cublasOperation_t trans_dC = CUBLAS_OP_N; cublasOperation_t trans_dC = CUBLAS_OP_N;
cublasOperation_t trans_A = CUBLAS_OP_T; cublasOperation_t trans_A = CUBLAS_OP_T;
CUBLAS_CALL(cublasGemm<DType>( CUBLAS_CALL(cublasGemm<DType>(
thr_entry->cublas_handle, thr_entry->cublas_handle,
trans_dC, trans_dC,
trans_A, trans_A,
m, n, k, m, n, k,
&alpha, &alpha,
dC_data + dC_offset, lddC, dC_data + dC_offset, lddC,
A_data + A_offset, ldA, A_data + A_offset, ldA,
&beta, &beta,
dB_data + dB_offset, lddB)); dB_data + dB_offset, lddB));
dC_offset += m * k; dC_offset += m * k;
A_offset += n * k; A_offset += n * k;
dB_offset += m * n; dB_offset += m * n;
k_offset += k; k_offset += k;
} }
});
} }
/*! /*!
...@@ -306,33 +317,31 @@ void SegmentMMBackwardB(const NDArray A, ...@@ -306,33 +317,31 @@ void SegmentMMBackwardB(const NDArray A,
* \param idx_b The input vector to gather right hand operand on * \param idx_b The input vector to gather right hand operand on
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void GatherMM(const NDArray A, void GatherMM(const NDArray A,
const NDArray B, const NDArray B,
NDArray C, NDArray C,
const NDArray idx_a, const NDArray idx_a,
const NDArray idx_b) { const NDArray idx_b) {
SWITCH_BITS(bits, DType, { auto device = runtime::DeviceAPI::Get(A->ctx);
auto device = runtime::DeviceAPI::Get(A->ctx); cudaStream_t stream = runtime::getCurrentCUDAStream();
cudaStream_t stream = runtime::getCurrentCUDAStream(); int64_t out_len = B->shape[2]; // cols of B
int64_t out_len = B->shape[2]; // cols of B int64_t in_len = A->shape[1]; // cols of A
int64_t in_len = A->shape[1]; // cols of A const int64_t tot_num_rows = A->shape[0];
const int64_t tot_num_rows = A->shape[0]; const int ntx = 128;
const int ntx = 128; const int warp_size = 32;
const int warp_size = 32; const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);
const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx); const dim3 nblks(nbx);
const dim3 nblks(nbx); const dim3 nthrs(ntx);
const dim3 nthrs(ntx); CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel<IdType, DType>),
CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel<IdType, DType>), nblks, nthrs, 0, stream,
nblks, nthrs, 0, stream, A.Ptr<DType>(),
A.Ptr<DType>(), B.Ptr<DType>(),
B.Ptr<DType>(), C.Ptr<DType>(),
C.Ptr<DType>(), idx_a.Ptr<IdType>(),
idx_a.Ptr<IdType>(), idx_b.Ptr<IdType>(),
idx_b.Ptr<IdType>(), nullptr,
nullptr, tot_num_rows, in_len, out_len);
tot_num_rows, in_len, out_len);
});
} }
/*! /*!
...@@ -348,120 +357,147 @@ void GatherMM(const NDArray A, ...@@ -348,120 +357,147 @@ void GatherMM(const NDArray A,
* \param a_trans Matrix A to be transposed * \param a_trans Matrix A to be transposed
* \param b_trans Matrix B to be transposed * \param b_trans Matrix B to be transposed
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void GatherMMScatter(const NDArray A, void GatherMMScatter(const NDArray A,
const NDArray B, const NDArray B,
NDArray C, NDArray C,
const NDArray idx_a, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_b,
const NDArray idx_c) { const NDArray idx_c) {
SWITCH_BITS(bits, DType, { auto device = runtime::DeviceAPI::Get(A->ctx);
auto device = runtime::DeviceAPI::Get(A->ctx); cudaStream_t stream = runtime::getCurrentCUDAStream();
cudaStream_t stream = runtime::getCurrentCUDAStream(); const IdType *idx_c_data = idx_c.Ptr<IdType>();
const IdType *idx_c_data = idx_c.Ptr<IdType>(); int64_t out_len = (B->ndim == 2)? B->shape[1] : B->shape[2]; // cols of B
int64_t out_len = (B->ndim == 2)? B->shape[1] : B->shape[2]; // cols of B int64_t in_len = A->shape[1]; // cols of A
int64_t in_len = A->shape[1]; // cols of A int64_t tot_num_rows = A->shape[0];
int64_t tot_num_rows = A->shape[0]; const int ntx = 128;
const int ntx = 128; const int warp_size = 32;
const int warp_size = 32; const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx);
const int nbx = ((tot_num_rows * warp_size + ntx - 1) / ntx); const dim3 nblks(nbx);
const dim3 nblks(nbx); const dim3 nthrs(ntx);
const dim3 nthrs(ntx); if (B->ndim == 3) {
if (B->ndim == 3) { CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel<IdType, DType>),
CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel<IdType, DType>), nblks, nthrs, 0, stream,
nblks, nthrs, 0, stream, A.Ptr<DType>(),
A.Ptr<DType>(), B.Ptr<DType>(),
B.Ptr<DType>(), C.Ptr<DType>(),
C.Ptr<DType>(), idx_a.Ptr<IdType>(),
idx_a.Ptr<IdType>(), idx_b.Ptr<IdType>(),
idx_b.Ptr<IdType>(), idx_c.Ptr<IdType>(),
idx_c.Ptr<IdType>(), tot_num_rows, in_len, out_len);
tot_num_rows, in_len, out_len); } else {
} else { // Custom kernel for W_grad[idx_c[i]] = H^T[i] * C.grad[i]
// Custom kernel for W_grad[idx_c[i]] = H^T[i] * C.grad[i] // This kernel accesses rows of A in a transposed way w/o explicitly converting A
// This kernel accesses rows of A in a transposed way w/o explicitly converting A CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel2<IdType, DType>),
CUDA_KERNEL_CALL((cuda::GatherMMScatterKernel2<IdType, DType>), nblks, nthrs, 0, stream,
nblks, nthrs, 0, stream, A.Ptr<DType>(),
A.Ptr<DType>(), B.Ptr<DType>(),
B.Ptr<DType>(), C.Ptr<DType>(),
C.Ptr<DType>(), idx_a.Ptr<IdType>(),
idx_a.Ptr<IdType>(), idx_b.Ptr<IdType>(),
idx_b.Ptr<IdType>(), idx_c.Ptr<IdType>(),
idx_c.Ptr<IdType>(), tot_num_rows, in_len, out_len);
tot_num_rows, in_len, out_len); }
}
});
} }
template void GatherMM<kDGLCUDA, int32_t, __half>(
template void GatherMM<kDGLCUDA, int32_t, 16>( const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCUDA, int64_t, __half>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCUDA, int64_t, 16>( #if BF16_ENABLED
template void GatherMM<kDGLCUDA, int32_t, __nv_bfloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCUDA, int32_t, 32>( template void GatherMM<kDGLCUDA, int64_t, __nv_bfloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCUDA, int64_t, 32>( #endif // BF16_ENABLED
template void GatherMM<kDGLCUDA, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCUDA, int32_t, 64>( template void GatherMM<kDGLCUDA, int64_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCUDA, int64_t, 64>( template void GatherMM<kDGLCUDA, int32_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDGLCUDA, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b); const NDArray idx_a, const NDArray idx_b);
template void GatherMMScatter<kDGLCUDA, int32_t, 16>( template void GatherMMScatter<kDGLCUDA, int32_t, __half>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCUDA, int64_t, __half>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
#if BF16_ENABLED
template void GatherMMScatter<kDGLCUDA, int32_t, __nv_bfloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCUDA, int64_t, 16>( template void GatherMMScatter<kDGLCUDA, int64_t, __nv_bfloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCUDA, int32_t, 32>( #endif // BF16_ENABLED
template void GatherMMScatter<kDGLCUDA, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCUDA, int64_t, 32>( template void GatherMMScatter<kDGLCUDA, int64_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCUDA, int32_t, 64>( template void GatherMMScatter<kDGLCUDA, int32_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCUDA, int64_t, 64>( template void GatherMMScatter<kDGLCUDA, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c); const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void SegmentMM<kDGLCUDA, int32_t, 16>( template void SegmentMM<kDGLCUDA, int32_t, __half>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int64_t, 16>( template void SegmentMM<kDGLCUDA, int64_t, __half>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int32_t, 32>( #if BF16_ENABLED
template void SegmentMM<kDGLCUDA, int32_t, __nv_bfloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int64_t, 32>( template void SegmentMM<kDGLCUDA, int64_t, __nv_bfloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int32_t, 64>( #endif // BF16_ENABLED
template void SegmentMM<kDGLCUDA, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int64_t, 64>( template void SegmentMM<kDGLCUDA, int64_t, float>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int32_t, double>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDGLCUDA, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans); const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDGLCUDA, int32_t, 16>( template void SegmentMMBackwardB<kDGLCUDA, int32_t, __half>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCUDA, int64_t, __half>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
#if BF16_ENABLED
template void SegmentMMBackwardB<kDGLCUDA, int32_t, __nv_bfloat16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCUDA, int64_t, 16>( template void SegmentMMBackwardB<kDGLCUDA, int64_t, __nv_bfloat16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCUDA, int32_t, 32>( #endif // BF16_ENABLED
template void SegmentMMBackwardB<kDGLCUDA, int32_t, float>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCUDA, int64_t, 32>( template void SegmentMMBackwardB<kDGLCUDA, int64_t, float>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCUDA, int32_t, 64>( template void SegmentMMBackwardB<kDGLCUDA, int32_t, double>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCUDA, int64_t, 64>( template void SegmentMMBackwardB<kDGLCUDA, int64_t, double>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen); const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
} // namespace aten } // namespace aten
......
...@@ -13,7 +13,7 @@ namespace aten { ...@@ -13,7 +13,7 @@ namespace aten {
/*! /*!
* \brief CUDA implementation of g-SDDMM on Csr format. * \brief CUDA implementation of g-SDDMM on Csr format.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op, void SDDMMCsr(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const CSRMatrix& csr, const CSRMatrix& csr,
...@@ -22,11 +22,9 @@ void SDDMMCsr(const std::string& op, ...@@ -22,11 +22,9 @@ void SDDMMCsr(const std::string& op,
NDArray out, NDArray out,
int lhs_target, int lhs_target,
int rhs_target) { int rhs_target) {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
});
}); });
}); });
} }
...@@ -35,7 +33,7 @@ void SDDMMCsr(const std::string& op, ...@@ -35,7 +33,7 @@ void SDDMMCsr(const std::string& op,
/*! /*!
* \brief CUDA implementation of g-SDDMM on Coo format. * \brief CUDA implementation of g-SDDMM on Coo format.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op, void SDDMMCoo(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const COOMatrix& coo, const COOMatrix& coo,
...@@ -44,62 +42,79 @@ void SDDMMCoo(const std::string& op, ...@@ -44,62 +42,79 @@ void SDDMMCoo(const std::string& op,
NDArray out, NDArray out,
int lhs_target, int lhs_target,
int rhs_target) { int rhs_target) {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
});
}); });
}); });
} }
template void SDDMMCsr<kDGLCUDA, int32_t, __half>(
template void SDDMMCsr<kDGLCUDA, int32_t, 16>( const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int64_t, __half>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
#if BF16_ENABLED
template void SDDMMCsr<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int64_t, 16>( template void SDDMMCsr<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int32_t, 32>( #endif // BF16_ENABLED
template void SDDMMCsr<kDGLCUDA, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int64_t, 32>( template void SDDMMCsr<kDGLCUDA, int64_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int32_t, 64>( template void SDDMMCsr<kDGLCUDA, int32_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int64_t, 64>( template void SDDMMCsr<kDGLCUDA, int64_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr, const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int32_t, 16>( template void SDDMMCoo<kDGLCUDA, int32_t, __half>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int64_t, __half>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
#if BF16_ENABLED
template void SDDMMCoo<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int64_t, 16>( template void SDDMMCoo<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int32_t, 32>( #endif // BF16_ENABLED
template void SDDMMCoo<kDGLCUDA, int32_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int64_t, 32>( template void SDDMMCoo<kDGLCUDA, int64_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int32_t, 64>( template void SDDMMCoo<kDGLCUDA, int32_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int64_t, 64>( template void SDDMMCoo<kDGLCUDA, int64_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo, const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target); int lhs_target, int rhs_target);
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include "atomic.cuh" #include "atomic.cuh"
#include "functor.cuh" #include "functor.cuh"
#include "fp16.cuh" #include "fp16.cuh"
#include "bf16.cuh"
#include "./utils.h" #include "./utils.h"
#include "./functor.cuh" #include "./functor.cuh"
#include "../selector.h" #include "../selector.h"
......
...@@ -13,7 +13,7 @@ namespace aten { ...@@ -13,7 +13,7 @@ namespace aten {
* \brief CUDA implementation of g-SDDMM on heterograph using * \brief CUDA implementation of g-SDDMM on heterograph using
Csr format. Csr format.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SDDMMCooHetero(const std::string& op, void SDDMMCooHetero(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
...@@ -24,60 +24,73 @@ void SDDMMCooHetero(const std::string& op, ...@@ -24,60 +24,73 @@ void SDDMMCooHetero(const std::string& op,
int rhs_target, int rhs_target,
const std::vector<dgl_type_t>& lhs_eid, const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid) { const std::vector<dgl_type_t>& rhs_eid) {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { /* Call SDDMM CUDA kernel for each relation type sequentially */
/* Call SDDMM CUDA kernel for each relation type sequentially */ for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {
for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) { COOMatrix coo = vec_coo[etype];
COOMatrix coo = vec_coo[etype]; NDArray lhs = vec_lhs[lhs_eid[etype]];
NDArray lhs = vec_lhs[lhs_eid[etype]]; NDArray rhs = vec_rhs[rhs_eid[etype]];
NDArray rhs = vec_rhs[rhs_eid[etype]]; NDArray out = vec_out[etype];
NDArray out = vec_out[etype]; cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>( bcast, coo, lhs, rhs, out);
bcast, coo, lhs, rhs, out); }
}
});
}); });
}); });
} }
template void SDDMMCooHetero<kDGLCUDA, int32_t, __half>(
template void SDDMMCooHetero<kDGLCUDA, int32_t, 16>( const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int64_t, __half>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
#if BF16_ENABLED
template void SDDMMCooHetero<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int64_t, 16>( template void SDDMMCooHetero<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int32_t, 32>( #endif // BF16_ENABLED
template void SDDMMCooHetero<kDGLCUDA, int32_t, float>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int64_t, 32>( template void SDDMMCooHetero<kDGLCUDA, int64_t, float>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int32_t, 64>( template void SDDMMCooHetero<kDGLCUDA, int32_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int64_t, 64>( template void SDDMMCooHetero<kDGLCUDA, int64_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
......
...@@ -13,7 +13,7 @@ namespace aten { ...@@ -13,7 +13,7 @@ namespace aten {
* \brief CUDA implementation of g-SDDMM on heterograph using * \brief CUDA implementation of g-SDDMM on heterograph using
Csr format. Csr format.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SDDMMCsrHetero(const std::string& op, void SDDMMCsrHetero(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
...@@ -24,59 +24,73 @@ void SDDMMCsrHetero(const std::string& op, ...@@ -24,59 +24,73 @@ void SDDMMCsrHetero(const std::string& op,
int rhs_target, int rhs_target,
const std::vector<dgl_type_t>& lhs_eid, const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid) { const std::vector<dgl_type_t>& rhs_eid) {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, { /* Call SDDMM CUDA kernel for each relation type sequentially */
/* Call SDDMM CUDA kernel for each relation type sequentially */ for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {
for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) { CSRMatrix csr = vec_csr[etype];
CSRMatrix csr = vec_csr[etype]; NDArray lhs = vec_lhs[lhs_eid[etype]];
NDArray lhs = vec_lhs[lhs_eid[etype]]; NDArray rhs = vec_rhs[rhs_eid[etype]];
NDArray rhs = vec_rhs[rhs_eid[etype]]; NDArray out = vec_out[etype];
NDArray out = vec_out[etype]; cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>( bcast, csr, lhs, rhs, out);
bcast, csr, lhs, rhs, out); }
}
});
}); });
}); });
} }
template void SDDMMCsrHetero<kDGLCUDA, int32_t, 16>( template void SDDMMCsrHetero<kDGLCUDA, int32_t, __half>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int64_t, 16>( template void SDDMMCsrHetero<kDGLCUDA, int64_t, __half>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int32_t, 32>( #if BF16_ENABLED
template void SDDMMCsrHetero<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int64_t, 32>( template void SDDMMCsrHetero<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int32_t, 64>( #endif // BF16_ENABLED
template void SDDMMCsrHetero<kDGLCUDA, int32_t, float>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target, std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid); const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int64_t, 64>( template void SDDMMCsrHetero<kDGLCUDA, int64_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int32_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int64_t, double>(
const std::string& op, const BcastOff& bcast, const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs, const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
......
...@@ -17,169 +17,206 @@ using namespace cuda; ...@@ -17,169 +17,206 @@ using namespace cuda;
namespace aten { namespace aten {
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SegmentReduce(const std::string& op, void SegmentReduce(const std::string& op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg) { NDArray arg) {
SWITCH_BITS(bits, DType, { if (op == "sum") {
if (op == "sum") { cuda::SegmentReduce<IdType, DType, cuda::reduce::Sum<IdType, DType>>(
cuda::SegmentReduce<IdType, DType, cuda::reduce::Sum<IdType, DType>>( feat, offsets, out, arg);
feat, offsets, out, arg); } else if (op == "max") {
} else if (op == "max") { cuda::SegmentReduce<IdType, DType, cuda::reduce::Max<IdType, DType>>(
cuda::SegmentReduce<IdType, DType, cuda::reduce::Max<IdType, DType>>( feat, offsets, out, arg);
feat, offsets, out, arg); } else if (op == "min") {
} else if (op == "min") { cuda::SegmentReduce<IdType, DType, cuda::reduce::Min<IdType, DType>>(
cuda::SegmentReduce<IdType, DType, cuda::reduce::Min<IdType, DType>>( feat, offsets, out, arg);
feat, offsets, out, arg); } else {
} else { LOG(FATAL) << "Not implemented";
LOG(FATAL) << "Not implemented"; }
}
});
} }
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void ScatterAdd(NDArray feat, void ScatterAdd(NDArray feat,
NDArray idx, NDArray idx,
NDArray out) { NDArray out) {
SWITCH_BITS(bits, DType, { cuda::ScatterAdd<IdType, DType>(feat, idx, out);
cuda::ScatterAdd<IdType, DType>(feat, idx, out);
});
} }
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g, void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
const std::string& op, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& feat,
const std::vector<NDArray>& idx, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out) { std::vector<NDArray>* out) {
SWITCH_BITS(bits, DType, { cuda::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);
cuda::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);
});
} }
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void BackwardSegmentCmp(NDArray feat, void BackwardSegmentCmp(NDArray feat,
NDArray arg, NDArray arg,
NDArray out) { NDArray out) {
SWITCH_BITS(bits, DType, { cuda::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
cuda::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
});
} }
template void SegmentReduce<kDGLCUDA, int32_t, 16>( template void SegmentReduce<kDGLCUDA, int32_t, __half>(
const std::string& op, const std::string& op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCUDA, int64_t, 16>( template void SegmentReduce<kDGLCUDA, int64_t, __half>(
const std::string &op, const std::string &op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCUDA, int32_t, 32>( #if BF16_ENABLED
template void SegmentReduce<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const std::string& op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCUDA, int64_t, 32>( template void SegmentReduce<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string &op, const std::string &op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCUDA, int32_t, 64>( #endif // BF16_ENABLED
template void SegmentReduce<kDGLCUDA, int32_t, float>(
const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCUDA, int64_t, float>(
const std::string &op, const std::string &op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void SegmentReduce<kDGLCUDA, int64_t, 64>( template void SegmentReduce<kDGLCUDA, int32_t, double>(
const std::string &op, const std::string &op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
NDArray out, NDArray out,
NDArray arg); NDArray arg);
template void ScatterAdd<kDGLCUDA, int32_t, 16>( template void SegmentReduce<kDGLCUDA, int64_t, double>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void ScatterAdd<kDGLCUDA, int32_t, __half>(
NDArray feat, NDArray feat,
NDArray idx, NDArray idx,
NDArray out); NDArray out);
template void ScatterAdd<kDGLCUDA, int64_t, 16>( template void ScatterAdd<kDGLCUDA, int64_t, __half>(
NDArray feat, NDArray feat,
NDArray idx, NDArray idx,
NDArray out); NDArray out);
template void ScatterAdd<kDGLCUDA, int32_t, 32>( #if BF16_ENABLED
template void ScatterAdd<kDGLCUDA, int32_t, __nv_bfloat16>(
NDArray feat, NDArray feat,
NDArray idx, NDArray idx,
NDArray out); NDArray out);
template void ScatterAdd<kDGLCUDA, int64_t, 32>( template void ScatterAdd<kDGLCUDA, int64_t, __nv_bfloat16>(
NDArray feat, NDArray feat,
NDArray idx, NDArray idx,
NDArray out); NDArray out);
template void ScatterAdd<kDGLCUDA, int32_t, 64>( #endif // BF16_ENABLED
template void ScatterAdd<kDGLCUDA, int32_t, float>(
NDArray feat, NDArray feat,
NDArray idx, NDArray idx,
NDArray out); NDArray out);
template void ScatterAdd<kDGLCUDA, int64_t, 64>( template void ScatterAdd<kDGLCUDA, int64_t, float>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCUDA, int32_t, double>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCUDA, int64_t, double>(
NDArray feat, NDArray feat,
NDArray idx, NDArray idx,
NDArray out); NDArray out);
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, 16>( template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, __half>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, __half>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
#if BF16_ENABLED
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, __nv_bfloat16>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, 16>( template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, __nv_bfloat16>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, 32>( #endif // BF16_ENABLED
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, float>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, 32>( template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, float>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, 64>( template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, double>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, 64>( template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, double>(
const HeteroGraphPtr& g, const std::string& op, const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx, const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out); const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDGLCUDA, int32_t, 16>( template void BackwardSegmentCmp<kDGLCUDA, int32_t, __half>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int64_t, __half>(
NDArray feat,
NDArray arg,
NDArray out);
#if BF16_ENABLED
template void BackwardSegmentCmp<kDGLCUDA, int32_t, __nv_bfloat16>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int64_t, 16>( template void BackwardSegmentCmp<kDGLCUDA, int64_t, __nv_bfloat16>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int32_t, 32>( #endif // BF16_ENABLED
template void BackwardSegmentCmp<kDGLCUDA, int32_t, float>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int64_t, 32>( template void BackwardSegmentCmp<kDGLCUDA, int64_t, float>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int32_t, 64>( template void BackwardSegmentCmp<kDGLCUDA, int32_t, double>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int64_t, 64>( template void BackwardSegmentCmp<kDGLCUDA, int64_t, double>(
NDArray feat, NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
......
...@@ -15,30 +15,12 @@ using namespace cuda; ...@@ -15,30 +15,12 @@ using namespace cuda;
namespace aten { namespace aten {
/*!
* \brief Determine whether cusparse SpMM function is applicable.
*/
template <int bits, typename IdType>
inline bool cusparse_available(bool more_nnz_than_matrix_size) {
#if CUDART_VERSION < 11000
if (std::is_same<IdType, int>::value)
if (bits > 16)
return true;
return false;
#else
if (bits == 16)
return false; // cusparse's SpMM on fp16 is slow, temporally disabled.
// If the CSR matrix has more NNZ than matrix size, we should not use cuSPARSE 11.1.
return !more_nnz_than_matrix_size;
#endif
}
/*! /*!
* \brief CUDA implementation of g-SpMM on Csr format. * \brief CUDA implementation of g-SpMM on Csr format.
* \note use cusparse if the reduce operator is `sum` and there is * \note use cusparse if the reduce operator is `sum` and there is
* no broadcast, use dgl's kernel in other cases. * no broadcast, use dgl's kernel in other cases.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SpMMCsr(const std::string& op, const std::string& reduce, void SpMMCsr(const std::string& op, const std::string& reduce,
const BcastOff& bcast, const BcastOff& bcast,
const CSRMatrix& csr, const CSRMatrix& csr,
...@@ -51,58 +33,46 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -51,58 +33,46 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
if (reduce == "sum") { if (reduce == "sum") {
bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols); bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
if (op == "copy_lhs" && cusparse_available<bits, IdType>(more_nnz)) { if (op == "copy_lhs" && cusparse_available<DType, IdType>(more_nnz)) {
// cusparse // cusparse
int64_t x_length = 1; int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i) for (int i = 1; i < ufeat->ndim; ++i)
x_length *= ufeat->shape[i]; x_length *= ufeat->shape[i];
SWITCH_BITS(bits, DType, { CusparseCsrmm2<DType, IdType>(
CusparseCsrmm2<DType, IdType>( ufeat->ctx, csr,
ufeat->ctx, csr, static_cast<DType*>(ufeat->data),
static_cast<DType*>(ufeat->data), nullptr,
nullptr, static_cast<DType*>(out->data),
static_cast<DType*>(out->data), x_length);
x_length); } else if (op == "mul" && is_scalar_efeat && cusparse_available<DType, IdType>(more_nnz)) {
});
} else if (op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>(more_nnz)) {
// cusparse // cusparse
int64_t x_length = 1; int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i) for (int i = 1; i < ufeat->ndim; ++i)
x_length *= ufeat->shape[i]; x_length *= ufeat->shape[i];
if (!IsNullArray(csr.data)) { if (!IsNullArray(csr.data)) {
SWITCH_BITS(bits, DType, { efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
});
} }
SWITCH_BITS(bits, DType, { CusparseCsrmm2<DType, IdType>(
CusparseCsrmm2<DType, IdType>( ufeat->ctx, csr,
ufeat->ctx, csr, static_cast<DType*>(ufeat->data),
static_cast<DType*>(ufeat->data), static_cast<DType*>(efeat->data),
static_cast<DType*>(efeat->data), static_cast<DType*>(out->data),
static_cast<DType*>(out->data), x_length);
x_length);
});
} else { // general kernel } else { // general kernel
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >( bcast, csr, ufeat, efeat, out, NullArray(), NullArray());
bcast, csr, ufeat, efeat, out, NullArray(), NullArray());
});
}); });
} }
} else if (reduce == "max") { } else if (reduce == "max") {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >( bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
}); });
} else if (reduce == "min") { } else if (reduce == "min") {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >( bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
}); });
} else { } else {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
...@@ -113,7 +83,7 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -113,7 +83,7 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
/*! /*!
* \brief CUDA implementation of g-SpMM on Coo format. * \brief CUDA implementation of g-SpMM on Coo format.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SpMMCoo(const std::string& op, const std::string& reduce, void SpMMCoo(const std::string& op, const std::string& reduce,
const BcastOff& bcast, const BcastOff& bcast,
const COOMatrix& coo, const COOMatrix& coo,
...@@ -122,82 +92,94 @@ void SpMMCoo(const std::string& op, const std::string& reduce, ...@@ -122,82 +92,94 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
NDArray out, NDArray out,
std::vector<NDArray> out_aux) { std::vector<NDArray> out_aux) {
if (reduce == "sum") { if (reduce == "sum") {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Sum<IdType, DType, true> > (
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Sum<IdType, DType, true> > ( bcast, coo, ufeat, efeat, out, NullArray(), NullArray());
bcast, coo, ufeat, efeat, out, NullArray(), NullArray());
});
}); });
} else if (reduce == "max") { } else if (reduce == "max") {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Max<IdType, DType, true> > (
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Max<IdType, DType, true> > ( bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
}); });
} else if (reduce == "min") { } else if (reduce == "min") {
SWITCH_BITS(bits, DType, { SWITCH_OP(op, Op, {
SWITCH_OP(op, Op, { cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Min<IdType, DType, true> > (
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Min<IdType, DType, true> > ( bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
}); });
} else { } else {
LOG(FATAL) << "Not implemented"; LOG(FATAL) << "Not implemented";
} }
} }
template void SpMMCsr<kDGLCUDA, int32_t, 16>( template void SpMMCsr<kDGLCUDA, int32_t, __half>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int64_t, 16>( template void SpMMCsr<kDGLCUDA, int64_t, __half>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int32_t, 32>( #if BF16_ENABLED
template void SpMMCsr<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int64_t, 32>( template void SpMMCsr<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int32_t, 64>( #endif // BF16_ENABLED
template void SpMMCsr<kDGLCUDA, int32_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int64_t, 64>( template void SpMMCsr<kDGLCUDA, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int64_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int32_t, __half>(
template void SpMMCoo<kDGLCUDA, int32_t, 16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int64_t, 16>( template void SpMMCoo<kDGLCUDA, int64_t, __half>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int32_t, 32>( #if BF16_ENABLED
template void SpMMCoo<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int64_t, 32>( template void SpMMCoo<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int32_t, 64>( #endif // BF16_ENABLED
template void SpMMCoo<kDGLCUDA, int32_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int64_t, 64>( template void SpMMCoo<kDGLCUDA, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int64_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo, const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux); NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <limits> #include <limits>
#include "macro.cuh" #include "macro.cuh"
#include "fp16.cuh" #include "fp16.cuh"
#include "bf16.cuh"
#include "atomic.cuh" #include "atomic.cuh"
#include "../../runtime/cuda/cuda_common.h" #include "../../runtime/cuda/cuda_common.h"
#include "./utils.h" #include "./utils.h"
...@@ -20,6 +21,24 @@ using namespace cuda; ...@@ -20,6 +21,24 @@ using namespace cuda;
namespace aten { namespace aten {
/*!
* \brief Determine whether cusparse SpMM function is applicable.
*/
template <typename DType, typename IdType>
inline bool cusparse_available(bool more_nnz_than_matrix_size) {
#if CUDART_VERSION < 11000
if (std::is_same<IdType, int>::value &&
(std::is_same<DType, float>::value || std::is_same<DType, double>::value))
return true;
return false;
#else
if (std::is_same<DType, __half>::value || std::is_same<DType, __nv_bfloat16>::value)
return false; // cusparse's SpMM on fp16 is slow, temporally disabled.
// If the CSR matrix has more NNZ than matrix size, we should not use cuSPARSE 11.1.
return !more_nnz_than_matrix_size;
#endif
}
namespace { namespace {
/*! \brief Call cuBLAS geam API for transpose operation for float and double. */ /*! \brief Call cuBLAS geam API for transpose operation for float and double. */
...@@ -33,7 +52,6 @@ cublasStatus_t Xgeam(cublasHandle_t handle, cublasOperation_t transa, ...@@ -33,7 +52,6 @@ cublasStatus_t Xgeam(cublasHandle_t handle, cublasOperation_t transa,
return CUBLAS_STATUS_EXECUTION_FAILED; return CUBLAS_STATUS_EXECUTION_FAILED;
} }
#ifdef USE_FP16
template <> template <>
cublasStatus_t Xgeam<__half>(cublasHandle_t handle, cublasOperation_t transa, cublasStatus_t Xgeam<__half>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n, cublasOperation_t transb, int m, int n,
...@@ -45,7 +63,20 @@ cublasStatus_t Xgeam<__half>(cublasHandle_t handle, cublasOperation_t transa, ...@@ -45,7 +63,20 @@ cublasStatus_t Xgeam<__half>(cublasHandle_t handle, cublasOperation_t transa,
LOG(FATAL) << "Xgeam does not support dtype half (FP16)"; LOG(FATAL) << "Xgeam does not support dtype half (FP16)";
return CUBLAS_STATUS_EXECUTION_FAILED; return CUBLAS_STATUS_EXECUTION_FAILED;
} }
#endif
#if BF16_ENABLED
template <>
cublasStatus_t Xgeam<__nv_bfloat16>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
const __nv_bfloat16* alpha, const __nv_bfloat16* A, int lda,
const __nv_bfloat16* beta, const __nv_bfloat16* B, int ldb,
__nv_bfloat16* C, int ldc) {
// TODO(ndickson): There is no cublasHgeam, so a different
// implementation would be required.
LOG(FATAL) << "Xgeam does not support dtype bfloat16 (BF16)";
return CUBLAS_STATUS_EXECUTION_FAILED;
}
#endif // BF16_ENABLED
template <> template <>
cublasStatus_t Xgeam<float>(cublasHandle_t handle, cublasOperation_t transa, cublasStatus_t Xgeam<float>(cublasHandle_t handle, cublasOperation_t transa,
...@@ -131,6 +162,21 @@ void _Transpose<half>(const half* in, half* out, ...@@ -131,6 +162,21 @@ void _Transpose<half>(const half* in, half* out,
CUDA_KERNEL_CALL(_TransposeKernel, nb, nt, 0, stream, in, out, col, row); CUDA_KERNEL_CALL(_TransposeKernel, nb, nt, 0, stream, in, out, col, row);
} }
#if BF16_ENABLED
/*
* \brief Tranpose the input matrix for data type half.
* \note cuBLAS has no geam API for bf16 data type, fallback to our kernel.
*/
template <>
void _Transpose<__nv_bfloat16>(const __nv_bfloat16* in, __nv_bfloat16* out,
int row, int col) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = FindNumThreads(row);
int nb = col;
CUDA_KERNEL_CALL(_TransposeKernel, nb, nt, 0, stream, in, out, col, row);
}
#endif // BF16_ENABLED
/* /*
* \brief * \brief
*/ */
......
...@@ -15,30 +15,12 @@ using namespace cuda; ...@@ -15,30 +15,12 @@ using namespace cuda;
namespace aten { namespace aten {
/*!
* \brief Determine whether cusparse SpMM function is applicable.
*/
template <int bits, typename IdType>
inline bool cusparse_available(bool more_nnz_than_matrix_size) {
#if CUDART_VERSION < 11000
if (std::is_same<IdType, int>::value)
if (bits > 16)
return true;
return false;
#else
if (bits == 16)
return false; // cusparse's SpMM on fp16 is slow, temporally disabled.
// If the CSR matrix has more NNZ than matrix size, we should not use cuSPARSE 11.1.
return !more_nnz_than_matrix_size;
#endif
}
/*! /*!
* \brief CUDA implementation of g-SpMM on Csr format. * \brief CUDA implementation of g-SpMM on Csr format.
* \note use cusparse if the reduce operator is `sum` and there is * \note use cusparse if the reduce operator is `sum` and there is
* no broadcast, use dgl's kernel in other cases. * no broadcast, use dgl's kernel in other cases.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SpMMCsrHetero(const std::string& op, const std::string& reduce, void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const BcastOff& bcast, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
...@@ -51,192 +33,202 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -51,192 +33,202 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
bool is_scalar_efeat = vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0]; bool is_scalar_efeat = vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0];
bool use_efeat = op != "copy_lhs"; bool use_efeat = op != "copy_lhs";
auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx); auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);
SWITCH_BITS(bits, DType, { std::vector<DType*> trans_out((*vec_out).size(), NULL);
std::vector<DType*> trans_out((*vec_out).size(), NULL);
bool use_legacy_cusparsemm = bool use_legacy_cusparsemm =
(CUDART_VERSION < 11000) && (reduce == "sum") && (CUDART_VERSION < 11000) && (reduce == "sum") &&
// legacy cuSPARSE does not care about NNZ, hence the argument "false". // legacy cuSPARSE does not care about NNZ, hence the argument "false".
((op == "copy_lhs" && cusparse_available<bits, IdType>(false)) || ((op == "copy_lhs" && cusparse_available<DType, IdType>(false)) ||
(op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>(false))); (op == "mul" && is_scalar_efeat && cusparse_available<DType, IdType>(false)));
// Create temporary output buffer to store non-transposed output // Create temporary output buffer to store non-transposed output
if (use_legacy_cusparsemm) { if (use_legacy_cusparsemm) {
for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) { for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
const int m = (*vec_out)[ntype]->shape[0]; const int m = (*vec_out)[ntype]->shape[0];
const int n = (*vec_out)[ntype]->shape[1]; const int n = (*vec_out)[ntype]->shape[1];
if (m == 0) continue; if (m == 0) continue;
DType *out = static_cast<DType*>(device->AllocWorkspace(vec_csr[0].indptr->ctx, DType *out = static_cast<DType*>(device->AllocWorkspace(vec_csr[0].indptr->ctx,
m * n * sizeof(DType))); m * n * sizeof(DType)));
CUDA_CALL(cudaMemset(out, 0, m * n * sizeof(DType))); CUDA_CALL(cudaMemset(out, 0, m * n * sizeof(DType)));
trans_out[ntype] = out; trans_out[ntype] = out;
}
} }
// Check shape of ufeat for all relation type and compute feature size }
int64_t x_length = 1; // Check shape of ufeat for all relation type and compute feature size
for (dgl_type_t etype = 0; etype < (ufeat_ntids.size() - 1); ++etype) { int64_t x_length = 1;
NDArray ufeat = vec_ufeat[ufeat_ntids[etype]]; for (dgl_type_t etype = 0; etype < (ufeat_ntids.size() - 1); ++etype) {
NDArray next_ufeat = vec_ufeat[ufeat_ntids[etype + 1]]; NDArray ufeat = vec_ufeat[ufeat_ntids[etype]];
CHECK_EQ(ufeat->ndim, next_ufeat->ndim) << "Input features have different shapes"; NDArray next_ufeat = vec_ufeat[ufeat_ntids[etype + 1]];
for (int i = 1; i < ufeat->ndim; ++i) { CHECK_EQ(ufeat->ndim, next_ufeat->ndim) << "Input features have different shapes";
if (ufeat->shape[i] != next_ufeat->shape[i]) { for (int i = 1; i < ufeat->ndim; ++i) {
if (ufeat->shape[i] == 1 || next_ufeat->shape[i] == 1) if (ufeat->shape[i] != next_ufeat->shape[i]) {
LOG(FATAL) << if (ufeat->shape[i] == 1 || next_ufeat->shape[i] == 1)
"Homogenized message passing on heterogeneous graphs does not support " << LOG(FATAL) <<
"automatic broadcasting. Please manually broadcast it before calling " << "Homogenized message passing on heterogeneous graphs does not support " <<
"message passing functions."; "automatic broadcasting. Please manually broadcast it before calling " <<
else "message passing functions.";
LOG(FATAL) << "Input features have different shapes."; else
return; LOG(FATAL) << "Input features have different shapes.";
} return;
if (etype == 0)
x_length *= ufeat->shape[i];
} }
if (etype == 0)
x_length *= ufeat->shape[i];
} }
// TODO(Israt): Can python do the following initializations while creating the tensors? }
if (reduce == "max" || reduce == "min") { // TODO(Israt): Can python do the following initializations while creating the tensors?
const int64_t dim = bcast.out_len; if (reduce == "max" || reduce == "min") {
std::vector<bool> updated((*vec_out).size(), false); const int64_t dim = bcast.out_len;
for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) { std::vector<bool> updated((*vec_out).size(), false);
DType *out_off = (*vec_out)[out_ntids[etype]].Ptr<DType>(); for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
if (reduce == "max") DType *out_off = (*vec_out)[out_ntids[etype]].Ptr<DType>();
_Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Max<IdType, DType>::zero()); if (reduce == "max")
else // min _Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Max<IdType, DType>::zero());
_Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Min<IdType, DType>::zero()); else // min
const dgl_type_t dst_id = out_ntids[etype]; _Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Min<IdType, DType>::zero());
if (!updated[dst_id]) { const dgl_type_t dst_id = out_ntids[etype];
updated[dst_id] = true; if (!updated[dst_id]) {
if (op == "copy_lhs") { updated[dst_id] = true;
IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>(); if (op == "copy_lhs") {
_Fill(argu_ntype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1)); IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
} _Fill(argu_ntype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1));
if (op == "copy_rhs") { }
IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>(); if (op == "copy_rhs") {
_Fill(arge_etype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1)); IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
} _Fill(arge_etype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1));
} }
} }
} }
}
cudaStream_t stream = runtime::getCurrentCUDAStream(); cudaStream_t stream = runtime::getCurrentCUDAStream();
for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) { for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
const dgl_type_t src_id = ufeat_ntids[etype]; const dgl_type_t src_id = ufeat_ntids[etype];
const dgl_type_t dst_id = out_ntids[etype]; const dgl_type_t dst_id = out_ntids[etype];
CSRMatrix csr = vec_csr[etype]; CSRMatrix csr = vec_csr[etype];
if (reduce == "sum") { if (reduce == "sum") {
bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols); bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
/* Call SpMM for each relation type */ /* Call SpMM for each relation type */
if (op == "copy_lhs" && cusparse_available<bits, IdType>(more_nnz)) { // cusparse if (op == "copy_lhs" && cusparse_available<DType, IdType>(more_nnz)) { // cusparse
/* If CUDA is less than 11.0, put the output in trans_out for later transposition */ /* If CUDA is less than 11.0, put the output in trans_out for later transposition */
DType *out = (CUDART_VERSION < 11000) ? trans_out[dst_id] : DType *out = (CUDART_VERSION < 11000) ? trans_out[dst_id] :
static_cast<DType*>((*vec_out)[dst_id]->data); static_cast<DType*>((*vec_out)[dst_id]->data);
CusparseCsrmm2Hetero<DType, IdType>( CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr, csr.indptr->ctx, csr,
static_cast<DType*>(vec_ufeat[src_id]->data), static_cast<DType*>(vec_ufeat[src_id]->data),
nullptr, nullptr,
out, out,
x_length, stream); x_length, stream);
} else if (op == "mul" && is_scalar_efeat && } else if (op == "mul" && is_scalar_efeat &&
cusparse_available<bits, IdType>(more_nnz)) { // cusparse cusparse_available<DType, IdType>(more_nnz)) { // cusparse
NDArray efeat = vec_efeat[etype]; NDArray efeat = vec_efeat[etype];
if (!IsNullArray(csr.data)) if (!IsNullArray(csr.data))
efeat = _IndexSelect<DType, IdType>(efeat, csr.data); efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
CusparseCsrmm2Hetero<DType, IdType>( CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr, csr.indptr->ctx, csr,
static_cast<DType*>(vec_ufeat[src_id]->data), static_cast<DType*>(vec_ufeat[src_id]->data),
static_cast<DType*>(efeat->data), static_cast<DType*>(efeat->data),
// TODO(Israt): Change (*vec_out) to trans_out to support CUDA version < 11 // TODO(Israt): Change (*vec_out) to trans_out to support CUDA version < 11
static_cast<DType*>((*vec_out)[dst_id]->data), static_cast<DType*>((*vec_out)[dst_id]->data),
x_length, stream); x_length, stream);
} else { // general kernel } else { // general kernel
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(), NullArray());
});
}
} else if (reduce == "max") {
SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ? NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id]; NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ? NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype]; NullArray() : vec_efeat[etype];
SWITCH_OP(op, Op, { cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >( bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(), NullArray()); (*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
}); src_id, etype);
}
} else if (reduce == "max") {
SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
src_id, etype);
});
} else if (reduce == "min") {
SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
src_id, etype);
}); });
} else { } else if (reduce == "min") {
LOG(FATAL) << "Not implemented"; SWITCH_OP(op, Op, {
} NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
src_id, etype);
});
} else {
LOG(FATAL) << "Not implemented";
} }
}
if (use_legacy_cusparsemm) { if (use_legacy_cusparsemm) {
// transpose output // transpose output
for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) { for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
const int m = (*vec_out)[ntype]->shape[0]; const int m = (*vec_out)[ntype]->shape[0];
const int n = (*vec_out)[ntype]->shape[1]; const int n = (*vec_out)[ntype]->shape[1];
if (m == 0) continue; if (m == 0) continue;
DType *C_data = static_cast<DType*>((*vec_out)[ntype]->data); DType *C_data = static_cast<DType*>((*vec_out)[ntype]->data);
_Transpose(trans_out[ntype], C_data, n, m); _Transpose(trans_out[ntype], C_data, n, m);
device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]); device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]);
}
} }
}); }
} }
template void SpMMCsrHetero<kDGLCUDA, int32_t, 16>( template void SpMMCsrHetero<kDGLCUDA, int32_t, __half>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int64_t, 16>( template void SpMMCsrHetero<kDGLCUDA, int64_t, __half>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int32_t, 32>( #if BF16_ENABLED
template void SpMMCsrHetero<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int64_t, 32>( template void SpMMCsrHetero<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int32_t, 64>( #endif // BF16_ENABLED
template void SpMMCsrHetero<kDGLCUDA, int32_t, float>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int64_t, 64>( template void SpMMCsrHetero<kDGLCUDA, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int64_t, double>(
const std::string& op, const std::string& reduce, const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr, const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat, const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux, std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids); const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
} // namespace aten } // namespace aten
} // namespace dgl } // namespace dgl
...@@ -22,36 +22,6 @@ namespace cuda { ...@@ -22,36 +22,6 @@ namespace cuda {
// The max number of threads per block // The max number of threads per block
#define CUDA_MAX_NUM_THREADS 256 #define CUDA_MAX_NUM_THREADS 256
#ifdef USE_FP16
#define SWITCH_BITS(bits, DType, ...) \
do { \
if ((bits) == 16) { \
typedef half DType; \
{ __VA_ARGS__ } \
} else if ((bits) == 32) { \
typedef float DType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef double DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Data type not recognized with bits " << bits; \
} \
} while (0)
#else // USE_FP16
#define SWITCH_BITS(bits, DType, ...) \
do { \
if ((bits) == 32) { \
typedef float DType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef double DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Data type not recognized with bits " << bits; \
} \
} while (0)
#endif // USE_FP16
/*! \brief Calculate the number of threads needed given the dimension length. /*! \brief Calculate the number of threads needed given the dimension length.
* *
...@@ -185,6 +155,40 @@ __global__ void _LinearSearchKernel( ...@@ -185,6 +155,40 @@ __global__ void _LinearSearchKernel(
} }
} }
#if BF16_ENABLED
/*!
* \brief Specialization for bf16 because conversion from long long to bfloat16
* doesn't exist before SM80.
*/
template <typename IdType>
__global__ void _LinearSearchKernel(
const IdType* indptr, const IdType* indices, const IdType* data,
const IdType* row, const IdType* col,
int64_t row_stride, int64_t col_stride, int64_t length,
const __nv_bfloat16* weights, __nv_bfloat16 filler, __nv_bfloat16* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
int rpos = tx * row_stride, cpos = tx * col_stride;
IdType v = -1;
const IdType r = row[rpos], c = col[cpos];
for (IdType i = indptr[r]; i < indptr[r + 1]; ++i) {
if (indices[i] == c) {
v = data ? data[i] : i;
break;
}
}
if (v == -1) {
out[tx] = filler;
} else {
// If the result is saved in bf16, it should be fine to convert it to float first
out[tx] = weights ? weights[v] : __nv_bfloat16(static_cast<float>(v));
}
tx += stride_x;
}
}
#endif // BF16_ENABLED
template <typename DType> template <typename DType>
inline DType GetCUDAScalar( inline DType GetCUDAScalar(
runtime::DeviceAPI* device_api, runtime::DeviceAPI* device_api,
......
...@@ -36,13 +36,13 @@ void SpMM(const std::string& op, const std::string& reduce, ...@@ -36,13 +36,13 @@ void SpMM(const std::string& op, const std::string& reduce,
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", { ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out->dtype, bits, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", {
if (format == SparseFormat::kCSC) { if (format == SparseFormat::kCSC) {
SpMMCsr<XPU, IdType, bits>( SpMMCsr<XPU, IdType, Dtype>(
op, reduce, bcast, graph->GetCSCMatrix(0), op, reduce, bcast, graph->GetCSCMatrix(0),
ufeat, efeat, out, out_aux); ufeat, efeat, out, out_aux);
} else if (format == SparseFormat::kCOO) { } else if (format == SparseFormat::kCOO) {
SpMMCoo<XPU, IdType, bits>( SpMMCoo<XPU, IdType, Dtype>(
op, reduce, bcast, graph->GetCOOMatrix(0), op, reduce, bcast, graph->GetCOOMatrix(0),
ufeat, efeat, out, out_aux); ufeat, efeat, out, out_aux);
} else { } else {
...@@ -76,8 +76,8 @@ void SegmentMM(const NDArray A, ...@@ -76,8 +76,8 @@ void SegmentMM(const NDArray A,
CHECK(A->ctx == B->ctx) << "segment_mm expects A and B to be of the same device"; CHECK(A->ctx == B->ctx) << "segment_mm expects A and B to be of the same device";
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "SegmentMM", { ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "SegmentMM", {
ATEN_ID_TYPE_SWITCH(seglen_A->dtype, IdType, { ATEN_ID_TYPE_SWITCH(seglen_A->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(A->dtype, bits, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", {
SegmentMM<XPU, IdType, bits>(A, B, C, seglen_A, A_trans, B_trans); SegmentMM<XPU, IdType, Dtype>(A, B, C, seglen_A, A_trans, B_trans);
}); });
}); });
}); });
...@@ -94,8 +94,8 @@ void SegmentMMBackwardB(const NDArray A, ...@@ -94,8 +94,8 @@ void SegmentMMBackwardB(const NDArray A,
<< "segment_mm expects seglen to be on CPU."; << "segment_mm expects seglen to be on CPU.";
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "SegmentMMBackwardB", { ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "SegmentMMBackwardB", {
ATEN_ID_TYPE_SWITCH(seglen->dtype, IdType, { ATEN_ID_TYPE_SWITCH(seglen->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(A->dtype, bits, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", {
SegmentMMBackwardB<XPU, IdType, bits>(A, dC, dB, seglen); SegmentMMBackwardB<XPU, IdType, Dtype>(A, dC, dB, seglen);
}); });
}); });
}); });
...@@ -131,8 +131,8 @@ void GatherMM(const NDArray A, ...@@ -131,8 +131,8 @@ void GatherMM(const NDArray A,
const auto idtype = aten::IsNullArray(idx_a)? idx_b->dtype : idx_a->dtype; const auto idtype = aten::IsNullArray(idx_a)? idx_b->dtype : idx_a->dtype;
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", { ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", {
ATEN_ID_TYPE_SWITCH(idtype, IdType, { ATEN_ID_TYPE_SWITCH(idtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(A->dtype, bits, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", {
GatherMM<XPU, IdType, bits>(A, B, C, idx_a, idx_b); GatherMM<XPU, IdType, Dtype>(A, B, C, idx_a, idx_b);
}); });
}); });
}); });
...@@ -171,8 +171,8 @@ void GatherMMScatter(const NDArray A, ...@@ -171,8 +171,8 @@ void GatherMMScatter(const NDArray A,
} }
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", { ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", {
ATEN_ID_TYPE_SWITCH(idx_c->dtype, IdType, { ATEN_ID_TYPE_SWITCH(idx_c->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(A->dtype, bits, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", {
GatherMMScatter<XPU, IdType, bits>(A, B, C, idx_a, idx_b, idx_c); GatherMMScatter<XPU, IdType, Dtype>(A, B, C, idx_a, idx_b, idx_c);
}); });
}); });
}); });
...@@ -210,9 +210,9 @@ void SpMMHetero(const std::string& op, const std::string& reduce, ...@@ -210,9 +210,9 @@ void SpMMHetero(const std::string& op, const std::string& reduce,
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", { ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH((*out)[out_eid[0]]->dtype, bits, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS((*out)[out_eid[0]]->dtype, Dtype, XPU, "Feature data", {
if (format == SparseFormat::kCSC) { if (format == SparseFormat::kCSC) {
SpMMCsrHetero<XPU, IdType, bits>( SpMMCsrHetero<XPU, IdType, Dtype>(
op, reduce, bcast, vec_graph, op, reduce, bcast, vec_graph,
ufeat_vec, efeat_vec, out, out_aux, ufeat_vec, efeat_vec, out, out_aux,
ufeat_eid, out_eid); ufeat_eid, out_eid);
...@@ -241,13 +241,13 @@ void SDDMM(const std::string& op, ...@@ -241,13 +241,13 @@ void SDDMM(const std::string& op,
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", { ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out->dtype, bits, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", {
if (format == SparseFormat::kCSR) { if (format == SparseFormat::kCSR) {
SDDMMCsr<XPU, IdType, bits>( SDDMMCsr<XPU, IdType, Dtype>(
op, bcast, graph->GetCSRMatrix(0), op, bcast, graph->GetCSRMatrix(0),
lhs, rhs, out, lhs_target, rhs_target); lhs, rhs, out, lhs_target, rhs_target);
} else if (format == SparseFormat::kCOO) { } else if (format == SparseFormat::kCOO) {
SDDMMCoo<XPU, IdType, bits>( SDDMMCoo<XPU, IdType, Dtype>(
op, bcast, graph->GetCOOMatrix(0), op, bcast, graph->GetCOOMatrix(0),
lhs, rhs, out, lhs_target, rhs_target); lhs, rhs, out, lhs_target, rhs_target);
} else { } else {
...@@ -294,13 +294,13 @@ void SDDMMHetero(const std::string& op, ...@@ -294,13 +294,13 @@ void SDDMMHetero(const std::string& op,
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", { ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out[rhs_eid[0]]->dtype, bits, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(out[rhs_eid[0]]->dtype, Dtype, XPU, "Feature data", {
if (format == SparseFormat::kCSR) { if (format == SparseFormat::kCSR) {
std::vector<CSRMatrix> vec_csr; std::vector<CSRMatrix> vec_csr;
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
vec_csr.push_back(graph->GetCSRMatrix(etype)); vec_csr.push_back(graph->GetCSRMatrix(etype));
} }
SDDMMCsrHetero<XPU, IdType, bits>( SDDMMCsrHetero<XPU, IdType, Dtype>(
op, bcast, vec_csr, op, bcast, vec_csr,
lhs, rhs, out, lhs_target, rhs_target, lhs, rhs, out, lhs_target, rhs_target,
lhs_eid, rhs_eid); lhs_eid, rhs_eid);
...@@ -309,7 +309,7 @@ void SDDMMHetero(const std::string& op, ...@@ -309,7 +309,7 @@ void SDDMMHetero(const std::string& op,
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) { for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
vec_coo.push_back(graph->GetCOOMatrix(etype)); vec_coo.push_back(graph->GetCOOMatrix(etype));
} }
SDDMMCooHetero<XPU, IdType, bits>( SDDMMCooHetero<XPU, IdType, Dtype>(
op, bcast, vec_coo, op, bcast, vec_coo,
lhs, rhs, out, lhs_target, rhs_target, lhs, rhs, out, lhs_target, rhs_target,
lhs_eid, rhs_eid); lhs_eid, rhs_eid);
...@@ -333,8 +333,8 @@ void Edge_softmax_forward(const std::string& op, ...@@ -333,8 +333,8 @@ void Edge_softmax_forward(const std::string& op,
ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "edge_softmax", { ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "edge_softmax", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out->dtype, bits, "edge_softmax out data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "edge_softmax out data", {
Edge_softmax_csr_forward<XPU, IdType, bits>( Edge_softmax_csr_forward<XPU, IdType, Dtype>(
op, bcast, graph->GetCSCMatrix(0), ufeat, efeat, out); op, bcast, graph->GetCSCMatrix(0), ufeat, efeat, out);
}); });
}); });
...@@ -354,8 +354,8 @@ void Edge_softmax_backward(const std::string& op, ...@@ -354,8 +354,8 @@ void Edge_softmax_backward(const std::string& op,
ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "edge_softmax_back", { ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "edge_softmax_back", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, { ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out->dtype, bits, "edge_softmax out data_back", { ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "edge_softmax out data_back", {
Edge_softmax_csr_backward<XPU, IdType, bits>( Edge_softmax_csr_backward<XPU, IdType, Dtype>(
op, bcast, graph->GetCSCMatrix(0), out, sds, back_out); op, bcast, graph->GetCSCMatrix(0), out, sds, back_out);
}); });
}); });
...@@ -380,8 +380,8 @@ void SegmentReduceDispatch(const std::string& op, ...@@ -380,8 +380,8 @@ void SegmentReduceDispatch(const std::string& op,
NDArray arg) { NDArray arg) {
ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "SegmentReduce", { ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "SegmentReduce", {
ATEN_ID_TYPE_SWITCH(offsets->dtype, IdType, { ATEN_ID_TYPE_SWITCH(offsets->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(feat->dtype, bits, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(feat->dtype, Dtype, XPU, "Feature data", {
SegmentReduce<XPU, IdType, bits>(op, feat, offsets, out, arg); SegmentReduce<XPU, IdType, Dtype>(op, feat, offsets, out, arg);
}); });
}); });
}); });
...@@ -391,8 +391,8 @@ void SegmentReduceDispatch(const std::string& op, ...@@ -391,8 +391,8 @@ void SegmentReduceDispatch(const std::string& op,
void ScatterAddDispatch(NDArray feat, NDArray idx, NDArray out) { void ScatterAddDispatch(NDArray feat, NDArray idx, NDArray out) {
ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "ScatterAdd", { ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "ScatterAdd", {
ATEN_ID_TYPE_SWITCH(idx->dtype, IdType, { ATEN_ID_TYPE_SWITCH(idx->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(feat->dtype, bits, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(feat->dtype, Dtype, XPU, "Feature data", {
ScatterAdd<XPU, IdType, bits>(feat, idx, out); ScatterAdd<XPU, IdType, Dtype>(feat, idx, out);
}); });
}); });
}); });
...@@ -409,8 +409,8 @@ void UpdateGradMinMaxDispatchHetero(const HeteroGraphPtr& graph, ...@@ -409,8 +409,8 @@ void UpdateGradMinMaxDispatchHetero(const HeteroGraphPtr& graph,
auto src_id = pair.first; auto src_id = pair.first;
ATEN_XPU_SWITCH_CUDA(feat[src_id]->ctx.device_type, XPU, "ScatterAdd", { ATEN_XPU_SWITCH_CUDA(feat[src_id]->ctx.device_type, XPU, "ScatterAdd", {
ATEN_ID_TYPE_SWITCH(idx[src_id]->dtype, IdType, { ATEN_ID_TYPE_SWITCH(idx[src_id]->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(feat[src_id]->dtype, bits, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(feat[src_id]->dtype, Dtype, XPU, "Feature data", {
UpdateGradMinMax_hetero<XPU, IdType, bits>(graph, op, feat, idx, idx_etype, out); UpdateGradMinMax_hetero<XPU, IdType, Dtype>(graph, op, feat, idx, idx_etype, out);
}); });
}); });
}); });
...@@ -420,8 +420,8 @@ void UpdateGradMinMaxDispatchHetero(const HeteroGraphPtr& graph, ...@@ -420,8 +420,8 @@ void UpdateGradMinMaxDispatchHetero(const HeteroGraphPtr& graph,
void BackwardSegmentCmpDispatch(NDArray feat, NDArray arg, NDArray out) { void BackwardSegmentCmpDispatch(NDArray feat, NDArray arg, NDArray out) {
ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "BackwardSegmentCmp", { ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "BackwardSegmentCmp", {
ATEN_ID_TYPE_SWITCH(arg->dtype, IdType, { ATEN_ID_TYPE_SWITCH(arg->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(feat->dtype, bits, "Feature data", { ATEN_FLOAT_TYPE_SWITCH_16BITS(feat->dtype, Dtype, XPU, "Feature data", {
BackwardSegmentCmp<XPU, IdType, bits>(feat, arg, out); BackwardSegmentCmp<XPU, IdType, Dtype>(feat, arg, out);
}); });
}); });
}); });
......
...@@ -20,7 +20,7 @@ namespace aten { ...@@ -20,7 +20,7 @@ namespace aten {
/*! /*!
* \brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format. * \brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SpMMCsr(const std::string& op, const std::string& reduce, void SpMMCsr(const std::string& op, const std::string& reduce,
const BcastOff& bcast, const BcastOff& bcast,
const aten::CSRMatrix& csr, const aten::CSRMatrix& csr,
...@@ -33,7 +33,7 @@ void SpMMCsr(const std::string& op, const std::string& reduce, ...@@ -33,7 +33,7 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
* \brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format * \brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format
with heterograph support. with heterograph support.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SpMMCsrHetero(const std::string& op, const std::string& reduce, void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const BcastOff& bcast, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<CSRMatrix>& csr,
...@@ -46,7 +46,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce, ...@@ -46,7 +46,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
/*! /*!
* \brief Generalized Sparse Matrix Dense Matrix Multiplication on Coo format. * \brief Generalized Sparse Matrix Dense Matrix Multiplication on Coo format.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SpMMCoo(const std::string& op, const std::string& reduce, void SpMMCoo(const std::string& op, const std::string& reduce,
const BcastOff& bcast, const BcastOff& bcast,
const aten::COOMatrix& coo, const aten::COOMatrix& coo,
...@@ -58,7 +58,7 @@ void SpMMCoo(const std::string& op, const std::string& reduce, ...@@ -58,7 +58,7 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
/*! /*!
* \brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format. * \brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op, void SDDMMCsr(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const aten::CSRMatrix& csr, const aten::CSRMatrix& csr,
...@@ -71,7 +71,7 @@ void SDDMMCsr(const std::string& op, ...@@ -71,7 +71,7 @@ void SDDMMCsr(const std::string& op,
* \brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr * \brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr
format with heterograph support. format with heterograph support.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SDDMMCsrHetero(const std::string& op, void SDDMMCsrHetero(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<CSRMatrix>& vec_csr,
...@@ -86,7 +86,7 @@ void SDDMMCsrHetero(const std::string& op, ...@@ -86,7 +86,7 @@ void SDDMMCsrHetero(const std::string& op,
/*! /*!
* \brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format. * \brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op, void SDDMMCoo(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const aten::COOMatrix& coo, const aten::COOMatrix& coo,
...@@ -100,7 +100,7 @@ void SDDMMCoo(const std::string& op, ...@@ -100,7 +100,7 @@ void SDDMMCoo(const std::string& op,
* \brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo * \brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo
format with heterograph support. format with heterograph support.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SDDMMCooHetero(const std::string& op, void SDDMMCooHetero(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<COOMatrix>& vec_coo,
...@@ -115,7 +115,7 @@ void SDDMMCooHetero(const std::string& op, ...@@ -115,7 +115,7 @@ void SDDMMCooHetero(const std::string& op,
/*! /*!
* \brief Generalized Dense Matrix-Matrix Multiplication according to relation types. * \brief Generalized Dense Matrix-Matrix Multiplication according to relation types.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void GatherMM(const NDArray A, void GatherMM(const NDArray A,
const NDArray B, const NDArray B,
NDArray out, NDArray out,
...@@ -125,7 +125,7 @@ void GatherMM(const NDArray A, ...@@ -125,7 +125,7 @@ void GatherMM(const NDArray A,
/*! /*!
* \brief Generalized Dense Matrix-Matrix Multiplication according to relation types. * \brief Generalized Dense Matrix-Matrix Multiplication according to relation types.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void GatherMMScatter(const NDArray A, void GatherMMScatter(const NDArray A,
const NDArray B, const NDArray B,
NDArray out, NDArray out,
...@@ -136,14 +136,14 @@ void GatherMMScatter(const NDArray A, ...@@ -136,14 +136,14 @@ void GatherMMScatter(const NDArray A,
/*! /*!
* \brief Generalized segmented dense Matrix-Matrix Multiplication. * \brief Generalized segmented dense Matrix-Matrix Multiplication.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SegmentMM(const NDArray A, void SegmentMM(const NDArray A,
const NDArray B, const NDArray B,
NDArray out, NDArray out,
const NDArray seglen_A, const NDArray seglen_A,
bool a_trans, bool b_trans); bool a_trans, bool b_trans);
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SegmentMMBackwardB(const NDArray A, void SegmentMMBackwardB(const NDArray A,
const NDArray dC, const NDArray dC,
NDArray dB, NDArray dB,
...@@ -152,7 +152,7 @@ void SegmentMMBackwardB(const NDArray A, ...@@ -152,7 +152,7 @@ void SegmentMMBackwardB(const NDArray A,
/*! /*!
* \brief Segment reduce. * \brief Segment reduce.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void SegmentReduce(const std::string& op, void SegmentReduce(const std::string& op,
NDArray feat, NDArray feat,
NDArray offsets, NDArray offsets,
...@@ -162,7 +162,7 @@ void SegmentReduce(const std::string& op, ...@@ -162,7 +162,7 @@ void SegmentReduce(const std::string& op,
/*! /*!
* \brief Scatter Add on first dimension. * \brief Scatter Add on first dimension.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void ScatterAdd(NDArray feat, void ScatterAdd(NDArray feat,
NDArray idx, NDArray idx,
NDArray out); NDArray out);
...@@ -170,7 +170,7 @@ void ScatterAdd(NDArray feat, ...@@ -170,7 +170,7 @@ void ScatterAdd(NDArray feat,
/*! /*!
* \brief Update gradients for reduce operator max and min on first dimension. * \brief Update gradients for reduce operator max and min on first dimension.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g, void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
const std::string& op, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& feat,
...@@ -181,7 +181,7 @@ void UpdateGradMinMax_hetero(const HeteroGraphPtr& g, ...@@ -181,7 +181,7 @@ void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
/*! /*!
* \brief Backward function of segment cmp. * \brief Backward function of segment cmp.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void BackwardSegmentCmp(NDArray feat, void BackwardSegmentCmp(NDArray feat,
NDArray arg, NDArray arg,
NDArray out); NDArray out);
...@@ -223,7 +223,7 @@ std::pair<CSRMatrix, NDArray> CSRSum( ...@@ -223,7 +223,7 @@ std::pair<CSRMatrix, NDArray> CSRSum(
/*! /*!
* \brief Edge_softmax_csr forward function on Csr format. * \brief Edge_softmax_csr forward function on Csr format.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_forward(const std::string& op, void Edge_softmax_csr_forward(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const aten::CSRMatrix& csr, const aten::CSRMatrix& csr,
...@@ -233,7 +233,7 @@ void Edge_softmax_csr_forward(const std::string& op, ...@@ -233,7 +233,7 @@ void Edge_softmax_csr_forward(const std::string& op,
/*! /*!
* \brief Edge_softmax_csr backward function on Csr format. * \brief Edge_softmax_csr backward function on Csr format.
*/ */
template <int XPU, typename IdType, int bits> template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_backward(const std::string& op, void Edge_softmax_csr_backward(const std::string& op,
const BcastOff& bcast, const BcastOff& bcast,
const aten::CSRMatrix& csr, const aten::CSRMatrix& csr,
......
...@@ -22,9 +22,12 @@ constexpr DGLDataType DGLDataTypeTraits<int32_t>::dtype; ...@@ -22,9 +22,12 @@ constexpr DGLDataType DGLDataTypeTraits<int32_t>::dtype;
constexpr DGLDataType DGLDataTypeTraits<int64_t>::dtype; constexpr DGLDataType DGLDataTypeTraits<int64_t>::dtype;
constexpr DGLDataType DGLDataTypeTraits<uint32_t>::dtype; constexpr DGLDataType DGLDataTypeTraits<uint32_t>::dtype;
constexpr DGLDataType DGLDataTypeTraits<uint64_t>::dtype; constexpr DGLDataType DGLDataTypeTraits<uint64_t>::dtype;
#ifdef USE_FP16 #ifdef DGL_USE_CUDA
constexpr DGLDataType DGLDataTypeTraits<__half>::dtype; constexpr DGLDataType DGLDataTypeTraits<__half>::dtype;
#endif #if BF16_ENABLED
constexpr DGLDataType DGLDataTypeTraits<__nv_bfloat16>::dtype;
#endif // BF16_ENABLED
#endif // DGL_USE_CUDA
constexpr DGLDataType DGLDataTypeTraits<float>::dtype; constexpr DGLDataType DGLDataTypeTraits<float>::dtype;
constexpr DGLDataType DGLDataTypeTraits<double>::dtype; constexpr DGLDataType DGLDataTypeTraits<double>::dtype;
......
from distutils.version import LooseVersion
import random import random
import unittest import unittest
import backend as F import backend as F
import networkx as nx
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
...@@ -325,13 +325,20 @@ def test_segment_reduce(reducer): ...@@ -325,13 +325,20 @@ def test_segment_reduce(reducer):
@parametrize_idtype @parametrize_idtype
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256]) @pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"dtype,tol", "dtype, tol",
[(torch.float16, 1e-2), (torch.float32, 3e-3), (torch.float64, 1e-4)], [(torch.float16, 1e-2), (torch.bfloat16, 1e-2),
(torch.float32, 3e-3), (torch.float64, 1e-4)],
) )
def test_segment_mm(idtype, feat_size, dtype, tol): def test_segment_mm(idtype, feat_size, dtype, tol):
if F._default_context_str == "cpu" and dtype == torch.float16: if F._default_context_str == "cpu" and dtype in (torch.float16, torch.bfloat16):
pytest.skip( pytest.skip(
"fp16 support for CPU linalg functions has been removed in PyTorch." "Only support float32 and float64 on CPU."
)
if F._default_context_str == "gpu" \
and LooseVersion(torch.version.cuda) < LooseVersion("11.0") \
and dtype == torch.bfloat16:
pytest.skip(
"BF16 requires CUDA >= 11.0."
) )
dev = F.ctx() dev = F.ctx()
# input # input
...@@ -343,7 +350,7 @@ def test_segment_mm(idtype, feat_size, dtype, tol): ...@@ -343,7 +350,7 @@ def test_segment_mm(idtype, feat_size, dtype, tol):
.to(dtype) .to(dtype)
) )
b.requires_grad_() b.requires_grad_()
seglen_a = torch.tensor([10, 15, 8, 0, 1, 9, 18, 24, 15, 0]) seglen_a = torch.tensor([10, 15, 8, 0, 1, 9, 18, 24, 15, 0]).to(idtype)
dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype) dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype)
# compute # compute
c = dgl.ops.segment_mm(a, b, seglen_a) c = dgl.ops.segment_mm(a, b, seglen_a)
...@@ -371,19 +378,28 @@ def test_segment_mm(idtype, feat_size, dtype, tol): ...@@ -371,19 +378,28 @@ def test_segment_mm(idtype, feat_size, dtype, tol):
@unittest.skipIf( @unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now" dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
) )
@parametrize_idtype
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256]) @pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
def test_gather_mm_idx_b(idtype, feat_size): @pytest.mark.parametrize(
import torch "dtype, tol",
[(torch.float16, 1e-2), (torch.bfloat16, 2e-2),
(torch.float32, 3e-3), (torch.float64, 1e-4)]
)
def test_gather_mm_idx_b(feat_size, dtype, tol):
if F._default_context_str == "cpu" and dtype in (torch.float16, torch.bfloat16):
pytest.skip("Only support float32 and float64 on CPU.")
if F._default_context_str == "gpu" \
and LooseVersion(torch.version.cuda) < LooseVersion("11.0") \
and dtype == torch.bfloat16:
pytest.skip("BF16 requires CUDA >= 11.0.")
dev = F.ctx() dev = F.ctx()
# input # input
a = torch.tensor(np.random.rand(100, feat_size)).to(dev) a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype)
a.requires_grad_() a.requires_grad_()
b = torch.tensor(np.random.rand(10, feat_size, feat_size + 1)).to(dev) b = torch.tensor(np.random.rand(10, feat_size, feat_size + 1)).to(dev).to(dtype)
b.requires_grad_() b.requires_grad_()
idx = torch.tensor(np.random.randint(0, 10, 100)).to(dev).long() idx = torch.tensor(np.random.randint(0, 10, 100)).to(dev).long()
dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev) dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype)
# compute # compute
c = dgl.ops.gather_mm(a, b, idx_b=idx) c = dgl.ops.gather_mm(a, b, idx_b=idx)
c.backward(dc) c.backward(dc)
...@@ -397,9 +413,9 @@ def test_gather_mm_idx_b(idtype, feat_size): ...@@ -397,9 +413,9 @@ def test_gather_mm_idx_b(idtype, feat_size):
da_t = a.grad da_t = a.grad
db_t = b.grad db_t = b.grad
assert torch.allclose(c, c_t, atol=1e-4, rtol=1e-4) assert torch.allclose(c, c_t, atol=tol, rtol=tol)
assert torch.allclose(da, da_t, atol=1e-4, rtol=1e-4) assert torch.allclose(da, da_t, atol=tol, rtol=tol)
assert torch.allclose(db, db_t, atol=1e-4, rtol=1e-4) assert torch.allclose(db, db_t, atol=tol, rtol=tol)
@unittest.skipIf( @unittest.skipIf(
......
...@@ -25,7 +25,7 @@ if [[ $arch == *"x86"* ]]; then ...@@ -25,7 +25,7 @@ if [[ $arch == *"x86"* ]]; then
fi fi
if [[ $1 != "cpu" ]]; then if [[ $1 != "cpu" ]]; then
CMAKE_VARS="-DUSE_CUDA=ON -DUSE_NCCL=ON -DUSE_FP16=ON $CMAKE_VARS" CMAKE_VARS="-DUSE_CUDA=ON -DUSE_NCCL=ON $CMAKE_VARS"
fi fi
if [ -d build ]; then if [ -d build ]; then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment