Unverified Commit 7c61d31a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Legalize Datatype for mma intrinisc codegen (#1179)

* fix

* lint fix

* Enhance CUDA code generation by updating register type handling for float data types. Introduced a workaround for TF32 type compatibility and improved the registration of MMA register types for A and B operands.
parent d99853b6
......@@ -1749,6 +1749,9 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
"reinterpret_cast<const (ARegType)*>((A_ptr) + (A_offset)), "
"reinterpret_cast<const (BRegType)*>((B_ptr) + (B_offset)));\n";
tl::codegen::Replacer replacer;
// TODO(lei): Type Workaround for TF32, should be removed when
// we introduced tfloat32_t in the frontend.
std::string AType = tl::codegen::ptx::DTypeEnumToString(dtype_a_enum);
if (AType == "tl::DataType::kFloat32") {
AType = "tl::DataType::kTensorFloat32";
......@@ -1757,11 +1760,17 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
if (BType == "tl::DataType::kFloat32") {
BType = "tl::DataType::kTensorFloat32";
}
std::string ARegType = tl::codegen::GetMMARegisterType(dtype_a_enum);
if (ARegType == "float") {
ARegType = "uint32_t";
}
std::string BRegType = tl::codegen::GetMMARegisterType(dtype_b_enum);
if (BRegType == "float") {
BRegType = "uint32_t";
}
replacer.register_rule("(AType)",
tl::codegen::ptx::DTypeEnumToString(AType));
replacer.register_rule("(BType)",
tl::codegen::ptx::DTypeEnumToString(BType));
replacer.register_rule("(AType)", AType);
replacer.register_rule("(BType)", BType);
replacer.register_rule("(CType)",
tl::codegen::ptx::DTypeEnumToString(dtype_c_enum));
replacer.register_rule("(M)", std::to_string(m));
......@@ -1769,10 +1778,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
replacer.register_rule("(K)", std::to_string(k));
replacer.register_rule("(TransA)", A_layout == "row" ? "false" : "true");
replacer.register_rule("(TransB)", B_layout == "row" ? "false" : "true");
replacer.register_rule("(ARegType)",
tl::codegen::GetMMARegisterType(dtype_a_enum));
replacer.register_rule("(BRegType)",
tl::codegen::GetMMARegisterType(dtype_b_enum));
replacer.register_rule("(ARegType)", ARegType);
replacer.register_rule("(BRegType)", BRegType);
replacer.register_rule("(CRegType)",
tl::codegen::GetMMARegisterType(dtype_c_enum));
replacer.register_rule("(A_ptr)", a_ref);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment