Unverified Commit 92121fc6 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Generalize fp8 process (#1372)

* [Refactor] Update condition for benchmarking in example_gemv.py and simplify cached library path handling in sparse.py

* [Enhancement] Extend support for float8 data types in GEMM operations

- Updated GEMM operations to recognize additional float8 data types: `float8_e4m3fn` and `float8_e5m2fnuz`.
- Refactored condition checks in `checkWgmma` methods to simplify float8 type handling.
- Adjusted test cases to ensure compatibility with the new float8 types in tile language examples.

* lint fix
parent 1da3debf
...@@ -51,7 +51,12 @@ def tl_matmul( ...@@ -51,7 +51,12 @@ def tl_matmul(
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"] is_float8 = in_dtype in [
"float8_e4m3",
"float8_e5m2",
"float8_e4m3fn",
"float8_e5m2fnuz",
]
if out_dtype == "int32" or is_float8: if out_dtype == "int32" or is_float8:
micro_size_k = 32 micro_size_k = 32
......
...@@ -57,7 +57,7 @@ static int to_CUtensorMapDataType(DataType dtype) { ...@@ -57,7 +57,7 @@ static int to_CUtensorMapDataType(DataType dtype) {
} }
} else if (dtype.is_bfloat16()) { } else if (dtype.is_bfloat16()) {
tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
} else if (dtype.is_float8_e4m3() || dtype.is_float8_e5m2()) { } else if (dtype.is_float8()) {
tp = CU_TENSOR_MAP_DATA_TYPE_UINT8; tp = CU_TENSOR_MAP_DATA_TYPE_UINT8;
} else if (dtype.is_int()) { } else if (dtype.is_int()) {
switch (dtype.bits()) { switch (dtype.bits()) {
......
...@@ -361,13 +361,7 @@ bool GemmNode::checkWgmma() const { ...@@ -361,13 +361,7 @@ bool GemmNode::checkWgmma() const {
if (c_->dtype == DataType::Float(16)) { if (c_->dtype == DataType::Float(16)) {
if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16)) if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
return k_ % 16 == 0; return k_ % 16 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3()) else if (a_->dtype.is_float8() && b_->dtype.is_float8())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0; return (!transA_) && transB_ && k_ % 32 == 0;
else else
return false; return false;
...@@ -380,13 +374,7 @@ bool GemmNode::checkWgmma() const { ...@@ -380,13 +374,7 @@ bool GemmNode::checkWgmma() const {
else if (a_->dtype == DataType::Float(32) && else if (a_->dtype == DataType::Float(32) &&
b_->dtype == DataType::Float(32)) b_->dtype == DataType::Float(32))
return (!transA_) && transB_ && k_ % 8 == 0; return (!transA_) && transB_ && k_ % 8 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3()) else if (a_->dtype.is_float8() && b_->dtype.is_float8())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0; return (!transA_) && transB_ && k_ % 32 == 0;
else else
return false; return false;
......
...@@ -182,13 +182,7 @@ bool GemmPyNode::checkWgmma() const { ...@@ -182,13 +182,7 @@ bool GemmPyNode::checkWgmma() const {
if (c_->dtype == DataType::Float(16)) { if (c_->dtype == DataType::Float(16)) {
if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16)) if (a_->dtype == DataType::Float(16) && b_->dtype == DataType::Float(16))
return k_ % 16 == 0; return k_ % 16 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3()) else if (a_->dtype.is_float8() && b_->dtype.is_float8())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0; return (!transA_) && transB_ && k_ % 32 == 0;
else else
return false; return false;
...@@ -201,13 +195,7 @@ bool GemmPyNode::checkWgmma() const { ...@@ -201,13 +195,7 @@ bool GemmPyNode::checkWgmma() const {
else if (a_->dtype == DataType::Float(32) && else if (a_->dtype == DataType::Float(32) &&
b_->dtype == DataType::Float(32)) b_->dtype == DataType::Float(32))
return (!transA_) && transB_ && k_ % 8 == 0; return (!transA_) && transB_ && k_ % 8 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e4m3()) else if (a_->dtype.is_float8() && b_->dtype.is_float8())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e4m3() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e4m3())
return (!transA_) && transB_ && k_ % 32 == 0;
else if (a_->dtype.is_float8_e5m2() && b_->dtype.is_float8_e5m2())
return (!transA_) && transB_ && k_ % 32 == 0; return (!transA_) && transB_ && k_ % 32 == 0;
else else
return false; return false;
......
...@@ -52,10 +52,8 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) { ...@@ -52,10 +52,8 @@ GetTCGEN5MMAMeta(int M, int N, int K, DataType ab_dtype, DataType c_dtype) {
} else { } else {
FAIL; FAIL;
} }
} else if ((ab_dtype.is_float8_e4m3fn() || ab_dtype.is_float8_e4m3() || } else if ((ab_dtype.is_float8() || ab_dtype.is_float6_e2m3fn() ||
ab_dtype.is_float8_e5m2() || ab_dtype.is_float8_e5m2fnuz() || ab_dtype.is_float6_e3m2fn() || ab_dtype.is_float4_e2m1fn()) &&
ab_dtype.is_float6_e2m3fn() || ab_dtype.is_float6_e3m2fn() ||
ab_dtype.is_float4_e2m1fn()) &&
((c_dtype.is_float() && c_dtype.bits() == 32) || ((c_dtype.is_float() && c_dtype.bits() == 32) ||
(c_dtype.is_float16() && c_dtype.bits() == 16))) { (c_dtype.is_float16() && c_dtype.bits() == 16))) {
if (K % 32 != 0) if (K % 32 != 0)
......
...@@ -52,7 +52,12 @@ def tl_matmul( ...@@ -52,7 +52,12 @@ def tl_matmul(
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"] is_float8 = in_dtype in [
"float8_e4m3",
"float8_e5m2",
"float8_e4m3fn",
"float8_e5m2fnuz",
]
if out_dtype == "int32" or is_float8: if out_dtype == "int32" or is_float8:
micro_size_k = 32 micro_size_k = 32
......
...@@ -51,7 +51,12 @@ def tl_matmul( ...@@ -51,7 +51,12 @@ def tl_matmul(
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"] is_float8 = in_dtype in [
"float8_e4m3",
"float8_e5m2",
"float8_e4m3fn",
"float8_e5m2fnuz",
]
if out_dtype == "int32" or is_float8: if out_dtype == "int32" or is_float8:
micro_size_k = 32 micro_size_k = 32
......
...@@ -52,7 +52,12 @@ def tl_matmul( ...@@ -52,7 +52,12 @@ def tl_matmul(
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in ["float8_e4m3", "float8_e5m2"] is_float8 = in_dtype in [
"float8_e4m3",
"float8_e5m2",
"float8_e4m3fn",
"float8_e5m2fnuz",
]
if out_dtype == "int32" or is_float8: if out_dtype == "int32" or is_float8:
micro_size_k = 32 micro_size_k = 32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment