Unverified Commit c70b2697 authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[BugFix] Implement bfloat16 support in CUDA code generation with min/max...

[BugFix] Implement bfloat16 support in CUDA code generation with min/max functions and inf/nan values (#1143)

* Implement bfloat16 support in CUDA code generation with min/max functions and inf/nan values

* refactor

* fix prev typo

* bugfix

* lint

* bugfix
parent bc773c56
......@@ -1017,6 +1017,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<< "))+1), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
}
}
......@@ -1034,6 +1036,52 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
os << sret;
}
void CodeGenTileLangCUDA::VisitExpr_(const MinNode *op, std::ostream &os) {
// TODO(wt): Consider vectorized reduction and impl for other dtypes
DataType t = op->dtype;
// Standard min/max functions don't support bfloat16 or float16
if ((t.is_bfloat16() || t.is_float16()) && t.is_scalar()) {
os << "cutlass::fast_min(" << PrintExpr(op->a) << ", " << PrintExpr(op->b)
<< ")";
return;
}
// For float32 and float64 scalar, use standard min functions
if (t.is_float() && t.is_scalar()) {
if (t.bits() == 32 || t.bits() == 64) {
os << "min(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")";
return;
}
}
// For all other scalar types (int, uint), use default implementation
CodeGenC::VisitExpr_(op, os);
}
void CodeGenTileLangCUDA::VisitExpr_(const MaxNode *op, std::ostream &os) {
// TODO(wt): Consider vectorized reduction and impl for other dtypes
DataType t = op->dtype;
// Standard min/max functions don't support bfloat16 or float16
if ((t.is_bfloat16() || t.is_float16()) && t.is_scalar()) {
os << "cutlass::fast_max(" << PrintExpr(op->a) << ", " << PrintExpr(op->b)
<< ")";
return;
}
// For float32 and float64 scalar, use standard max functions
if (t.is_float() && t.is_scalar()) {
if (t.bits() == 32 || t.bits() == 64) {
os << "max(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")";
return;
}
}
// For all other scalar types (int, uint), use default implementation
CodeGenC::VisitExpr_(op, os);
}
void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args,
bool skip_first_arg,
......@@ -2540,12 +2588,29 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
inline void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p) { // NOLINT(*)
// Type code is kBFloat
if (op->dtype.is_bfloat16()) {
os << "bfloat16_t";
os << '(' << std::hexfloat << op->value << 'f';
os << "/*" << std::scientific << op->value << "*/";
os << ')';
// Type code is kBFloat/kFloat16
// which is indeed CUTLASS supported types currently
if (op->dtype.is_bfloat16() || op->dtype.is_float16()) {
std::ostringstream temp;
if (std::isinf(op->value)) {
if (op->value < 0) {
temp << "-";
}
temp << "std::numeric_limits<";
p->PrintType(op->dtype, temp);
temp << ">::infinity()";
} else if (std::isnan(op->value)) {
temp << "std::numeric_limits<";
p->PrintType(op->dtype, temp);
temp << ">::quiet_NaN()";
} else {
p->PrintType(op->dtype, temp);
temp << '(' << std::hexfloat << op->value << 'f';
temp << "/*" << std::scientific << op->value << "*/";
temp << ')';
}
p->MarkConst(temp.str());
os << temp.str();
return;
}
// Type code is kFloat8_e5m2 or kE4M4Float
......@@ -2556,7 +2621,7 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
os << ')';
return;
}
// Type code is kFloat
// Type code is kFloat64/kFloat32 (kFloat16 is handled above)
switch (op->dtype.bits()) {
case 64:
case 32: {
......@@ -2580,13 +2645,6 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
os << temp.str();
break;
}
case 16: {
os << "half_t" << '(';
FloatImm const_f32 = FloatImm(DataType::Float(32), op->value);
PrintConst(const_f32.get(), os, p);
os << ')';
break;
}
default:
LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
}
......
......@@ -51,6 +51,8 @@ public:
void VisitExpr_(const FloatImmNode *op, std::ostream &os) final;
void VisitExpr_(const CallNode *op, std::ostream &os) final;
void VisitExpr_(const CastNode *op, std::ostream &os) final;
void VisitExpr_(const MinNode *op, std::ostream &os) final;
void VisitExpr_(const MaxNode *op, std::ostream &os) final;
void VisitStmt_(const EvaluateNode *op) final;
void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode *op) final;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment