Commit 89725f7f authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Enhance FP8/FP4 type handling in CUDA codegen (#323)



* [Enhancement] Introduce CUDA driver module and refactor CUDA device handling

- Added a new `cuda_driver` module to encapsulate CUDA device properties and functionalities.
- Updated `CUDA` class in `cuda.py` to utilize the new driver for fetching device name and shared memory capabilities.
- Introduced `get_device_name` and `get_shared_memory_per_block` functions in the `cuda_driver` for improved device property management.
- This refactor enhances code organization and maintainability while improving the handling of CUDA device attributes.

* [Refactor] Clean up whitespace in CUDA-related files

- Removed unnecessary blank lines in `cuda.py`, `__init__.py`, and `cuda_driver.py` to improve code readability and maintainability.
- This change enhances the overall organization of the codebase without altering functionality.

* [Benchmark] Add FP8 Matrix Multiplication Benchmark Script

- Introduced a new benchmark script for FP8 matrix multiplication in `benchmark/matmul_fp8/benchmark_matmul.py`.
- The script includes functions for reference matrix multiplication, configuration generation for autotuning, and an autotuned kernel for performance measurement.
- Added command-line argument parsing for matrix dimensions and the option to enable BitBLAS roller for search space exploration.
- The benchmark computes and prints the best latency and performance metrics, enhancing the benchmarking capabilities for FP8 operations.

* lint fix

* Update submodule and enhance FP8 type handling in CUDA codegen

- Updated the TVM submodule to the latest commit.
- Modified FP8 type handling in `codegen_cuda.cc` to use more descriptive type codes.
- Improved constant printing for FP8 and bfloat16 types, ensuring correct representation in generated code.
- Added error handling for missing configuration keys in the AutoTuner class.

* lint fix

* Remove print statement from example script

* lint fix

* fix

---------
Co-authored-by: default avatarLeiWang1999 <wyatuestc@gmail.com>
parent 8fdfdf03
Subproject commit edd35139a0481e9359aa269e3e50450b95ba2f5a
Subproject commit 3aca64b65b4c1586031da076ffbaa65280f21dac
......@@ -41,7 +41,6 @@ def test_gemm_fp8(M, N, K, dtype):
func = matmul(M, N, K, 128, 128, 64, dtype)
kernel = tilelang.compile(func, out_idx=-1)
a = torch.randn(M, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype)
b = torch.randn(N, K, dtype=torch.float16, device='cuda').to(dtype=torch_dtype)
......
......@@ -39,9 +39,9 @@ static std::string GetFP8Type(DataType type) {
LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) "
"for FP8";
}
if (type.code() == DataType::kE4M3Float) {
if (type.code() == DataType::kFloat8_e4m3fn) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kE5M2Float) {
} else if (type.code() == DataType::kFloat8_e5m2) {
stream << "fp8_e5" << vec << "_t";
} else {
LOG(FATAL) << "Unsupported FP8 type in CUDA codegen";
......@@ -1475,6 +1475,12 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
os << '(' << std::scientific << op->value << 'f' << ')';
return;
}
// Type code is kFloat8_e5m2 or kE4M4Float
if (op->dtype.is_float8() || op->dtype.is_float4()) {
p->PrintType(op->dtype, os);
os << '(' << std::scientific << op->value << 'f' << ')';
return;
}
// Type code is kFloat
switch (op->dtype.bits()) {
case 64:
......@@ -1485,8 +1491,10 @@ inline void PrintConst(const FloatImmNode *op, std::ostream &os,
temp << "-";
}
temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF");
p->need_math_constants_h_ = true;
} else if (std::isnan(op->value)) {
temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
p->need_math_constants_h_ = true;
} else {
temp << std::scientific << op->value;
if (op->dtype.bits() == 32)
......
......@@ -259,6 +259,8 @@ class AutoTuner:
if name not in keys:
new_args.append(value)
else:
if name not in config:
raise ValueError(f"Configuration {config} does not contain key {name}")
new_args.append(config[name])
new_args = tuple(new_args)
config_args.append(new_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment