Unverified Commit 96297fb8 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Feature] Add bfloat16 (bf16) support (#4648)

* add bf16 specializations

* remove SWITCH_BITS

* enable amp for bf16

* remove SWITCH_BITS for cpu kernels

* enbale bf16 based on CUDART

* fix compiling for sm<80

* fix cpu build

* enable unit tests

* update doc

* disable test for CUDA < 11.0

* address comments

* address comments
parent 1d229194
......@@ -10,6 +10,7 @@
#include <limits>
#include "./atomic.cuh"
#include "./fp16.cuh"
#include "bf16.cuh"
namespace dgl {
namespace aten {
......@@ -108,7 +109,7 @@ struct Dot {
static constexpr bool reduce_last_dim = true;
static __device__ __forceinline__ DType Call(
const DType *lhs, const DType *rhs, int64_t len = 1) {
DType rst = static_cast<DType>(0);
DType rst = static_cast<DType>(0.0f);
for (int64_t i = 0; i < len; ++i) {
rst += lhs[i] * rhs[i];
}
......@@ -159,14 +160,21 @@ template <typename Idx,
bool atomic = false>
struct Sum: _Sum<Idx, DType, atomic> { };
#ifdef USE_FP16
template <typename Idx, bool atomic>
struct Sum<Idx, half, atomic>: _Sum<Idx, half, atomic> {
static constexpr __host__ __device__ __forceinline__ half zero() {
return __float2half_rn(0.);
}
};
#endif // USE_FP16
#if BF16_ENABLED
template <typename Idx, bool atomic>
struct Sum<Idx, __nv_bfloat16, atomic>: _Sum<Idx, __nv_bfloat16, atomic> {
static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {
return __float2bfloat16_rn(0.);
}
};
#endif // BF16_ENABLED
template <typename Idx,
typename DType,
......@@ -220,7 +228,6 @@ template <typename Idx,
bool atomic = false>
struct Max : _Max<Idx, DType, atomic> { };
#ifdef USE_FP16
template <typename Idx,
bool atomic>
struct Max<Idx, half, atomic> : _Max<Idx, half, atomic> {
......@@ -228,7 +235,16 @@ struct Max<Idx, half, atomic> : _Max<Idx, half, atomic> {
return __float2half_rn(-6.550400e+04f);
}
};
#endif
#if BF16_ENABLED
template <typename Idx,
bool atomic>
struct Max<Idx, __nv_bfloat16, atomic> : _Max<Idx, __nv_bfloat16, atomic> {
static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {
return __float2bfloat16_rn(-std::numeric_limits<float>::infinity());
}
};
#endif // BF16_ENABLED
template <typename Idx,
typename DType,
......@@ -282,7 +298,6 @@ template <typename Idx,
bool atomic = false>
struct Min : _Min<Idx, DType, atomic> { };
#ifdef USE_FP16
template <typename Idx,
bool atomic>
struct Min<Idx, half, atomic> : _Min<Idx, half, atomic> {
......@@ -290,7 +305,16 @@ struct Min<Idx, half, atomic> : _Min<Idx, half, atomic> {
return __float2half_rn(6.550400e+04f);
}
};
#endif // USE_FP16
#if BF16_ENABLED
template <typename Idx,
bool atomic>
struct Min<Idx, __nv_bfloat16, atomic> : _Min<Idx, __nv_bfloat16, atomic> {
static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() {
return __float2bfloat16_rn(std::numeric_limits<float>::infinity());
}
};
#endif // BF16_ENABLED
} // namespace reduce
......
This diff is collapsed.
......@@ -13,7 +13,7 @@ namespace aten {
/*!
* \brief CUDA implementation of g-SDDMM on Csr format.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op,
const BcastOff& bcast,
const CSRMatrix& csr,
......@@ -22,11 +22,9 @@ void SDDMMCsr(const std::string& op,
NDArray out,
int lhs_target,
int rhs_target) {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
});
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, csr, lhs, rhs, out);
});
});
}
......@@ -35,7 +33,7 @@ void SDDMMCsr(const std::string& op,
/*!
* \brief CUDA implementation of g-SDDMM on Coo format.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op,
const BcastOff& bcast,
const COOMatrix& coo,
......@@ -44,62 +42,79 @@ void SDDMMCoo(const std::string& op,
NDArray out,
int lhs_target,
int rhs_target) {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
});
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(bcast, coo, lhs, rhs, out);
});
});
}
template void SDDMMCsr<kDGLCUDA, int32_t, 16>(
template void SDDMMCsr<kDGLCUDA, int32_t, __half>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int64_t, __half>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
#if BF16_ENABLED
template void SDDMMCsr<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int64_t, 16>(
template void SDDMMCsr<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int32_t, 32>(
#endif // BF16_ENABLED
template void SDDMMCsr<kDGLCUDA, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int64_t, 32>(
template void SDDMMCsr<kDGLCUDA, int64_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int32_t, 64>(
template void SDDMMCsr<kDGLCUDA, int32_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCUDA, int64_t, 64>(
template void SDDMMCsr<kDGLCUDA, int64_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int32_t, 16>(
template void SDDMMCoo<kDGLCUDA, int32_t, __half>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int64_t, __half>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
#if BF16_ENABLED
template void SDDMMCoo<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int64_t, 16>(
template void SDDMMCoo<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int32_t, 32>(
#endif // BF16_ENABLED
template void SDDMMCoo<kDGLCUDA, int32_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int64_t, 32>(
template void SDDMMCoo<kDGLCUDA, int64_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int32_t, 64>(
template void SDDMMCoo<kDGLCUDA, int32_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCUDA, int64_t, 64>(
template void SDDMMCoo<kDGLCUDA, int64_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
......
......@@ -11,6 +11,7 @@
#include "atomic.cuh"
#include "functor.cuh"
#include "fp16.cuh"
#include "bf16.cuh"
#include "./utils.h"
#include "./functor.cuh"
#include "../selector.h"
......
......@@ -13,7 +13,7 @@ namespace aten {
* \brief CUDA implementation of g-SDDMM on heterograph using
Csr format.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SDDMMCooHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
......@@ -24,60 +24,73 @@ void SDDMMCooHetero(const std::string& op,
int rhs_target,
const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid) {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM CUDA kernel for each relation type sequentially */
for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {
COOMatrix coo = vec_coo[etype];
NDArray lhs = vec_lhs[lhs_eid[etype]];
NDArray rhs = vec_rhs[rhs_eid[etype]];
NDArray out = vec_out[etype];
cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, coo, lhs, rhs, out);
}
});
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM CUDA kernel for each relation type sequentially */
for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {
COOMatrix coo = vec_coo[etype];
NDArray lhs = vec_lhs[lhs_eid[etype]];
NDArray rhs = vec_rhs[rhs_eid[etype]];
NDArray out = vec_out[etype];
cuda::SDDMMCoo<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, coo, lhs, rhs, out);
}
});
});
}
template void SDDMMCooHetero<kDGLCUDA, int32_t, 16>(
template void SDDMMCooHetero<kDGLCUDA, int32_t, __half>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int64_t, __half>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
#if BF16_ENABLED
template void SDDMMCooHetero<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int64_t, 16>(
template void SDDMMCooHetero<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int32_t, 32>(
#endif // BF16_ENABLED
template void SDDMMCooHetero<kDGLCUDA, int32_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int64_t, 32>(
template void SDDMMCooHetero<kDGLCUDA, int64_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int32_t, 64>(
template void SDDMMCooHetero<kDGLCUDA, int32_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCUDA, int64_t, 64>(
template void SDDMMCooHetero<kDGLCUDA, int64_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
......
......@@ -13,7 +13,7 @@ namespace aten {
* \brief CUDA implementation of g-SDDMM on heterograph using
Csr format.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SDDMMCsrHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
......@@ -24,59 +24,73 @@ void SDDMMCsrHetero(const std::string& op,
int rhs_target,
const std::vector<dgl_type_t>& lhs_eid,
const std::vector<dgl_type_t>& rhs_eid) {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM CUDA kernel for each relation type sequentially */
for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {
CSRMatrix csr = vec_csr[etype];
NDArray lhs = vec_lhs[lhs_eid[etype]];
NDArray rhs = vec_rhs[rhs_eid[etype]];
NDArray out = vec_out[etype];
cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, csr, lhs, rhs, out);
}
});
SWITCH_OP(op, Op, {
SWITCH_TARGET(lhs_target, rhs_target, LhsTarget, RhsTarget, {
/* Call SDDMM CUDA kernel for each relation type sequentially */
for (dgl_type_t etype = 0; etype < lhs_eid.size(); ++etype) {
CSRMatrix csr = vec_csr[etype];
NDArray lhs = vec_lhs[lhs_eid[etype]];
NDArray rhs = vec_rhs[rhs_eid[etype]];
NDArray out = vec_out[etype];
cuda::SDDMMCsr<IdType, DType, Op, LhsTarget, RhsTarget>(
bcast, csr, lhs, rhs, out);
}
});
});
}
template void SDDMMCsrHetero<kDGLCUDA, int32_t, 16>(
template void SDDMMCsrHetero<kDGLCUDA, int32_t, __half>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int64_t, 16>(
template void SDDMMCsrHetero<kDGLCUDA, int64_t, __half>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int32_t, 32>(
#if BF16_ENABLED
template void SDDMMCsrHetero<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int64_t, 32>(
template void SDDMMCsrHetero<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int32_t, 64>(
#endif // BF16_ENABLED
template void SDDMMCsrHetero<kDGLCUDA, int32_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int64_t, 64>(
template void SDDMMCsrHetero<kDGLCUDA, int64_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int32_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCUDA, int64_t, double>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
......
......@@ -17,169 +17,206 @@ using namespace cuda;
namespace aten {
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SegmentReduce(const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg) {
SWITCH_BITS(bits, DType, {
if (op == "sum") {
cuda::SegmentReduce<IdType, DType, cuda::reduce::Sum<IdType, DType>>(
feat, offsets, out, arg);
} else if (op == "max") {
cuda::SegmentReduce<IdType, DType, cuda::reduce::Max<IdType, DType>>(
feat, offsets, out, arg);
} else if (op == "min") {
cuda::SegmentReduce<IdType, DType, cuda::reduce::Min<IdType, DType>>(
feat, offsets, out, arg);
} else {
LOG(FATAL) << "Not implemented";
}
});
if (op == "sum") {
cuda::SegmentReduce<IdType, DType, cuda::reduce::Sum<IdType, DType>>(
feat, offsets, out, arg);
} else if (op == "max") {
cuda::SegmentReduce<IdType, DType, cuda::reduce::Max<IdType, DType>>(
feat, offsets, out, arg);
} else if (op == "min") {
cuda::SegmentReduce<IdType, DType, cuda::reduce::Min<IdType, DType>>(
feat, offsets, out, arg);
} else {
LOG(FATAL) << "Not implemented";
}
}
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void ScatterAdd(NDArray feat,
NDArray idx,
NDArray out) {
SWITCH_BITS(bits, DType, {
cuda::ScatterAdd<IdType, DType>(feat, idx, out);
});
cuda::ScatterAdd<IdType, DType>(feat, idx, out);
}
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
const std::string& op,
const std::vector<NDArray>& feat,
const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype,
std::vector<NDArray>* out) {
SWITCH_BITS(bits, DType, {
cuda::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);
});
cuda::UpdateGradMinMax_hetero<IdType, DType>(g, op, feat, idx, idx_etype, out);
}
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void BackwardSegmentCmp(NDArray feat,
NDArray arg,
NDArray out) {
SWITCH_BITS(bits, DType, {
cuda::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
});
cuda::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
}
template void SegmentReduce<kDGLCUDA, int32_t, 16>(
template void SegmentReduce<kDGLCUDA, int32_t, __half>(
const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCUDA, int64_t, 16>(
template void SegmentReduce<kDGLCUDA, int64_t, __half>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCUDA, int32_t, 32>(
#if BF16_ENABLED
template void SegmentReduce<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCUDA, int64_t, 32>(
template void SegmentReduce<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCUDA, int32_t, 64>(
#endif // BF16_ENABLED
template void SegmentReduce<kDGLCUDA, int32_t, float>(
const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCUDA, int64_t, float>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCUDA, int64_t, 64>(
template void SegmentReduce<kDGLCUDA, int32_t, double>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void ScatterAdd<kDGLCUDA, int32_t, 16>(
template void SegmentReduce<kDGLCUDA, int64_t, double>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void ScatterAdd<kDGLCUDA, int32_t, __half>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCUDA, int64_t, 16>(
template void ScatterAdd<kDGLCUDA, int64_t, __half>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCUDA, int32_t, 32>(
#if BF16_ENABLED
template void ScatterAdd<kDGLCUDA, int32_t, __nv_bfloat16>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCUDA, int64_t, 32>(
template void ScatterAdd<kDGLCUDA, int64_t, __nv_bfloat16>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCUDA, int32_t, 64>(
#endif // BF16_ENABLED
template void ScatterAdd<kDGLCUDA, int32_t, float>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCUDA, int64_t, 64>(
template void ScatterAdd<kDGLCUDA, int64_t, float>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCUDA, int32_t, double>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDGLCUDA, int64_t, double>(
NDArray feat,
NDArray idx,
NDArray out);
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, 16>(
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, __half>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, __half>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
#if BF16_ENABLED
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, __nv_bfloat16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, 16>(
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, __nv_bfloat16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, 32>(
#endif // BF16_ENABLED
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, float>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, 32>(
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, float>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, 64>(
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, double>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, 64>(
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, double>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDGLCUDA, int32_t, 16>(
template void BackwardSegmentCmp<kDGLCUDA, int32_t, __half>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int64_t, __half>(
NDArray feat,
NDArray arg,
NDArray out);
#if BF16_ENABLED
template void BackwardSegmentCmp<kDGLCUDA, int32_t, __nv_bfloat16>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int64_t, 16>(
template void BackwardSegmentCmp<kDGLCUDA, int64_t, __nv_bfloat16>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int32_t, 32>(
#endif // BF16_ENABLED
template void BackwardSegmentCmp<kDGLCUDA, int32_t, float>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int64_t, 32>(
template void BackwardSegmentCmp<kDGLCUDA, int64_t, float>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int32_t, 64>(
template void BackwardSegmentCmp<kDGLCUDA, int32_t, double>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDGLCUDA, int64_t, 64>(
template void BackwardSegmentCmp<kDGLCUDA, int64_t, double>(
NDArray feat,
NDArray arg,
NDArray out);
......
......@@ -15,30 +15,12 @@ using namespace cuda;
namespace aten {
/*!
* \brief Determine whether cusparse SpMM function is applicable.
*/
template <int bits, typename IdType>
inline bool cusparse_available(bool more_nnz_than_matrix_size) {
#if CUDART_VERSION < 11000
if (std::is_same<IdType, int>::value)
if (bits > 16)
return true;
return false;
#else
if (bits == 16)
return false; // cusparse's SpMM on fp16 is slow, temporally disabled.
// If the CSR matrix has more NNZ than matrix size, we should not use cuSPARSE 11.1.
return !more_nnz_than_matrix_size;
#endif
}
/*!
* \brief CUDA implementation of g-SpMM on Csr format.
* \note use cusparse if the reduce operator is `sum` and there is
* no broadcast, use dgl's kernel in other cases.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SpMMCsr(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const CSRMatrix& csr,
......@@ -51,58 +33,46 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
if (reduce == "sum") {
bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
if (op == "copy_lhs" && cusparse_available<bits, IdType>(more_nnz)) {
if (op == "copy_lhs" && cusparse_available<DType, IdType>(more_nnz)) {
// cusparse
int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i)
x_length *= ufeat->shape[i];
SWITCH_BITS(bits, DType, {
CusparseCsrmm2<DType, IdType>(
ufeat->ctx, csr,
static_cast<DType*>(ufeat->data),
nullptr,
static_cast<DType*>(out->data),
x_length);
});
} else if (op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>(more_nnz)) {
CusparseCsrmm2<DType, IdType>(
ufeat->ctx, csr,
static_cast<DType*>(ufeat->data),
nullptr,
static_cast<DType*>(out->data),
x_length);
} else if (op == "mul" && is_scalar_efeat && cusparse_available<DType, IdType>(more_nnz)) {
// cusparse
int64_t x_length = 1;
for (int i = 1; i < ufeat->ndim; ++i)
x_length *= ufeat->shape[i];
if (!IsNullArray(csr.data)) {
SWITCH_BITS(bits, DType, {
efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
});
efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
}
SWITCH_BITS(bits, DType, {
CusparseCsrmm2<DType, IdType>(
ufeat->ctx, csr,
static_cast<DType*>(ufeat->data),
static_cast<DType*>(efeat->data),
static_cast<DType*>(out->data),
x_length);
});
CusparseCsrmm2<DType, IdType>(
ufeat->ctx, csr,
static_cast<DType*>(ufeat->data),
static_cast<DType*>(efeat->data),
static_cast<DType*>(out->data),
x_length);
} else { // general kernel
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
bcast, csr, ufeat, efeat, out, NullArray(), NullArray());
});
SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
bcast, csr, ufeat, efeat, out, NullArray(), NullArray());
});
}
} else if (reduce == "max") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
} else if (reduce == "min") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
bcast, csr, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
} else {
LOG(FATAL) << "Not implemented";
......@@ -113,7 +83,7 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
/*!
* \brief CUDA implementation of g-SpMM on Coo format.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SpMMCoo(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const COOMatrix& coo,
......@@ -122,82 +92,94 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
NDArray out,
std::vector<NDArray> out_aux) {
if (reduce == "sum") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Sum<IdType, DType, true> > (
bcast, coo, ufeat, efeat, out, NullArray(), NullArray());
});
SWITCH_OP(op, Op, {
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Sum<IdType, DType, true> > (
bcast, coo, ufeat, efeat, out, NullArray(), NullArray());
});
} else if (reduce == "max") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Max<IdType, DType, true> > (
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
SWITCH_OP(op, Op, {
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Max<IdType, DType, true> > (
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
} else if (reduce == "min") {
SWITCH_BITS(bits, DType, {
SWITCH_OP(op, Op, {
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Min<IdType, DType, true> > (
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
SWITCH_OP(op, Op, {
cuda::SpMMCoo<IdType, DType, Op, cuda::reduce::Min<IdType, DType, true> > (
bcast, coo, ufeat, efeat, out, out_aux[0], out_aux[1]);
});
} else {
LOG(FATAL) << "Not implemented";
}
}
template void SpMMCsr<kDGLCUDA, int32_t, 16>(
template void SpMMCsr<kDGLCUDA, int32_t, __half>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int64_t, 16>(
template void SpMMCsr<kDGLCUDA, int64_t, __half>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int32_t, 32>(
#if BF16_ENABLED
template void SpMMCsr<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int64_t, 32>(
template void SpMMCsr<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int32_t, 64>(
#endif // BF16_ENABLED
template void SpMMCsr<kDGLCUDA, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int64_t, 64>(
template void SpMMCsr<kDGLCUDA, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCUDA, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int32_t, 16>(
template void SpMMCoo<kDGLCUDA, int32_t, __half>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int64_t, 16>(
template void SpMMCoo<kDGLCUDA, int64_t, __half>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int32_t, 32>(
#if BF16_ENABLED
template void SpMMCoo<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int64_t, 32>(
template void SpMMCoo<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int32_t, 64>(
#endif // BF16_ENABLED
template void SpMMCoo<kDGLCUDA, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int64_t, 64>(
template void SpMMCoo<kDGLCUDA, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCUDA, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
} // namespace aten
} // namespace dgl
......@@ -10,6 +10,7 @@
#include <limits>
#include "macro.cuh"
#include "fp16.cuh"
#include "bf16.cuh"
#include "atomic.cuh"
#include "../../runtime/cuda/cuda_common.h"
#include "./utils.h"
......@@ -20,6 +21,24 @@ using namespace cuda;
namespace aten {
/*!
* \brief Determine whether cusparse SpMM function is applicable.
*/
template <typename DType, typename IdType>
inline bool cusparse_available(bool more_nnz_than_matrix_size) {
#if CUDART_VERSION < 11000
if (std::is_same<IdType, int>::value &&
(std::is_same<DType, float>::value || std::is_same<DType, double>::value))
return true;
return false;
#else
if (std::is_same<DType, __half>::value || std::is_same<DType, __nv_bfloat16>::value)
return false; // cusparse's SpMM on fp16 is slow, temporally disabled.
// If the CSR matrix has more NNZ than matrix size, we should not use cuSPARSE 11.1.
return !more_nnz_than_matrix_size;
#endif
}
namespace {
/*! \brief Call cuBLAS geam API for transpose operation for float and double. */
......@@ -33,7 +52,6 @@ cublasStatus_t Xgeam(cublasHandle_t handle, cublasOperation_t transa,
return CUBLAS_STATUS_EXECUTION_FAILED;
}
#ifdef USE_FP16
template <>
cublasStatus_t Xgeam<__half>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
......@@ -45,7 +63,20 @@ cublasStatus_t Xgeam<__half>(cublasHandle_t handle, cublasOperation_t transa,
LOG(FATAL) << "Xgeam does not support dtype half (FP16)";
return CUBLAS_STATUS_EXECUTION_FAILED;
}
#endif
#if BF16_ENABLED
template <>
cublasStatus_t Xgeam<__nv_bfloat16>(cublasHandle_t handle, cublasOperation_t transa,
cublasOperation_t transb, int m, int n,
const __nv_bfloat16* alpha, const __nv_bfloat16* A, int lda,
const __nv_bfloat16* beta, const __nv_bfloat16* B, int ldb,
__nv_bfloat16* C, int ldc) {
// TODO(ndickson): There is no cublasHgeam, so a different
// implementation would be required.
LOG(FATAL) << "Xgeam does not support dtype bfloat16 (BF16)";
return CUBLAS_STATUS_EXECUTION_FAILED;
}
#endif // BF16_ENABLED
template <>
cublasStatus_t Xgeam<float>(cublasHandle_t handle, cublasOperation_t transa,
......@@ -131,6 +162,21 @@ void _Transpose<half>(const half* in, half* out,
CUDA_KERNEL_CALL(_TransposeKernel, nb, nt, 0, stream, in, out, col, row);
}
#if BF16_ENABLED
/*
* \brief Tranpose the input matrix for data type half.
* \note cuBLAS has no geam API for bf16 data type, fallback to our kernel.
*/
template <>
void _Transpose<__nv_bfloat16>(const __nv_bfloat16* in, __nv_bfloat16* out,
int row, int col) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = FindNumThreads(row);
int nb = col;
CUDA_KERNEL_CALL(_TransposeKernel, nb, nt, 0, stream, in, out, col, row);
}
#endif // BF16_ENABLED
/*
* \brief
*/
......
......@@ -15,30 +15,12 @@ using namespace cuda;
namespace aten {
/*!
* \brief Determine whether cusparse SpMM function is applicable.
*/
template <int bits, typename IdType>
inline bool cusparse_available(bool more_nnz_than_matrix_size) {
#if CUDART_VERSION < 11000
if (std::is_same<IdType, int>::value)
if (bits > 16)
return true;
return false;
#else
if (bits == 16)
return false; // cusparse's SpMM on fp16 is slow, temporally disabled.
// If the CSR matrix has more NNZ than matrix size, we should not use cuSPARSE 11.1.
return !more_nnz_than_matrix_size;
#endif
}
/*!
* \brief CUDA implementation of g-SpMM on Csr format.
* \note use cusparse if the reduce operator is `sum` and there is
* no broadcast, use dgl's kernel in other cases.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
......@@ -51,192 +33,202 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
bool is_scalar_efeat = vec_efeat[0].NumElements() == vec_csr[0].indices->shape[0];
bool use_efeat = op != "copy_lhs";
auto device = runtime::DeviceAPI::Get(vec_csr[0].indptr->ctx);
SWITCH_BITS(bits, DType, {
std::vector<DType*> trans_out((*vec_out).size(), NULL);
std::vector<DType*> trans_out((*vec_out).size(), NULL);
bool use_legacy_cusparsemm =
(CUDART_VERSION < 11000) && (reduce == "sum") &&
// legacy cuSPARSE does not care about NNZ, hence the argument "false".
((op == "copy_lhs" && cusparse_available<bits, IdType>(false)) ||
(op == "mul" && is_scalar_efeat && cusparse_available<bits, IdType>(false)));
// Create temporary output buffer to store non-transposed output
if (use_legacy_cusparsemm) {
for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
const int m = (*vec_out)[ntype]->shape[0];
const int n = (*vec_out)[ntype]->shape[1];
if (m == 0) continue;
DType *out = static_cast<DType*>(device->AllocWorkspace(vec_csr[0].indptr->ctx,
m * n * sizeof(DType)));
CUDA_CALL(cudaMemset(out, 0, m * n * sizeof(DType)));
trans_out[ntype] = out;
}
bool use_legacy_cusparsemm =
(CUDART_VERSION < 11000) && (reduce == "sum") &&
// legacy cuSPARSE does not care about NNZ, hence the argument "false".
((op == "copy_lhs" && cusparse_available<DType, IdType>(false)) ||
(op == "mul" && is_scalar_efeat && cusparse_available<DType, IdType>(false)));
// Create temporary output buffer to store non-transposed output
if (use_legacy_cusparsemm) {
for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
const int m = (*vec_out)[ntype]->shape[0];
const int n = (*vec_out)[ntype]->shape[1];
if (m == 0) continue;
DType *out = static_cast<DType*>(device->AllocWorkspace(vec_csr[0].indptr->ctx,
m * n * sizeof(DType)));
CUDA_CALL(cudaMemset(out, 0, m * n * sizeof(DType)));
trans_out[ntype] = out;
}
// Check shape of ufeat for all relation type and compute feature size
int64_t x_length = 1;
for (dgl_type_t etype = 0; etype < (ufeat_ntids.size() - 1); ++etype) {
NDArray ufeat = vec_ufeat[ufeat_ntids[etype]];
NDArray next_ufeat = vec_ufeat[ufeat_ntids[etype + 1]];
CHECK_EQ(ufeat->ndim, next_ufeat->ndim) << "Input features have different shapes";
for (int i = 1; i < ufeat->ndim; ++i) {
if (ufeat->shape[i] != next_ufeat->shape[i]) {
if (ufeat->shape[i] == 1 || next_ufeat->shape[i] == 1)
LOG(FATAL) <<
"Homogenized message passing on heterogeneous graphs does not support " <<
"automatic broadcasting. Please manually broadcast it before calling " <<
"message passing functions.";
else
LOG(FATAL) << "Input features have different shapes.";
return;
}
if (etype == 0)
x_length *= ufeat->shape[i];
}
// Check shape of ufeat for all relation type and compute feature size
int64_t x_length = 1;
for (dgl_type_t etype = 0; etype < (ufeat_ntids.size() - 1); ++etype) {
NDArray ufeat = vec_ufeat[ufeat_ntids[etype]];
NDArray next_ufeat = vec_ufeat[ufeat_ntids[etype + 1]];
CHECK_EQ(ufeat->ndim, next_ufeat->ndim) << "Input features have different shapes";
for (int i = 1; i < ufeat->ndim; ++i) {
if (ufeat->shape[i] != next_ufeat->shape[i]) {
if (ufeat->shape[i] == 1 || next_ufeat->shape[i] == 1)
LOG(FATAL) <<
"Homogenized message passing on heterogeneous graphs does not support " <<
"automatic broadcasting. Please manually broadcast it before calling " <<
"message passing functions.";
else
LOG(FATAL) << "Input features have different shapes.";
return;
}
if (etype == 0)
x_length *= ufeat->shape[i];
}
// TODO(Israt): Can python do the following initializations while creating the tensors?
if (reduce == "max" || reduce == "min") {
const int64_t dim = bcast.out_len;
std::vector<bool> updated((*vec_out).size(), false);
for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
DType *out_off = (*vec_out)[out_ntids[etype]].Ptr<DType>();
if (reduce == "max")
_Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Max<IdType, DType>::zero());
else // min
_Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Min<IdType, DType>::zero());
const dgl_type_t dst_id = out_ntids[etype];
if (!updated[dst_id]) {
updated[dst_id] = true;
if (op == "copy_lhs") {
IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
_Fill(argu_ntype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1));
}
if (op == "copy_rhs") {
IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
_Fill(arge_etype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1));
}
}
// TODO(Israt): Can python do the following initializations while creating the tensors?
if (reduce == "max" || reduce == "min") {
const int64_t dim = bcast.out_len;
std::vector<bool> updated((*vec_out).size(), false);
for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
DType *out_off = (*vec_out)[out_ntids[etype]].Ptr<DType>();
if (reduce == "max")
_Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Max<IdType, DType>::zero());
else // min
_Fill(out_off, vec_csr[etype].num_rows * dim, cuda::reduce::Min<IdType, DType>::zero());
const dgl_type_t dst_id = out_ntids[etype];
if (!updated[dst_id]) {
updated[dst_id] = true;
if (op == "copy_lhs") {
IdType *argu_ntype = (*out_aux)[2][dst_id].Ptr<IdType>();
_Fill(argu_ntype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1));
}
if (op == "copy_rhs") {
IdType *arge_etype = (*out_aux)[3][dst_id].Ptr<IdType>();
_Fill(arge_etype, vec_csr[etype].num_rows * dim, static_cast<IdType>(-1));
}
}
}
}
cudaStream_t stream = runtime::getCurrentCUDAStream();
for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
const dgl_type_t src_id = ufeat_ntids[etype];
const dgl_type_t dst_id = out_ntids[etype];
CSRMatrix csr = vec_csr[etype];
if (reduce == "sum") {
bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
/* Call SpMM for each relation type */
if (op == "copy_lhs" && cusparse_available<bits, IdType>(more_nnz)) { // cusparse
/* If CUDA is less than 11.0, put the output in trans_out for later transposition */
DType *out = (CUDART_VERSION < 11000) ? trans_out[dst_id] :
static_cast<DType*>((*vec_out)[dst_id]->data);
CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr,
static_cast<DType*>(vec_ufeat[src_id]->data),
nullptr,
out,
x_length, stream);
} else if (op == "mul" && is_scalar_efeat &&
cusparse_available<bits, IdType>(more_nnz)) { // cusparse
NDArray efeat = vec_efeat[etype];
if (!IsNullArray(csr.data))
efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr,
static_cast<DType*>(vec_ufeat[src_id]->data),
static_cast<DType*>(efeat->data),
// TODO(Israt): Change (*vec_out) to trans_out to support CUDA version < 11
static_cast<DType*>((*vec_out)[dst_id]->data),
x_length, stream);
} else { // general kernel
cudaStream_t stream = runtime::getCurrentCUDAStream();
for (dgl_type_t etype = 0; etype < ufeat_ntids.size(); ++etype) {
const dgl_type_t src_id = ufeat_ntids[etype];
const dgl_type_t dst_id = out_ntids[etype];
CSRMatrix csr = vec_csr[etype];
if (reduce == "sum") {
bool more_nnz = (csr.indices->shape[0] > csr.num_rows * csr.num_cols);
/* Call SpMM for each relation type */
if (op == "copy_lhs" && cusparse_available<DType, IdType>(more_nnz)) { // cusparse
/* If CUDA is less than 11.0, put the output in trans_out for later transposition */
DType *out = (CUDART_VERSION < 11000) ? trans_out[dst_id] :
static_cast<DType*>((*vec_out)[dst_id]->data);
CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr,
static_cast<DType*>(vec_ufeat[src_id]->data),
nullptr,
out,
x_length, stream);
} else if (op == "mul" && is_scalar_efeat &&
cusparse_available<DType, IdType>(more_nnz)) { // cusparse
NDArray efeat = vec_efeat[etype];
if (!IsNullArray(csr.data))
efeat = _IndexSelect<DType, IdType>(efeat, csr.data);
CusparseCsrmm2Hetero<DType, IdType>(
csr.indptr->ctx, csr,
static_cast<DType*>(vec_ufeat[src_id]->data),
static_cast<DType*>(efeat->data),
// TODO(Israt): Change (*vec_out) to trans_out to support CUDA version < 11
static_cast<DType*>((*vec_out)[dst_id]->data),
x_length, stream);
} else { // general kernel
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(), NullArray());
});
}
} else if (reduce == "max") {
SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
SWITCH_OP(op, Op, {
cuda::SpMMCsr<IdType, DType, Op, cuda::reduce::Sum<IdType, DType> >(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], NullArray(), NullArray());
});
}
} else if (reduce == "max") {
SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
src_id, etype);
});
} else if (reduce == "min") {
SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
src_id, etype);
NullArray() : vec_efeat[etype];
cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Max<IdType, DType> >(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
src_id, etype);
});
} else {
LOG(FATAL) << "Not implemented";
}
} else if (reduce == "min") {
SWITCH_OP(op, Op, {
NDArray ufeat = (vec_ufeat.size() == 0) ?
NullArray() : vec_ufeat[src_id];
NDArray efeat = (vec_efeat.size() == 0) ?
NullArray() : vec_efeat[etype];
cuda::SpMMCmpCsrHetero<IdType, DType, Op, cuda::reduce::Min<IdType, DType> >(
bcast, csr, ufeat, efeat, (*vec_out)[dst_id], (*out_aux)[0][dst_id],
(*out_aux)[1][dst_id], (*out_aux)[2][dst_id], (*out_aux)[3][dst_id],
src_id, etype);
});
} else {
LOG(FATAL) << "Not implemented";
}
}
if (use_legacy_cusparsemm) {
// transpose output
for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
const int m = (*vec_out)[ntype]->shape[0];
const int n = (*vec_out)[ntype]->shape[1];
if (m == 0) continue;
DType *C_data = static_cast<DType*>((*vec_out)[ntype]->data);
_Transpose(trans_out[ntype], C_data, n, m);
device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]);
}
if (use_legacy_cusparsemm) {
// transpose output
for (dgl_type_t ntype = 0; ntype < (*vec_out).size(); ++ntype) {
const int m = (*vec_out)[ntype]->shape[0];
const int n = (*vec_out)[ntype]->shape[1];
if (m == 0) continue;
DType *C_data = static_cast<DType*>((*vec_out)[ntype]->data);
_Transpose(trans_out[ntype], C_data, n, m);
device->FreeWorkspace(vec_csr[0].indptr->ctx, trans_out[ntype]);
}
});
}
}
template void SpMMCsrHetero<kDGLCUDA, int32_t, 16>(
template void SpMMCsrHetero<kDGLCUDA, int32_t, __half>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int64_t, 16>(
template void SpMMCsrHetero<kDGLCUDA, int64_t, __half>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int32_t, 32>(
#if BF16_ENABLED
template void SpMMCsrHetero<kDGLCUDA, int32_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int64_t, 32>(
template void SpMMCsrHetero<kDGLCUDA, int64_t, __nv_bfloat16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int32_t, 64>(
#endif // BF16_ENABLED
template void SpMMCsrHetero<kDGLCUDA, int32_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int64_t, 64>(
template void SpMMCsrHetero<kDGLCUDA, int64_t, float>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int32_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDGLCUDA, int64_t, double>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
} // namespace aten
} // namespace dgl
......@@ -22,36 +22,6 @@ namespace cuda {
// The max number of threads per block
#define CUDA_MAX_NUM_THREADS 256
#ifdef USE_FP16
#define SWITCH_BITS(bits, DType, ...) \
do { \
if ((bits) == 16) { \
typedef half DType; \
{ __VA_ARGS__ } \
} else if ((bits) == 32) { \
typedef float DType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef double DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Data type not recognized with bits " << bits; \
} \
} while (0)
#else // USE_FP16
#define SWITCH_BITS(bits, DType, ...) \
do { \
if ((bits) == 32) { \
typedef float DType; \
{ __VA_ARGS__ } \
} else if ((bits) == 64) { \
typedef double DType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << "Data type not recognized with bits " << bits; \
} \
} while (0)
#endif // USE_FP16
/*! \brief Calculate the number of threads needed given the dimension length.
*
......@@ -185,6 +155,40 @@ __global__ void _LinearSearchKernel(
}
}
#if BF16_ENABLED
/*!
* \brief Specialization for bf16 because conversion from long long to bfloat16
* doesn't exist before SM80.
*/
template <typename IdType>
__global__ void _LinearSearchKernel(
const IdType* indptr, const IdType* indices, const IdType* data,
const IdType* row, const IdType* col,
int64_t row_stride, int64_t col_stride, int64_t length,
const __nv_bfloat16* weights, __nv_bfloat16 filler, __nv_bfloat16* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
const int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
int rpos = tx * row_stride, cpos = tx * col_stride;
IdType v = -1;
const IdType r = row[rpos], c = col[cpos];
for (IdType i = indptr[r]; i < indptr[r + 1]; ++i) {
if (indices[i] == c) {
v = data ? data[i] : i;
break;
}
}
if (v == -1) {
out[tx] = filler;
} else {
// If the result is saved in bf16, it should be fine to convert it to float first
out[tx] = weights ? weights[v] : __nv_bfloat16(static_cast<float>(v));
}
tx += stride_x;
}
}
#endif // BF16_ENABLED
template <typename DType>
inline DType GetCUDAScalar(
runtime::DeviceAPI* device_api,
......
......@@ -36,13 +36,13 @@ void SpMM(const std::string& op, const std::string& reduce,
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out->dtype, bits, "Feature data", {
ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", {
if (format == SparseFormat::kCSC) {
SpMMCsr<XPU, IdType, bits>(
SpMMCsr<XPU, IdType, Dtype>(
op, reduce, bcast, graph->GetCSCMatrix(0),
ufeat, efeat, out, out_aux);
} else if (format == SparseFormat::kCOO) {
SpMMCoo<XPU, IdType, bits>(
SpMMCoo<XPU, IdType, Dtype>(
op, reduce, bcast, graph->GetCOOMatrix(0),
ufeat, efeat, out, out_aux);
} else {
......@@ -76,8 +76,8 @@ void SegmentMM(const NDArray A,
CHECK(A->ctx == B->ctx) << "segment_mm expects A and B to be of the same device";
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "SegmentMM", {
ATEN_ID_TYPE_SWITCH(seglen_A->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(A->dtype, bits, "Feature data", {
SegmentMM<XPU, IdType, bits>(A, B, C, seglen_A, A_trans, B_trans);
ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", {
SegmentMM<XPU, IdType, Dtype>(A, B, C, seglen_A, A_trans, B_trans);
});
});
});
......@@ -94,8 +94,8 @@ void SegmentMMBackwardB(const NDArray A,
<< "segment_mm expects seglen to be on CPU.";
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "SegmentMMBackwardB", {
ATEN_ID_TYPE_SWITCH(seglen->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(A->dtype, bits, "Feature data", {
SegmentMMBackwardB<XPU, IdType, bits>(A, dC, dB, seglen);
ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", {
SegmentMMBackwardB<XPU, IdType, Dtype>(A, dC, dB, seglen);
});
});
});
......@@ -131,8 +131,8 @@ void GatherMM(const NDArray A,
const auto idtype = aten::IsNullArray(idx_a)? idx_b->dtype : idx_a->dtype;
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", {
ATEN_ID_TYPE_SWITCH(idtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(A->dtype, bits, "Feature data", {
GatherMM<XPU, IdType, bits>(A, B, C, idx_a, idx_b);
ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", {
GatherMM<XPU, IdType, Dtype>(A, B, C, idx_a, idx_b);
});
});
});
......@@ -171,8 +171,8 @@ void GatherMMScatter(const NDArray A,
}
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "GatherMM", {
ATEN_ID_TYPE_SWITCH(idx_c->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(A->dtype, bits, "Feature data", {
GatherMMScatter<XPU, IdType, bits>(A, B, C, idx_a, idx_b, idx_c);
ATEN_FLOAT_TYPE_SWITCH_16BITS(A->dtype, Dtype, XPU, "Feature data", {
GatherMMScatter<XPU, IdType, Dtype>(A, B, C, idx_a, idx_b, idx_c);
});
});
});
......@@ -210,9 +210,9 @@ void SpMMHetero(const std::string& op, const std::string& reduce,
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SpMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH((*out)[out_eid[0]]->dtype, bits, "Feature data", {
ATEN_FLOAT_TYPE_SWITCH_16BITS((*out)[out_eid[0]]->dtype, Dtype, XPU, "Feature data", {
if (format == SparseFormat::kCSC) {
SpMMCsrHetero<XPU, IdType, bits>(
SpMMCsrHetero<XPU, IdType, Dtype>(
op, reduce, bcast, vec_graph,
ufeat_vec, efeat_vec, out, out_aux,
ufeat_eid, out_eid);
......@@ -241,13 +241,13 @@ void SDDMM(const std::string& op,
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out->dtype, bits, "Feature data", {
ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "Feature data", {
if (format == SparseFormat::kCSR) {
SDDMMCsr<XPU, IdType, bits>(
SDDMMCsr<XPU, IdType, Dtype>(
op, bcast, graph->GetCSRMatrix(0),
lhs, rhs, out, lhs_target, rhs_target);
} else if (format == SparseFormat::kCOO) {
SDDMMCoo<XPU, IdType, bits>(
SDDMMCoo<XPU, IdType, Dtype>(
op, bcast, graph->GetCOOMatrix(0),
lhs, rhs, out, lhs_target, rhs_target);
} else {
......@@ -294,13 +294,13 @@ void SDDMMHetero(const std::string& op,
ATEN_XPU_SWITCH_CUDA(graph->Context().device_type, XPU, "SDDMM", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out[rhs_eid[0]]->dtype, bits, "Feature data", {
ATEN_FLOAT_TYPE_SWITCH_16BITS(out[rhs_eid[0]]->dtype, Dtype, XPU, "Feature data", {
if (format == SparseFormat::kCSR) {
std::vector<CSRMatrix> vec_csr;
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
vec_csr.push_back(graph->GetCSRMatrix(etype));
}
SDDMMCsrHetero<XPU, IdType, bits>(
SDDMMCsrHetero<XPU, IdType, Dtype>(
op, bcast, vec_csr,
lhs, rhs, out, lhs_target, rhs_target,
lhs_eid, rhs_eid);
......@@ -309,7 +309,7 @@ void SDDMMHetero(const std::string& op,
for (dgl_type_t etype = 0; etype < graph->NumEdgeTypes(); ++etype) {
vec_coo.push_back(graph->GetCOOMatrix(etype));
}
SDDMMCooHetero<XPU, IdType, bits>(
SDDMMCooHetero<XPU, IdType, Dtype>(
op, bcast, vec_coo,
lhs, rhs, out, lhs_target, rhs_target,
lhs_eid, rhs_eid);
......@@ -333,8 +333,8 @@ void Edge_softmax_forward(const std::string& op,
ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "edge_softmax", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out->dtype, bits, "edge_softmax out data", {
Edge_softmax_csr_forward<XPU, IdType, bits>(
ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "edge_softmax out data", {
Edge_softmax_csr_forward<XPU, IdType, Dtype>(
op, bcast, graph->GetCSCMatrix(0), ufeat, efeat, out);
});
});
......@@ -354,8 +354,8 @@ void Edge_softmax_backward(const std::string& op,
ATEN_XPU_SWITCH(graph->Context().device_type, XPU, "edge_softmax_back", {
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
ATEN_FLOAT_BITS_SWITCH(out->dtype, bits, "edge_softmax out data_back", {
Edge_softmax_csr_backward<XPU, IdType, bits>(
ATEN_FLOAT_TYPE_SWITCH_16BITS(out->dtype, Dtype, XPU, "edge_softmax out data_back", {
Edge_softmax_csr_backward<XPU, IdType, Dtype>(
op, bcast, graph->GetCSCMatrix(0), out, sds, back_out);
});
});
......@@ -380,8 +380,8 @@ void SegmentReduceDispatch(const std::string& op,
NDArray arg) {
ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "SegmentReduce", {
ATEN_ID_TYPE_SWITCH(offsets->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(feat->dtype, bits, "Feature data", {
SegmentReduce<XPU, IdType, bits>(op, feat, offsets, out, arg);
ATEN_FLOAT_TYPE_SWITCH_16BITS(feat->dtype, Dtype, XPU, "Feature data", {
SegmentReduce<XPU, IdType, Dtype>(op, feat, offsets, out, arg);
});
});
});
......@@ -391,8 +391,8 @@ void SegmentReduceDispatch(const std::string& op,
void ScatterAddDispatch(NDArray feat, NDArray idx, NDArray out) {
ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "ScatterAdd", {
ATEN_ID_TYPE_SWITCH(idx->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(feat->dtype, bits, "Feature data", {
ScatterAdd<XPU, IdType, bits>(feat, idx, out);
ATEN_FLOAT_TYPE_SWITCH_16BITS(feat->dtype, Dtype, XPU, "Feature data", {
ScatterAdd<XPU, IdType, Dtype>(feat, idx, out);
});
});
});
......@@ -409,8 +409,8 @@ void UpdateGradMinMaxDispatchHetero(const HeteroGraphPtr& graph,
auto src_id = pair.first;
ATEN_XPU_SWITCH_CUDA(feat[src_id]->ctx.device_type, XPU, "ScatterAdd", {
ATEN_ID_TYPE_SWITCH(idx[src_id]->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(feat[src_id]->dtype, bits, "Feature data", {
UpdateGradMinMax_hetero<XPU, IdType, bits>(graph, op, feat, idx, idx_etype, out);
ATEN_FLOAT_TYPE_SWITCH_16BITS(feat[src_id]->dtype, Dtype, XPU, "Feature data", {
UpdateGradMinMax_hetero<XPU, IdType, Dtype>(graph, op, feat, idx, idx_etype, out);
});
});
});
......@@ -420,8 +420,8 @@ void UpdateGradMinMaxDispatchHetero(const HeteroGraphPtr& graph,
void BackwardSegmentCmpDispatch(NDArray feat, NDArray arg, NDArray out) {
ATEN_XPU_SWITCH_CUDA(feat->ctx.device_type, XPU, "BackwardSegmentCmp", {
ATEN_ID_TYPE_SWITCH(arg->dtype, IdType, {
ATEN_FLOAT_BITS_SWITCH(feat->dtype, bits, "Feature data", {
BackwardSegmentCmp<XPU, IdType, bits>(feat, arg, out);
ATEN_FLOAT_TYPE_SWITCH_16BITS(feat->dtype, Dtype, XPU, "Feature data", {
BackwardSegmentCmp<XPU, IdType, Dtype>(feat, arg, out);
});
});
});
......
......@@ -20,7 +20,7 @@ namespace aten {
/*!
* \brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SpMMCsr(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const aten::CSRMatrix& csr,
......@@ -33,7 +33,7 @@ void SpMMCsr(const std::string& op, const std::string& reduce,
* \brief Generalized Sparse Matrix Dense Matrix Multiplication on Csr format
with heterograph support.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SpMMCsrHetero(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const std::vector<CSRMatrix>& csr,
......@@ -46,7 +46,7 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
/*!
* \brief Generalized Sparse Matrix Dense Matrix Multiplication on Coo format.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SpMMCoo(const std::string& op, const std::string& reduce,
const BcastOff& bcast,
const aten::COOMatrix& coo,
......@@ -58,7 +58,7 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
/*!
* \brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr format.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SDDMMCsr(const std::string& op,
const BcastOff& bcast,
const aten::CSRMatrix& csr,
......@@ -71,7 +71,7 @@ void SDDMMCsr(const std::string& op,
* \brief Generalized Sampled Dense-Dense Matrix Multiplication on Csr
format with heterograph support.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SDDMMCsrHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
......@@ -86,7 +86,7 @@ void SDDMMCsrHetero(const std::string& op,
/*!
* \brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo format.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SDDMMCoo(const std::string& op,
const BcastOff& bcast,
const aten::COOMatrix& coo,
......@@ -100,7 +100,7 @@ void SDDMMCoo(const std::string& op,
* \brief Generalized Sampled Dense-Dense Matrix Multiplication on Coo
format with heterograph support.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SDDMMCooHetero(const std::string& op,
const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
......@@ -115,7 +115,7 @@ void SDDMMCooHetero(const std::string& op,
/*!
* \brief Generalized Dense Matrix-Matrix Multiplication according to relation types.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void GatherMM(const NDArray A,
const NDArray B,
NDArray out,
......@@ -125,7 +125,7 @@ void GatherMM(const NDArray A,
/*!
* \brief Generalized Dense Matrix-Matrix Multiplication according to relation types.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void GatherMMScatter(const NDArray A,
const NDArray B,
NDArray out,
......@@ -136,14 +136,14 @@ void GatherMMScatter(const NDArray A,
/*!
* \brief Generalized segmented dense Matrix-Matrix Multiplication.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SegmentMM(const NDArray A,
const NDArray B,
NDArray out,
const NDArray seglen_A,
bool a_trans, bool b_trans);
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SegmentMMBackwardB(const NDArray A,
const NDArray dC,
NDArray dB,
......@@ -152,7 +152,7 @@ void SegmentMMBackwardB(const NDArray A,
/*!
* \brief Segment reduce.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void SegmentReduce(const std::string& op,
NDArray feat,
NDArray offsets,
......@@ -162,7 +162,7 @@ void SegmentReduce(const std::string& op,
/*!
* \brief Scatter Add on first dimension.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void ScatterAdd(NDArray feat,
NDArray idx,
NDArray out);
......@@ -170,7 +170,7 @@ void ScatterAdd(NDArray feat,
/*!
* \brief Update gradients for reduce operator max and min on first dimension.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
const std::string& op,
const std::vector<NDArray>& feat,
......@@ -181,7 +181,7 @@ void UpdateGradMinMax_hetero(const HeteroGraphPtr& g,
/*!
* \brief Backward function of segment cmp.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void BackwardSegmentCmp(NDArray feat,
NDArray arg,
NDArray out);
......@@ -223,7 +223,7 @@ std::pair<CSRMatrix, NDArray> CSRSum(
/*!
* \brief Edge_softmax_csr forward function on Csr format.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_forward(const std::string& op,
const BcastOff& bcast,
const aten::CSRMatrix& csr,
......@@ -233,7 +233,7 @@ void Edge_softmax_csr_forward(const std::string& op,
/*!
* \brief Edge_softmax_csr backward function on Csr format.
*/
template <int XPU, typename IdType, int bits>
template <int XPU, typename IdType, typename DType>
void Edge_softmax_csr_backward(const std::string& op,
const BcastOff& bcast,
const aten::CSRMatrix& csr,
......
......@@ -22,9 +22,12 @@ constexpr DGLDataType DGLDataTypeTraits<int32_t>::dtype;
constexpr DGLDataType DGLDataTypeTraits<int64_t>::dtype;
constexpr DGLDataType DGLDataTypeTraits<uint32_t>::dtype;
constexpr DGLDataType DGLDataTypeTraits<uint64_t>::dtype;
#ifdef USE_FP16
#ifdef DGL_USE_CUDA
constexpr DGLDataType DGLDataTypeTraits<__half>::dtype;
#endif
#if BF16_ENABLED
constexpr DGLDataType DGLDataTypeTraits<__nv_bfloat16>::dtype;
#endif // BF16_ENABLED
#endif // DGL_USE_CUDA
constexpr DGLDataType DGLDataTypeTraits<float>::dtype;
constexpr DGLDataType DGLDataTypeTraits<double>::dtype;
......
from distutils.version import LooseVersion
import random
import unittest
import backend as F
import networkx as nx
import numpy as np
import pytest
import torch
......@@ -325,13 +325,20 @@ def test_segment_reduce(reducer):
@parametrize_idtype
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
@pytest.mark.parametrize(
"dtype,tol",
[(torch.float16, 1e-2), (torch.float32, 3e-3), (torch.float64, 1e-4)],
"dtype, tol",
[(torch.float16, 1e-2), (torch.bfloat16, 1e-2),
(torch.float32, 3e-3), (torch.float64, 1e-4)],
)
def test_segment_mm(idtype, feat_size, dtype, tol):
if F._default_context_str == "cpu" and dtype == torch.float16:
if F._default_context_str == "cpu" and dtype in (torch.float16, torch.bfloat16):
pytest.skip(
"fp16 support for CPU linalg functions has been removed in PyTorch."
"Only support float32 and float64 on CPU."
)
if F._default_context_str == "gpu" \
and LooseVersion(torch.version.cuda) < LooseVersion("11.0") \
and dtype == torch.bfloat16:
pytest.skip(
"BF16 requires CUDA >= 11.0."
)
dev = F.ctx()
# input
......@@ -343,7 +350,7 @@ def test_segment_mm(idtype, feat_size, dtype, tol):
.to(dtype)
)
b.requires_grad_()
seglen_a = torch.tensor([10, 15, 8, 0, 1, 9, 18, 24, 15, 0])
seglen_a = torch.tensor([10, 15, 8, 0, 1, 9, 18, 24, 15, 0]).to(idtype)
dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype)
# compute
c = dgl.ops.segment_mm(a, b, seglen_a)
......@@ -371,19 +378,28 @@ def test_segment_mm(idtype, feat_size, dtype, tol):
@unittest.skipIf(
dgl.backend.backend_name != "pytorch", reason="Only support PyTorch for now"
)
@parametrize_idtype
@pytest.mark.parametrize("feat_size", [1, 8, 16, 64, 256])
def test_gather_mm_idx_b(idtype, feat_size):
import torch
@pytest.mark.parametrize(
"dtype, tol",
[(torch.float16, 1e-2), (torch.bfloat16, 2e-2),
(torch.float32, 3e-3), (torch.float64, 1e-4)]
)
def test_gather_mm_idx_b(feat_size, dtype, tol):
if F._default_context_str == "cpu" and dtype in (torch.float16, torch.bfloat16):
pytest.skip("Only support float32 and float64 on CPU.")
if F._default_context_str == "gpu" \
and LooseVersion(torch.version.cuda) < LooseVersion("11.0") \
and dtype == torch.bfloat16:
pytest.skip("BF16 requires CUDA >= 11.0.")
dev = F.ctx()
# input
a = torch.tensor(np.random.rand(100, feat_size)).to(dev)
a = torch.tensor(np.random.rand(100, feat_size)).to(dev).to(dtype)
a.requires_grad_()
b = torch.tensor(np.random.rand(10, feat_size, feat_size + 1)).to(dev)
b = torch.tensor(np.random.rand(10, feat_size, feat_size + 1)).to(dev).to(dtype)
b.requires_grad_()
idx = torch.tensor(np.random.randint(0, 10, 100)).to(dev).long()
dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev)
dc = torch.tensor(np.random.rand(100, feat_size + 1)).to(dev).to(dtype)
# compute
c = dgl.ops.gather_mm(a, b, idx_b=idx)
c.backward(dc)
......@@ -397,9 +413,9 @@ def test_gather_mm_idx_b(idtype, feat_size):
da_t = a.grad
db_t = b.grad
assert torch.allclose(c, c_t, atol=1e-4, rtol=1e-4)
assert torch.allclose(da, da_t, atol=1e-4, rtol=1e-4)
assert torch.allclose(db, db_t, atol=1e-4, rtol=1e-4)
assert torch.allclose(c, c_t, atol=tol, rtol=tol)
assert torch.allclose(da, da_t, atol=tol, rtol=tol)
assert torch.allclose(db, db_t, atol=tol, rtol=tol)
@unittest.skipIf(
......
......@@ -25,7 +25,7 @@ if [[ $arch == *"x86"* ]]; then
fi
if [[ $1 != "cpu" ]]; then
CMAKE_VARS="-DUSE_CUDA=ON -DUSE_NCCL=ON -DUSE_FP16=ON $CMAKE_VARS"
CMAKE_VARS="-DUSE_CUDA=ON -DUSE_NCCL=ON $CMAKE_VARS"
fi
if [ -d build ]; then
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment