"src/vscode:/vscode.git/clone" did not exist on "05b706c003666709471c62bd44b8f40190506000"
Unverified Commit acb4eb7e authored by Ilia Taraban's avatar Ilia Taraban Committed by GitHub
Browse files

[Feature] Add bfloat16 support for CPU (#5497)


Co-authored-by: default avatarHongzhi (Steve), Chen <chenhongzhi.nkcs@gmail.com>
parent 29e66615
......@@ -203,11 +203,11 @@ endif(NOT MSVC)
# Compile LIBXSMM
if((NOT MSVC) AND USE_LIBXSMM)
if(REBUILD_LIBXSMM)
add_custom_target(libxsmm COMMAND make realclean COMMAND make -j ECFLAGS="-Wno-error=deprecated-declarations" BLAS=0
add_custom_target(libxsmm COMMAND make realclean COMMAND make -j ECFLAGS="-Wno-error=deprecated-declarations" BLAS=0 CC=${CMAKE_C_COMPILER}
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/third_party/libxsmm
)
else(REBUILD_LIBXSMM)
add_custom_target(libxsmm COMMAND make -j ECFLAGS="-Wno-error=deprecated-declarations" BLAS=0
add_custom_target(libxsmm COMMAND make -j ECFLAGS="-Wno-error=deprecated-declarations" BLAS=0 CC=${CMAKE_C_COMPILER}
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/third_party/libxsmm
)
endif(REBUILD_LIBXSMM)
......
......@@ -152,8 +152,13 @@
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef __nv_bfloat16 FloatType; \
{ __VA_ARGS__ } \
} else if (XPU == kDGLCPU) { \
LOG(FATAL) << (val_name) << " can only be float32 or float64 on CPU"; \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) { \
LOG(FATAL) << (val_name) << " can't be float16 on CPU"; \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef BFloat16 FloatType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float16/bfloat16/float32/float64 on GPU"; \
......@@ -177,8 +182,13 @@
} else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \
} else if (XPU == kDGLCPU) { \
LOG(FATAL) << (val_name) << " can only be float32 or float64 on CPU"; \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) { \
LOG(FATAL) << (val_name) << " can't be float16 on CPU"; \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef BFloat16 FloatType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float16/float32/float64 on GPU"; \
......@@ -187,7 +197,24 @@
#endif // BF16_ENABLED
#else // DGL_USE_CUDA
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, {__VA_ARGS__})
do { \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{ __VA_ARGS__ } \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef BFloat16 FloatType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be bfloat16/float32/float64 on CPU"; \
} \
} while (0)
#endif // DGL_USE_CUDA
/**
......
/**
* Copyright (c) 2023 by Contributors
* @file dgl/runtime/ndarray.h
* @brief BFloat16 CPU header
*/
#ifndef DGL_RUNTIME_BFLOAT16_H_
#define DGL_RUNTIME_BFLOAT16_H_
#include <cmath>
class BFloat16 {
uint16_t val;
public:
constexpr BFloat16() : val(0) {}
// Disable lint "explicit" warning, since implicit usage on constructor is
// expected.
BFloat16(float f) { // NOLINT
if (std::isnan(f)) {
val = 0x7FC0;
} else {
union {
uint16_t iraw16[2];
uint32_t iraw32;
float f32;
};
f32 = f;
const uint32_t rounding_bias = 0x00007FFF + (iraw16[1] & 0x1);
val = static_cast<uint16_t>((iraw32 + rounding_bias) >> 16);
}
}
static constexpr BFloat16 Min() {
BFloat16 min;
min.val = 0xFF80;
return min;
}
static constexpr BFloat16 Max() {
BFloat16 max;
max.val = 0x7F80;
return max;
}
BFloat16& operator-=(const float& rhs) {
float lhs = (*this);
(*this) = lhs - rhs;
return *this;
}
BFloat16& operator+=(const float& rhs) {
float lhs = (*this);
(*this) = lhs + rhs;
return *this;
}
operator float() const {
union {
float f;
uint16_t raw[2];
};
raw[0] = 0;
raw[1] = val;
return f;
}
};
#endif // DGL_RUNTIME_BFLOAT16_H_
......@@ -12,6 +12,7 @@
#include <utility>
#include <vector>
#include "bfloat16.h"
#include "c_runtime_api.h"
#include "serializer.h"
#include "shared_mem.h"
......
......@@ -153,7 +153,7 @@ def config_cython():
library_dirs=library_dirs,
libraries=libraries,
# Crashes without this flag with GCC 5.3.1
extra_compile_args=["-std=c++11"],
extra_compile_args=["-std=c++14"],
language="c++",
)
)
......
......@@ -40,6 +40,12 @@ void GatherMMScatter(
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
}
template void GatherMM<kDGLCPU, int32_t, BFloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, BFloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMM<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
......@@ -53,6 +59,12 @@ template void GatherMM<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMMScatter<kDGLCPU, int32_t, BFloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, BFloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
......@@ -66,6 +78,12 @@ template void GatherMMScatter<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void SegmentMM<kDGLCPU, int32_t, BFloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, BFloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
......@@ -79,6 +97,10 @@ template void SegmentMM<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDGLCPU, int32_t, BFloat16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int64_t, BFloat16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int64_t, float>(
......
......@@ -78,6 +78,12 @@ void SDDMMCsrHetero(
});
}
template void SDDMMCsr<kDGLCPU, int32_t, BFloat16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int64_t, BFloat16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
......@@ -91,6 +97,18 @@ template void SDDMMCsr<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsrHetero<kDGLCPU, int32_t, BFloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int64_t, BFloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
......@@ -152,6 +170,12 @@ void SDDMMCooHetero(
});
}
template void SDDMMCoo<kDGLCPU, int32_t, BFloat16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int64_t, BFloat16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
......@@ -165,6 +189,18 @@ template void SDDMMCoo<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCooHetero<kDGLCPU, int32_t, BFloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int64_t, BFloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
......
......@@ -56,6 +56,12 @@ void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
cpu::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
}
template void SegmentReduce<kDGLCPU, int32_t, BFloat16>(
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCPU, int64_t, BFloat16>(
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCPU, int32_t, float>(
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
......@@ -69,6 +75,16 @@ template void SegmentReduce<kDGLCPU, int64_t, double>(
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template <>
void ScatterAdd<kDGLCPU, int32_t, BFloat16>(
NDArray feat, NDArray idx, NDArray out) {
LOG(FATAL) << "Unsupported CPU kernel for ScatterAdd for BF16.";
}
template <>
void ScatterAdd<kDGLCPU, int64_t, BFloat16>(
NDArray feat, NDArray idx, NDArray out) {
LOG(FATAL) << "Unsupported CPU kernel for ScatterAdd for BF16.";
}
template void ScatterAdd<kDGLCPU, int32_t, float>(
NDArray feat, NDArray idx, NDArray out);
template void ScatterAdd<kDGLCPU, int64_t, float>(
......@@ -78,6 +94,20 @@ template void ScatterAdd<kDGLCPU, int32_t, double>(
template void ScatterAdd<kDGLCPU, int64_t, double>(
NDArray feat, NDArray arg, NDArray out);
template <>
void UpdateGradMinMax_hetero<kDGLCPU, int32_t, BFloat16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {
LOG(FATAL) << "Unsupported CPU kernel for UpdateGradMinMax_hetero for BF16.";
}
template <>
void UpdateGradMinMax_hetero<kDGLCPU, int64_t, BFloat16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {
LOG(FATAL) << "Unsupported CPU kernel for UpdateGradMinMax_hetero for BF16.";
}
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, float>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
......@@ -95,6 +125,10 @@ template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, double>(
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDGLCPU, int32_t, BFloat16>(
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, BFloat16>(
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int32_t, float>(
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, float>(
......
......@@ -25,6 +25,8 @@ namespace cpu {
*/
template <typename IdType, typename DType>
void SegmentSum(NDArray feat, NDArray offsets, NDArray out) {
if (std::is_same<DType, BFloat16>::value)
LOG(FATAL) << "Unsupported CPU kernel for SegmentSum for BF16.";
int n = out->shape[0];
int dim = 1;
for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
......
......@@ -124,6 +124,14 @@ void SpMMCsrHetero(
}
}
template void SpMMCsr<kDGLCPU, int32_t, BFloat16>(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int64_t, BFloat16>(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCsr<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
......@@ -141,6 +149,20 @@ template void SpMMCsr<kDGLCPU, int64_t, double>(
const CSRMatrix& csr, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCsrHetero<kDGLCPU, int32_t, BFloat16>(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int64_t, BFloat16>(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
const std::vector<NDArray>& efeat, std::vector<NDArray>* out,
std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const std::vector<CSRMatrix>& csr, const std::vector<NDArray>& ufeat,
......@@ -191,7 +213,12 @@ void Edge_softmax_csr_backward(
bcast, csr, out, sds, back_out);
});
}
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, BFloat16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, BFloat16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
......@@ -205,6 +232,12 @@ template void Edge_softmax_csr_forward<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, BFloat16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, BFloat16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
......@@ -242,6 +275,14 @@ void SpMMCoo(
}
}
template void SpMMCoo<kDGLCPU, int32_t, BFloat16>(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int64_t, BFloat16>(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
std::vector<NDArray> out_aux);
template void SpMMCoo<kDGLCPU, int32_t, float>(
const std::string& op, const std::string& reduce, const BcastOff& bcast,
const COOMatrix& coo, NDArray ufeat, NDArray efeat, NDArray out,
......
......@@ -27,6 +27,10 @@ namespace dgl {
namespace aten {
namespace cpu {
template <typename DType>
using AccType = typename std::conditional<
std::is_same<DType, BFloat16>::value, float, DType>::type;
/**
* @brief Naive CPU kernel of SpMM on Csr format.
* @param cpu_spec JIT'ed kernel
......@@ -51,18 +55,20 @@ void SpMMSumCsrNaive(
for (auto rid = b; rid < e; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
DType* out_off = O + rid * dim;
for (int64_t k = 0; k < dim; ++k) {
AccType<DType> acc = 0.;
for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j];
const IdType eid = has_idx ? edges[j] : j;
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off =
Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
const DType* rhs_off =
Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
out_off[k] += Op::Call(lhs_off, rhs_off);
acc += Op::Call(lhs_off, rhs_off);
}
out_off[k] += acc;
}
}
});
......@@ -129,7 +135,8 @@ void SpMMSumCsr(
* we use atomic operators in the reduction phase.
*/
template <typename IdType, typename DType, typename Op>
void SpMMSumCoo(
typename std::enable_if<!std::is_same<DType, BFloat16>::value, void>::type
SpMMSumCoo(
const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
NDArray out) {
const bool has_idx = !IsNullArray(coo.data);
......@@ -166,6 +173,14 @@ void SpMMSumCoo(
}
}
template <typename IdType, typename DType, typename Op>
typename std::enable_if<std::is_same<DType, BFloat16>::value, void>::type
SpMMSumCoo(
const BcastOff& bcast, const COOMatrix& coo, NDArray ufeat, NDArray efeat,
NDArray out) {
LOG(FATAL) << "Unsupported CPU kernel for SpMMSumCoo for BF16.";
}
/**
* @brief CPU kernel of SpMM-Min/Max on Csr format.
* @param bcast Broadcast information.
......@@ -442,7 +457,7 @@ void Edge_softmax_csr_forward(
runtime::parallel_for(0, csr.num_rows, [&](size_t b, size_t e) {
for (auto rid = b; rid < e; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
std::vector<DType> data_e(row_end - row_start, 0);
std::vector<AccType<DType>> data_e(row_end - row_start, 0);
std::vector<IdType> num(row_end - row_start, 0);
for (int64_t k = 0; k < dim; ++k) {
DType max_v = -std::numeric_limits<DType>::infinity();
......@@ -481,6 +496,8 @@ template <typename IdType, typename DType, typename Op>
void Edge_softmax_csr_backward(
const BcastOff& bcast, const CSRMatrix& csr, NDArray out, NDArray sds,
NDArray back_out) {
typedef typename std::conditional<
std::is_same<DType, BFloat16>::value, float, DType>::type AccType;
const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
const IdType* edges =
......@@ -492,7 +509,7 @@ void Edge_softmax_csr_backward(
for (auto rid = b; rid < e; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
for (int64_t k = 0; k < dim; ++k) {
DType sum_sds = 0;
AccType sum_sds = 0;
for (IdType j = row_start; j < row_end; ++j) {
const IdType eid = has_idx ? edges[j] : j;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
......
......@@ -102,20 +102,36 @@ constexpr bool CopyRhs<DType>::use_rhs;
//////////////////////////////// Reduce operators on CPU
///////////////////////////////////
template <typename DType>
constexpr DType MinDType() {
if (std::is_same<DType, BFloat16>::value)
return BFloat16::Min();
else
return -std::numeric_limits<DType>::infinity();
}
template <typename DType>
struct Max {
typedef DType type;
static constexpr DType zero = -std::numeric_limits<DType>::infinity();
static constexpr DType zero = MinDType<DType>();
// return true if accum should be replaced
inline static DType Call(DType accum, DType val) { return accum < val; }
};
template <typename DType>
constexpr DType Max<DType>::zero;
template <typename DType>
constexpr DType MaxDType() {
if (std::is_same<DType, BFloat16>::value)
return BFloat16::Max();
else
return std::numeric_limits<DType>::infinity();
}
template <typename DType>
struct Min {
typedef DType type;
static constexpr DType zero = std::numeric_limits<DType>::infinity();
static constexpr DType zero = MaxDType<DType>();
// return true if accum should be replaced
inline static DType Call(DType accum, DType val) { return accum > val; }
};
......
......@@ -257,7 +257,13 @@ inline libxsmm_meltwfunction_opreduce_vecs_idx SpMMCreateLibxsmmKernel(
N, &_ld, &_ld, LIBXSMM_DATATYPE_F32, LIBXSMM_DATATYPE_F32,
(sizeof(IdType) == 8) ? LIBXSMM_DATATYPE_I64 : LIBXSMM_DATATYPE_I32,
opredop_flags);
} else { // assume bf16
kernel = libxsmm_dispatch_meltw_opreduce_vecs_idx(
N, &_ld, &_ld, LIBXSMM_DATATYPE_BF16, LIBXSMM_DATATYPE_BF16,
(sizeof(IdType) == 8) ? LIBXSMM_DATATYPE_I64 : LIBXSMM_DATATYPE_I32,
opredop_flags);
}
if (kernel == nullptr) {
LOG(FATAL) << "Failed to generate libxsmm kernel for the SpMM operation."
"To disable libxsmm, use dgl.use_libxsmm(false).";
......
......@@ -1426,3 +1426,12 @@ TEST(ArrayTest, Sort) {
_TestSort<int64_t>(GPU);
#endif
}
TEST(ArrayTest, BFloatCast) {
for (int i = -100; i < 100; ++i) {
float a = i;
BFloat16 b = a;
float a_casted = b;
ASSERT_FLOAT_EQ(a, a_casted);
}
}
......@@ -105,6 +105,7 @@ void _TestSpmmCopyLhs() {
TEST(SpmmTest, TestSpmmCopyLhs) {
_TestSpmmCopyLhs<float>();
_TestSpmmCopyLhs<double>();
_TestSpmmCopyLhs<BFloat16>();
}
template <typename IDX>
......@@ -130,6 +131,7 @@ void _TestSpmmCopyRhs() {
TEST(SpmmTest, TestSpmmCopyRhs) {
_TestSpmmCopyRhs<float>();
_TestSpmmCopyRhs<double>();
_TestSpmmCopyRhs<BFloat16>();
}
template <typename IDX>
......@@ -156,6 +158,7 @@ void _TestSpmmAdd() {
TEST(SpmmTest, TestSpmmAdd) {
_TestSpmmAdd<float>();
_TestSpmmAdd<double>();
_TestSpmmAdd<BFloat16>();
}
template <typename IDX>
......@@ -182,6 +185,7 @@ void _TestSpmmSub() {
TEST(SpmmTest, TestSpmmSub) {
_TestSpmmSub<float>();
_TestSpmmSub<double>();
_TestSpmmSub<BFloat16>();
}
template <typename IDX>
......@@ -208,6 +212,7 @@ void _TestSpmmMul() {
TEST(SpmmTest, TestSpmmMul) {
_TestSpmmMul<float>();
_TestSpmmMul<double>();
_TestSpmmMul<BFloat16>();
}
template <typename IDX>
......@@ -234,5 +239,6 @@ void _TestSpmmDiv() {
TEST(SpmmTest, TestSpmmDiv) {
_TestSpmmDiv<float>();
_TestSpmmDiv<double>();
_TestSpmmDiv<BFloat16>();
}
#endif // _WIN32
......@@ -176,17 +176,19 @@ def test_spmm(idtype, g, shp, msg, reducer):
dgl.backend.backend_name != "pytorch",
reason="Only support PyTorch for now.",
)
@unittest.skipIf(
F._default_context_str == "cpu",
reason="Don't support half precision on CPU.",
)
@parametrize_idtype
@pytest.mark.parametrize(
"dtype, rtol, atol",
[(torch.float16, 1e-3, 0.5), (torch.bfloat16, 4e-3, 2.0)],
)
def test_half_spmm(idtype, dtype, rtol, atol):
if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
if F._default_context_str == "cpu" and dtype == torch.float16:
pytest.skip("float16 is not supported on CPU.")
if (
F._default_context_str == "gpu"
and dtype == torch.bfloat16
and not torch.cuda.is_bf16_supported()
):
pytest.skip("BF16 is not supported.")
# make sure the spmm result is < 512 to match the rtol/atol we set.
......@@ -195,7 +197,7 @@ def test_half_spmm(idtype, dtype, rtol, atol):
idtype=idtype,
device=F.ctx(),
)
feat_fp32 = torch.rand((g.num_src_nodes(), 32)).to(0)
feat_fp32 = torch.rand((g.num_src_nodes(), 32)).to(F.ctx())
feat_half = feat_fp32.to(dtype)
# test SpMMCSR
......@@ -337,11 +339,8 @@ def test_segment_reduce(reducer):
],
)
def test_segment_mm(idtype, feat_size, dtype, tol):
if F._default_context_str == "cpu" and dtype in (
torch.float16,
torch.bfloat16,
):
pytest.skip("Only support float32 and float64 on CPU.")
if F._default_context_str == "cpu" and dtype == torch.float16:
pytest.skip("float16 is not supported on CPU.")
if (
F._default_context_str == "gpu"
and dtype == torch.bfloat16
......@@ -397,11 +396,8 @@ def test_segment_mm(idtype, feat_size, dtype, tol):
],
)
def test_gather_mm_idx_b(feat_size, dtype, tol):
if F._default_context_str == "cpu" and dtype in (
torch.float16,
torch.bfloat16,
):
pytest.skip("Only support float32 and float64 on CPU.")
if F._default_context_str == "cpu" and dtype == torch.float16:
pytest.skip("float16 is not supported on CPU.")
if (
F._default_context_str == "gpu"
and dtype == torch.bfloat16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment