Commit a686f0f1 authored by Yu Cheng's avatar Yu Cheng Committed by LeiWang1999
Browse files

[Enhancement] Update group_per_split_token_cast_to_fp8 to support multiple data types (#356)

- Modified the `group_per_split_token_cast_to_fp8` function to support `bfloat16`, `float`, and `float16` data types.
- Updated local fragment allocations to use the new `accum_dtype` for consistency.
- Enhanced the main execution block to handle different tensor data types based on the specified `dtype`, improving flexibility in tensor operations.
parent c58cbfbb
......@@ -4,9 +4,12 @@ import tilelang.language as T
from typing import Tuple
from tilelang.utils.tensor import torch_assert_close
# support bfloat16, float, float16
dtype = "bfloat16"
accum_dtype = "float"
def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
dtype = "float"
group_size = 128
fp8_min = -448.0
fp8_max = 448.0
......@@ -14,16 +17,16 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
@T.prim_func
def main(X: T.Tensor((M, N), dtype), batch_sizes: T.Tensor((BG,), "int32"), X_fp8: T.Tensor(
(BG, M_max, N), "e4m3_float8"), X_amax: T.Tensor((BG, M_max, T.ceildiv(N, group_size)),
dtype)):
accum_dtype)):
with T.Kernel(
T.ceildiv(M_max, blk_m), T.ceildiv(N, group_size), BG, threads=128) as (bx, by, bz):
row = bx
row_g_id = by
bg = bz
y_local = T.alloc_fragment((blk_m, group_size), dtype)
y_amax_local = T.alloc_fragment((blk_m,), dtype)
y_s_local = T.alloc_fragment((blk_m,), dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), dtype)
y_local = T.alloc_fragment((blk_m, group_size), accum_dtype)
y_amax_local = T.alloc_fragment((blk_m,), accum_dtype)
y_s_local = T.alloc_fragment((blk_m,), accum_dtype)
y_q_local = T.alloc_fragment((blk_m, group_size), accum_dtype)
y_q_local_fp8 = T.alloc_fragment((blk_m, group_size), "e4m3_float8")
row_offset = T.alloc_local((1,), "int32")
......@@ -159,7 +162,14 @@ def ref_program(x: torch.Tensor, batch_sizes: torch.Tensor) -> \
if __name__ == "__main__":
M, N, BG, blk_m = 8192, 8192, 2, 8
x = torch.randn(M, N, device="cuda", dtype=torch.float32)
if dtype == "float":
x = torch.randn(M, N, device="cuda", dtype=torch.float32)
elif dtype == "float16":
x = torch.randn(M, N, device="cuda", dtype=torch.float16)
elif dtype == "bfloat16":
x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16)
else:
raise ValueError(f"Unsupported dtype: {dtype}")
batch_sizes = torch.tensor([2048, 6144], device="cuda", dtype=torch.int32)
M_max = int(ceil_div(batch_sizes.max(), 128) * 128)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment