Unverified Commit c70b2697 authored by Tong WU's avatar Tong WU Committed by GitHub
Browse files

[BugFix] Implement bfloat16 support in CUDA code generation with min/max...

[BugFix] Implement bfloat16 support in CUDA code generation with min/max functions and inf/nan values (#1143)

* Implement bfloat16 support in CUDA code generation with min/max functions and inf/nan values

* refactor

* fix prev typo

* bugfix

* lint

* bugfix
parent bc773c56
...@@ -1017,6 +1017,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { ...@@ -1017,6 +1017,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<< "))+1), __NV_SATFINITE, " << "))+1), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2") << (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n"; << ");\n";
os << sret;
return;
} }
} }
...@@ -1034,6 +1036,52 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { ...@@ -1034,6 +1036,52 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
os << sret; os << sret;
} }
void CodeGenTileLangCUDA::VisitExpr_(const MinNode *op, std::ostream &os) {
// TODO(wt): Consider vectorized reduction and impl for other dtypes
DataType t = op->dtype;
// Standard min/max functions don't support bfloat16 or float16
if ((t.is_bfloat16() || t.is_float16()) && t.is_scalar()) {
os << "cutlass::fast_min(" << PrintExpr(op->a) << ", " << PrintExpr(op->b)
<< ")";
return;
}
// For float32 and float64 scalar, use standard min functions
if (t.is_float() && t.is_scalar()) {
if (t.bits() == 32 || t.bits() == 64) {
os << "min(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")";
return;
}
}
// For all other scalar types (int, uint), use default implementation
CodeGenC::VisitExpr_(op, os);
}
void CodeGenTileLangCUDA::VisitExpr_(const MaxNode *op, std::ostream &os) {
// TODO(wt): Consider vectorized reduction and impl for other dtypes
DataType t = op->dtype;
// Standard min/max functions don't support bfloat16 or float16
if ((t.is_bfloat16() || t.is_float16()) && t.is_scalar()) {
os << "cutlass::fast_max(" << PrintExpr(op->a) << ", " << PrintExpr(op->b)
<< ")";
return;
}
// For float32 and float64 scalar, use standard max functions
if (t.is_float() && t.is_scalar()) {
if (t.bits() == 32 || t.bits() == 64) {
os << "max(" << PrintExpr(op->a) << ", " << PrintExpr(op->b) << ")";
return;
}
}
// For all other scalar types (int, uint), use default implementation
CodeGenC::VisitExpr_(op, os);
}
void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol, void CodeGenTileLangCUDA::PrintCallExtern(Type ret_type, String global_symbol,
const Array<PrimExpr> &args, const Array<PrimExpr> &args,
bool skip_first_arg, bool skip_first_arg,
...@@ -2540,12 +2588,29 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op, ...@@ -2540,12 +2588,29 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
inline void PrintConst(const FloatImmNode *op, std::ostream &os, inline void PrintConst(const FloatImmNode *op, std::ostream &os,
CodeGenTileLangCUDA *p) { // NOLINT(*) CodeGenTileLangCUDA *p) { // NOLINT(*)
// Type code is kBFloat // Type code is kBFloat/kFloat16
if (op->dtype.is_bfloat16()) { // which is indeed CUTLASS supported types currently
os << "bfloat16_t"; if (op->dtype.is_bfloat16() || op->dtype.is_float16()) {
os << '(' << std::hexfloat << op->value << 'f'; std::ostringstream temp;
os << "/*" << std::scientific << op->value << "*/"; if (std::isinf(op->value)) {
os << ')'; if (op->value < 0) {
temp << "-";
}
temp << "std::numeric_limits<";
p->PrintType(op->dtype, temp);
temp << ">::infinity()";
} else if (std::isnan(op->value)) {
temp << "std::numeric_limits<";
p->PrintType(op->dtype, temp);
temp << ">::quiet_NaN()";
} else {
p->PrintType(op->dtype, temp);
temp << '(' << std::hexfloat << op->value << 'f';
temp << "/*" << std::scientific << op->value << "*/";
temp << ')';
}
p->MarkConst(temp.str());
os << temp.str();
return; return;
} }
// Type code is kFloat8_e5m2 or kE4M4Float // Type code is kFloat8_e5m2 or kE4M4Float
...@@ -2556,7 +2621,7 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os, ...@@ -2556,7 +2621,7 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
os << ')'; os << ')';
return; return;
} }
// Type code is kFloat // Type code is kFloat64/kFloat32 (kFloat16 is handled above)
switch (op->dtype.bits()) { switch (op->dtype.bits()) {
case 64: case 64:
case 32: { case 32: {
...@@ -2580,13 +2645,6 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os, ...@@ -2580,13 +2645,6 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
os << temp.str(); os << temp.str();
break; break;
} }
case 16: {
os << "half_t" << '(';
FloatImm const_f32 = FloatImm(DataType::Float(32), op->value);
PrintConst(const_f32.get(), os, p);
os << ')';
break;
}
default: default:
LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n"; LOG(FATAL) << "Bad bit-width for float: " << op->dtype << "\n";
} }
......
...@@ -51,6 +51,8 @@ public: ...@@ -51,6 +51,8 @@ public:
void VisitExpr_(const FloatImmNode *op, std::ostream &os) final; void VisitExpr_(const FloatImmNode *op, std::ostream &os) final;
void VisitExpr_(const CallNode *op, std::ostream &os) final; void VisitExpr_(const CallNode *op, std::ostream &os) final;
void VisitExpr_(const CastNode *op, std::ostream &os) final; void VisitExpr_(const CastNode *op, std::ostream &os) final;
void VisitExpr_(const MinNode *op, std::ostream &os) final;
void VisitExpr_(const MaxNode *op, std::ostream &os) final;
void VisitStmt_(const EvaluateNode *op) final; void VisitStmt_(const EvaluateNode *op) final;
void VisitStmt_(const AllocateNode *op) final; void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode *op) final; void VisitStmt_(const AttrStmtNode *op) final;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment