Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
......@@ -21,12 +21,12 @@ def tl_fused_chunk_bwd_kernel(
H,
DK,
DV,
dtype: str = "float16",
dtype: T.dtype = T.float16,
scale: float = None,
) -> torch.Tensor:
if scale is None:
scale = DK**-0.5
accum_dtype = "float"
accum_dtype = T.float32
chunk_size = 64
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
......
......@@ -22,12 +22,12 @@ def tl_fused_chunk_fwd_kernel(
H,
DK,
DV,
dtype: str = "float16",
dtype: T.dtype = T.float16,
scale: float = None,
) -> torch.Tensor:
if scale is None:
scale = DK**-0.5
accum_dtype = "float"
accum_dtype = T.float32
chunk_size = 64
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
......
......@@ -89,8 +89,8 @@ def chunk_scan_fwd(
num_stages=2,
threads=128,
):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504
......
......@@ -55,8 +55,8 @@ def get_configs():
def chunk_state_fwd(
batch, seqlen, chunk_size, ngroups, nheads, headdim, dstate, block_M=64, block_N=64, block_K=64, num_stages=2, threads=128
):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
nchunks = T.ceildiv(seqlen, chunk_size)
p = 1.44269504
......
......@@ -13,12 +13,12 @@ def chunk_retention_fwd_kernel(
H,
DK,
DV,
dtype: str = "float16",
dtype: T.dtype = T.float16,
scale: float = None,
) -> torch.Tensor:
if scale is None:
scale = DK**-0.5
accum_dtype = "float"
accum_dtype = T.float32
chunk_size = 64
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
......@@ -37,7 +37,7 @@ def chunk_retention_fwd_kernel(
with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh):
i_b = i_bh // H
i_h = i_bh % H
log_decay = T.alloc_var("float32")
log_decay = T.alloc_var(T.float32)
log_decay = T.log2(1 - T.exp2(-5.0 - 1.0 * i_h)) # Head-specific log decay
q = T.alloc_shared([chunk_size, BK], dtype)
......
......@@ -31,9 +31,9 @@ def _tl_vs_sparse_flashattn(batch, heads, seq_len, dim, vertical_size, slash_siz
vertical_size_round, slash_size_round = tilelang.next_power_of_2(vertical_size), tilelang.next_power_of_2(slash_size)
dtype = "float16"
accum_dtype = "float"
int_dtype = "int32"
dtype = T.float16
accum_dtype = T.float32
int_dtype = T.int32
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
......
......@@ -4,7 +4,7 @@ import tilelang.language as T
def rms_norm_splitk(M, N, blk_m, blk_k):
dtype = "float"
dtype = T.float
@T.prim_func
def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
......@@ -35,7 +35,7 @@ def rms_norm_splitk(M, N, blk_m, blk_k):
@tilelang.jit(out_idx=[-1], pass_configs={"tl.disable_tma_lower": True})
def rms_norm(M, N, blk_m):
dtype = "float"
dtype = T.float
@T.prim_func
def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
......
......@@ -5,7 +5,7 @@ import tilelang.language as T
def rms_norm_splitk(M, N, blk_m, blk_k):
dtype = "float"
dtype = T.float
@T.prim_func
def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
......@@ -35,7 +35,7 @@ def rms_norm_splitk(M, N, blk_m, blk_k):
def rms_norm(M, N, blk_m):
dtype = "float"
dtype = T.float
@T.prim_func
def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
......
......@@ -9,12 +9,12 @@ from typing import Callable
def softmax_kernel(
M,
N,
dtype: str = "float16",
dtype: T.dtype = T.float16,
) -> "Callable":
BN = min(tl.next_power_of_2(N), 8192)
NN = tl.cdiv(N, BN)
accum_dtype = "float"
accum_dtype = T.float32
scale = 1.44269504 # log2(e)
......
......@@ -10,7 +10,7 @@ from typing import Literal, Callable
from tilelang.intrinsics.utils import get_mma_micro_size
from tilelang.tools import plot_layout
def make_mma_load_base_layout(dtype: str = "float16",
def make_mma_load_base_layout(dtype: str = T.float16,
matrix: Literal["A", "B"] = "A",
transposed: bool = False) -> T.Fragment:
"""
......@@ -69,7 +69,7 @@ def make_mma_load_base_layout(dtype: str = "float16",
micro_size_s, _, micro_size_r = get_mma_micro_size(dtype)
transform_func = transform_func
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32)
def forward_thread(i: int, j: int) -> int:
"""
......@@ -94,7 +94,7 @@ def make_mma_load_base_layout(dtype: str = "float16",
# Create a 16×16 matrix layout for ldmatrix operations
base_layout = make_mma_load_base_layout(dtype="float16", matrix="A", transposed=False)
base_layout = make_mma_load_base_layout(dtype=T.float16, matrix="A", transposed=False)
# Print the layout structure (optional for debugging)
print(base_layout)
......
......@@ -12,7 +12,7 @@ from tilelang.intrinsics.mfma_layout import (
def make_mfma_load_base_layout(
dtype: str = "float16", matrix: Literal["A", "B"] = "A", k_dim: int = 16, transposed: bool = False
dtype: T.dtype = T.float16, matrix: Literal["A", "B"] = "A", k_dim: int = 16, transposed: bool = False
) -> T.Fragment:
"""
Create a layout function for storing MFMA results into a fragment buffer.
......@@ -79,7 +79,7 @@ def make_mfma_load_base_layout(
else:
raise ValueError(f"Unsupported matrix {matrix}")
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32)
def forward_thread(i: int, j: int) -> int:
"""
......@@ -112,7 +112,7 @@ chunk = 2
from tilelang.tools import plot_layout
# ldmatrix layout 16x16
base_layout = make_mfma_load_base_layout(dtype="float16", matrix="A", transposed=False)
base_layout = make_mfma_load_base_layout(dtype=T.float16, matrix="A", transposed=False)
print(base_layout)
plot_layout(base_layout, name="base_layout")
......
......@@ -5,7 +5,7 @@ from tvm.tir import IndexMap
from tilelang.intrinsics.utils import get_mma_micro_size
def make_mma_load_base_layout(dtype: str = "float16", matrix: Literal["A", "B"] = "A", transposed: bool = False) -> T.Fragment:
def make_mma_load_base_layout(dtype: T.dtype = T.float16, matrix: Literal["A", "B"] = "A", transposed: bool = False) -> T.Fragment:
"""
Create a layout function for storing MMA results into a fragment buffer.
This layout is used in conjunction with `inverse_mma_store_layout` to
......@@ -74,7 +74,7 @@ def make_mma_load_base_layout(dtype: str = "float16", matrix: Literal["A", "B"]
else:
raise ValueError(f"Unsupported matrix {matrix}")
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype=T.int32)
def forward_thread(i: int, j: int) -> int:
"""
......@@ -107,7 +107,7 @@ chunk = 2
from tilelang.tools import plot_layout
# ldmatrix layout 16x16
base_layout = make_mma_load_base_layout(dtype="float16", matrix="A", transposed=False)
base_layout = make_mma_load_base_layout(dtype=T.float16, matrix="A", transposed=False)
print(base_layout)
plot_layout(base_layout, name="base_layout")
......
......@@ -6,7 +6,7 @@ import tilelang.language as T
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
......
......@@ -42,9 +42,9 @@ def blocksparse_flashattn(batch, heads, seq_q, seq_kv, dim, downsample_len, is_c
kv_shape = [batch, heads, seq_kv, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len]
dtype = "float16"
accum_dtype = "float"
block_mask_dtype = "int8"
dtype = T.float16
accum_dtype = T.float32
block_mask_dtype = T.int8
def kernel_func(block_M, block_N, num_stages, threads):
@T.macro
......
......@@ -2,6 +2,7 @@ import torch
import tilelang
from tilelang.utils.sparse import compress_sm90
from tilelang.layout import make_cutlass_metadata_layout
from tilelang import language as T
import tilelang.testing
......@@ -24,8 +25,6 @@ def matmul_sp(
A_shared_shape = (block_M, block_K // 2)
B_shared_shape = (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
......@@ -40,8 +39,8 @@ def matmul_sp(
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.annotate_layout(
{
E: make_cutlass_metadata_layout(E, mma_dtype="float16", arch="9.0", block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype="float16", arch="9.0", block_k=block_K),
E: make_cutlass_metadata_layout(E, mma_dtype=T.float16, arch="9.0", block_k=block_K),
E_shared: make_cutlass_metadata_layout(E_shared, mma_dtype=T.float16, arch="9.0", block_k=block_K),
}
)
T.clear(C_local)
......@@ -111,7 +110,7 @@ def run_gemm_sp(
def main():
run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 128, 128, 128, 2, 128)
run_gemm_sp(512, 1024, 768, T.float16, T.float16, T.float32, 128, 128, 128, 2, 128)
if __name__ == "__main__":
......
......@@ -22,19 +22,19 @@ def tl_topk(
blk_m,
threads=128,
):
dtype = "float32"
dtype = T.float32
@T.prim_func
def topk_kernel(
logits: T.Tensor([M, N], dtype),
topk_gates: T.Tensor([M, topk], dtype),
topk_indices: T.Tensor([M, topk], "int32"),
topk_indices: T.Tensor([M, topk], T.int32),
):
with T.Kernel(T.ceildiv(M, blk_m), threads=threads) as bx:
logits_frag = T.alloc_fragment([blk_m, N], dtype=dtype)
max_val = T.alloc_fragment([blk_m], dtype=dtype)
expand_max_idx = T.alloc_fragment([blk_m, N], "int32")
max_idx = T.alloc_fragment([blk_m], "int32")
expand_max_idx = T.alloc_fragment([blk_m, N], T.int32)
max_idx = T.alloc_fragment([blk_m], T.int32)
T.copy(logits[bx * blk_m, 0], logits_frag)
......
......@@ -10,7 +10,7 @@ import tilelang.language as T
tilelang.PassConfigKey.TL_LAYOUT_VISUALIZATION_FORMATS: "svg",
},
)
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
......
......@@ -10,8 +10,8 @@ import argparse
@tilelang.jit(out_idx=[6])
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim)) ** 0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
......
......@@ -7,7 +7,7 @@ tilelang.disable_cache()
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@tilelang.jit(out_idx=[2])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
num_stages = 2
mbarrier_list = [128, 128] * num_stages
......
......@@ -5,7 +5,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@tilelang.jit(out_idx=[2])
def matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment