Commit 61de5288 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Dev] Support FP8 Codegen for cuda backend (#64)

* [Enhancement] Add VectorizeLoop function and update imports for compatibility

* [CI][Test] Improve test cases for vectorization and fix typos in parser comments

* lint fix

* Fix incorrect module reference for VectorizeLoop transformation

* Refactor vectorize_loop transformation by removing unused extent mutation logic

* [Enhancement] Add support for FP8 data types and global barriers in CUDA codegen

* Fix formatting in CUDA FP8 header file for consistency

* Refactor CI workflow to use 'tilelang_ci' virtual environment and update CUDA type printing for better clarity

* Update submodule 'tvm' to latest commit for improved functionality

* Refactor execution backend references from 'dl_pack' to 'dlpack' for consistency and clarity; add apply_simplify function to simplify PrimFunc or IRModule.

* Refactor CUDA code for improved readability; clean up formatting and remove unnecessary whitespace in multiple files.

* Refactor import statement in test_tilelang_kernel_dequantize_gemm.py to use 'tilelang.language' for consistency

* Add CUDA requirements to FP8 test cases and update references for clarity

* Add a blank line for improved readability in test_tilelang_kernel_fp8_gemm_mma.py

* Fix data type in reference result calculation for consistency in test_tilelang_kernel_gemm_mma_intrinsic.py

* Add CUDA requirements and FP8 test cases for matmul and gemv simulations

* Remove debug print statements and use tilelang's testing assertion for result validation in test_tilelang_kernel_gemm_mma_intrinsic.py

* Remove outdated comment regarding FP8 tests in test_tilelang_kernel_gemv_simt.py
parent 7111239d
...@@ -37,3 +37,8 @@ def simplify_prim_func(func: Callable) -> Callable: ...@@ -37,3 +37,8 @@ def simplify_prim_func(func: Callable) -> Callable:
return _Simplify(stmt) return _Simplify(stmt)
return wrapper return wrapper
def apply_simplify(stmt: Union[PrimFunc, IRModule]) -> Union[PrimFunc, IRModule]:
"""Apply Simplify pass to a PrimFunc or IRModule."""
return _Simplify(stmt)
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
from enum import Enum from enum import Enum
import torch import torch
from tvm.relay import TensorType from tvm.relay import TensorType
from tvm.runtime import ndarray
from torch.utils.dlpack import to_dlpack
class TensorSupplyType(Enum): class TensorSupplyType(Enum):
...@@ -15,10 +17,40 @@ class TensorSupplyType(Enum): ...@@ -15,10 +17,40 @@ class TensorSupplyType(Enum):
One = 6 One = 6
def map_torch_type(intype):
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)
float8_dtype_map = {
torch.float8_e4m3fn: "e4m3_float8",
torch.float8_e4m3fnuz: "e4m3_float8",
torch.float8_e5m2: "e5m2_float8",
torch.float8_e5m2fnuz: "e5m2_float8",
}
def adapt_torch2tvm(arg):
if isinstance(arg, torch.Tensor):
if arg.dtype in {
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz
}:
return ndarray.from_dlpack(to_dlpack(arg.view(torch.int8)))._create_view(
shape=arg.shape, dtype=float8_dtype_map[arg.dtype])
return ndarray.from_dlpack(to_dlpack(arg))
return arg
def get_tensor_supply(supply_type: TensorSupplyType): def get_tensor_supply(supply_type: TensorSupplyType):
def get_tensor(tensor: TensorType) -> torch.Tensor: def get_tensor(tensor: TensorType) -> torch.Tensor:
dtype = torch.__getattribute__(str(tensor.dtype)) dtype = map_torch_type(str(tensor.dtype))
device = torch.cuda.current_device() device = torch.cuda.current_device()
shape = list(map(int, tensor.shape)) shape = list(map(int, tensor.shape))
...@@ -30,8 +62,12 @@ def get_tensor_supply(supply_type: TensorSupplyType): ...@@ -30,8 +62,12 @@ def get_tensor_supply(supply_type: TensorSupplyType):
if supply_type == TensorSupplyType.Integer: if supply_type == TensorSupplyType.Integer:
is_unsigned = tensor.dtype.startswith("uint") is_unsigned = tensor.dtype.startswith("uint")
is_float8 = tensor.dtype.endswith("float8")
if is_unsigned: if is_unsigned:
return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype) return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype)
elif is_float8:
return torch.randint(
low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype)
else: else:
return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype) return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype)
elif supply_type == TensorSupplyType.Uniform: elif supply_type == TensorSupplyType.Uniform:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment