Commit 89725f7f authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Enhance FP8/FP4 type handling in CUDA codegen (#323)



* [Enhancement] Introduce CUDA driver module and refactor CUDA device handling

- Added a new `cuda_driver` module to encapsulate CUDA device properties and functionalities.
- Updated `CUDA` class in `cuda.py` to utilize the new driver for fetching device name and shared memory capabilities.
- Introduced `get_device_name` and `get_shared_memory_per_block` functions in the `cuda_driver` for improved device property management.
- This refactor enhances code organization and maintainability while improving the handling of CUDA device attributes.

* [Refactor] Clean up whitespace in CUDA-related files

- Removed unnecessary blank lines in `cuda.py`, `__init__.py`, and `cuda_driver.py` to improve code readability and maintainability.
- This change enhances the overall organization of the codebase without altering functionality.

* [Benchmark] Add FP8 Matrix Multiplication Benchmark Script

- Introduced a new benchmark script for FP8 matrix multiplication in `benchmark/matmul_fp8/benchmark_matmul.py`.
- The script includes functions for reference matrix multiplication, configuration generation for autotuning, and an autotuned kernel for performance measurement.
- Added command-line argument parsing for matrix dimensions and the option to enable BitBLAS roller for search space exploration.
- The benchmark computes and prints the best latency and performance metrics, enhancing the benchmarking capabilities for FP8 operations.

* lint fix

* Update submodule and enhance FP8 type handling in CUDA codegen

- Updated the TVM submodule to the latest commit.
- Modified FP8 type handling in `codegen_cuda.cc` to use more descriptive type codes.
- Improved constant printing for FP8 and bfloat16 types, ensuring correct representation in generated code.
- Added error handling for missing configuration keys in the AutoTuner class.

* lint fix

* Remove print statement from example script

* lint fix

* fix

---------
Co-authored-by: default avatarLeiWang1999 <wyatuestc@gmail.com>
parent 8fdfdf03
Subproject commit edd35139a0481e9359aa269e3e50450b95ba2f5a Subproject commit 3aca64b65b4c1586031da076ffbaa65280f21dac
...@@ -41,7 +41,6 @@ def test_gemm_fp8(M, N, K, dtype): ...@@ -41,7 +41,6 @@ def test_gemm_fp8(M, N, K, dtype):
func = matmul(M, N, K, 128, 128, 64, dtype) func = matmul(M, N, K, 128, 128, 64, dtype)
kernel = tilelang.compile(func, out_idx=-1) kernel = tilelang.compile(func, out_idx=-1)
a = torch.randn(M, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype) a = torch.randn(M, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype)
b = torch.randn(N, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype) b = torch.randn(N, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype)
......
...@@ -39,9 +39,9 @@ static std::string GetFP8Type(DataType type) { ...@@ -39,9 +39,9 @@ static std::string GetFP8Type(DataType type) {
LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) " LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) "
"for FP8"; "for FP8";
} }
if (type.code() == DataType::kE4M3Float) { if (type.code() == DataType::kFloat8_e4m3fn) {
stream << "fp8_e4" << vec << "_t"; stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kE5M2Float) { } else if (type.code() == DataType::kFloat8_e5m2) {
stream << "fp8_e5" << vec << "_t"; stream << "fp8_e5" << vec << "_t";
} else { } else {
LOG(FATAL) << "Unsupported FP8 type in CUDA codegen"; LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
...@@ -1475,6 +1475,12 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os, ...@@ -1475,6 +1475,12 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
os << '(' << std::scientific << op->value << 'f' << ')'; os << '(' << std::scientific << op->value << 'f' << ')';
return; return;
} }
// Type code is kFloat8_e5m2 or kE4M4Float
if (op->dtype.is_float8() || op->dtype.is_float4()) {
p->PrintType(op->dtype, os);
os << '(' << std::scientific << op->value << 'f' << ')';
return;
}
// Type code is kFloat // Type code is kFloat
switch (op->dtype.bits()) { switch (op->dtype.bits()) {
case 64: case 64:
...@@ -1485,8 +1491,10 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os, ...@@ -1485,8 +1491,10 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
temp << "-"; temp << "-";
} }
temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF"); temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF");
p->need_math_constants_h_ = true;
} else if (std::isnan(op->value)) { } else if (std::isnan(op->value)) {
temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN"); temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
p->need_math_constants_h_ = true;
} else { } else {
temp << std::scientific << op->value; temp << std::scientific << op->value;
if (op->dtype.bits() == 32) if (op->dtype.bits() == 32)
......
...@@ -259,6 +259,8 @@ class AutoTuner: ...@@ -259,6 +259,8 @@ class AutoTuner:
if name not in keys: if name not in keys:
new_args.append(value) new_args.append(value)
else: else:
if name not in config:
raise ValueError(f"Configuration {config} does not contain key {name}")
new_args.append(config[name]) new_args.append(config[name])
new_args = tuple(new_args) new_args = tuple(new_args)
config_args.append(new_args) config_args.append(new_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment