Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
......@@ -39,8 +39,8 @@ def get_configs():
)
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M=128, block_N=128, block_K=32, num_stages=0, thread_num=128, enable_rasterization=False):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def main(
......
......@@ -3,19 +3,20 @@ from tilelang import carver
from tilelang.carver.roller import PrimFuncNode, OutputNode, Edge
from tilelang.carver.arch import auto_infer_current_arch
from tvm import te
from tilelang.language import dtypes as T
def run_general_matmul_emit_configs(M, N, K, topk: int = 20):
arch = auto_infer_current_arch()
def gemm(M, N, K):
A = te.placeholder((M, K), name="A", dtype="float16")
B = te.placeholder((N, K), name="B", dtype="float16")
A = te.placeholder((M, K), name="A", dtype=T.float16)
B = te.placeholder((N, K), name="B", dtype=T.float16)
# Describe the matrix multiplication in TE
k = te.reduce_axis((0, K), name="k")
C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype("float16") * B[j, k].astype("float16"), axis=[k]), name="C")
C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype(T.float16) * B[j, k].astype(T.float16), axis=[k]), name="C")
return A, B, C
......@@ -55,13 +56,13 @@ def run_general_matmul_matmul_emit_configs(M, N, K, topk: int = 20):
arch = auto_infer_current_arch()
def gemm(M, N, K):
A = te.placeholder((M, K), name="A", dtype="float16")
B = te.placeholder((N, K), name="B", dtype="float16")
A = te.placeholder((M, K), name="A", dtype=T.float16)
B = te.placeholder((N, K), name="B", dtype=T.float16)
# Describe the matrix multiplication in TE
k = te.reduce_axis((0, K), name="k")
C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype("float16") * B[j, k].astype("float16"), axis=[k]), name="C")
C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype(T.float16) * B[j, k].astype(T.float16), axis=[k]), name="C")
return A, B, C
......
import tilelang.testing
from tilelang import carver
from tilelang.language import dtypes as T
from tilelang.carver.arch import auto_infer_current_arch
from typing import List
def run_general_reduction_recommend_hints(structure: str = "SSR", shape: List[int] = None, dtype: str = "float16", topk: int = 20):
def run_general_reduction_recommend_hints(structure: str = "SSR", shape: List[int] = None, dtype: T.dtype = T.float16, topk: int = 20):
arch = auto_infer_current_arch()
carve_template = carver.GeneralReductionTemplate(
structure=structure,
......@@ -20,12 +21,12 @@ def run_general_reduction_recommend_hints(structure: str = "SSR", shape: List[in
def test_general_reduction_recommend_hints():
run_general_reduction_recommend_hints("SSR", [1024, 1024, 1024], "float16")
run_general_reduction_recommend_hints("SS", [1024, 1024], "float16")
run_general_reduction_recommend_hints("SRS", [1024, 1024, 1024], "float16")
run_general_reduction_recommend_hints("SSR", [1024, 1024, 1024], T.float16)
run_general_reduction_recommend_hints("SS", [1024, 1024], T.float16)
run_general_reduction_recommend_hints("SRS", [1024, 1024, 1024], T.float16)
def run_elementwise_recommend_hints(shape: List[int] = None, dtype: str = "float16", topk: int = 20):
def run_elementwise_recommend_hints(shape: List[int] = None, dtype: T.dtype = T.float16, topk: int = 20):
arch = auto_infer_current_arch()
carve_template = carver.ElementwiseTemplate(
shape=shape,
......@@ -40,18 +41,18 @@ def run_elementwise_recommend_hints(shape: List[int] = None, dtype: str = "float
def test_elementwise_recommend_hints():
run_elementwise_recommend_hints([1024, 1024], "float16")
run_elementwise_recommend_hints([1024], "float16")
run_elementwise_recommend_hints([1024, 1024, 1024], "float16")
run_elementwise_recommend_hints([1024, 1024], T.float16)
run_elementwise_recommend_hints([1024], T.float16)
run_elementwise_recommend_hints([1024, 1024, 1024], T.float16)
def run_matmul_recommend_hints(
M: int = 1024,
N: int = 1024,
K: int = 1024,
in_dtype: str = "float16",
out_dtype: str = "float16",
accum_dtype: str = "float16",
in_dtype: T.dtype = T.float16,
out_dtype: T.dtype = T.float16,
accum_dtype: T.dtype = T.float16,
):
arch = auto_infer_current_arch()
carve_template = carver.MatmulTemplate(
......@@ -71,13 +72,13 @@ def run_matmul_recommend_hints(
def test_matmul_recommend_hints():
run_matmul_recommend_hints(1024, 1024, 1024, "float16", "float16", "float16")
run_matmul_recommend_hints(1024, 1024, 1024, "int8", "int32", "int32")
run_matmul_recommend_hints(1024, 1024, 1024, "float16", "float32", "float16")
run_matmul_recommend_hints(1024, 1024, 1024, T.float16, T.float16, T.float16)
run_matmul_recommend_hints(1024, 1024, 1024, T.int8, T.int32, T.int32)
run_matmul_recommend_hints(1024, 1024, 1024, T.float16, T.float32, T.float16)
def run_gemv_recommend_hints(
N: int = 1024, K: int = 1024, in_dtype: str = "float16", out_dtype: str = "float16", accum_dtype: str = "float16"
N: int = 1024, K: int = 1024, in_dtype: T.dtype = T.float16, out_dtype: T.dtype = T.float16, accum_dtype: T.dtype = T.float16
):
arch = auto_infer_current_arch()
carve_template = carver.GEMVTemplate(
......@@ -96,9 +97,9 @@ def run_gemv_recommend_hints(
def test_gemv_recommend_hints():
run_gemv_recommend_hints(1024, 1024, "float16", "float16", "float16")
run_gemv_recommend_hints(1024, 1024, "int8", "int32", "int32")
run_gemv_recommend_hints(1024, 1024, "float16", "float32", "float16")
run_gemv_recommend_hints(1024, 1024, T.float16, T.float16, T.float16)
run_gemv_recommend_hints(1024, 1024, T.int8, T.int32, T.int32)
run_gemv_recommend_hints(1024, 1024, T.float16, T.float32, T.float16)
def run_fmha_recommend_hints(
......@@ -107,9 +108,9 @@ def run_fmha_recommend_hints(
seq_length: int = 512,
seq_kv_length: int = 512,
head_dim: int = 128,
in_dtype: str = "float16",
accum_dtype: str = "float16",
out_dtype: str = "float16",
in_dtype: T.dtype = T.float16,
accum_dtype: T.dtype = T.float16,
out_dtype: T.dtype = T.float16,
):
arch = auto_infer_current_arch()
carve_template = carver.FlashAttentionTemplate(
......@@ -133,8 +134,8 @@ def run_fmha_recommend_hints(
def test_fmha_recommend_hints():
run_fmha_recommend_hints(4, 32, 512, 512, 128, "float16", "float16", "float16")
run_fmha_recommend_hints(4, 32, 512, 512, 128, "int8", "int32", "int32")
run_fmha_recommend_hints(4, 32, 512, 512, 128, T.float16, T.float16, T.float16)
run_fmha_recommend_hints(4, 32, 512, 512, 128, T.int8, T.int32, T.int32)
if __name__ == "__main__":
......
......@@ -8,12 +8,12 @@ def _compile_kernel_without_inplace():
num_tokens = T.symbolic("num_tokens")
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), "float"]):
def buggy_kernel(x: T.Tensor[(num_tokens,), T.float]):
with T.Kernel(num_tokens, threads=32) as pid:
read = T.alloc_var("int")
read = T.alloc_var(T.int)
read = x[pid]
write = T.alloc_var("int")
write = T.alloc_var(T.int)
write = read * 2
x[pid] = write
......@@ -29,12 +29,12 @@ def _compile_kernel_with_inplace():
num_tokens = T.symbolic("num_tokens")
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), "float"]):
def buggy_kernel(x: T.Tensor[(num_tokens,), T.float]):
with T.Kernel(num_tokens, threads=32) as pid:
read = T.alloc_var("int")
read = T.alloc_var(T.int)
read = x[pid]
write = T.alloc_var("int")
write = T.alloc_var(T.int)
write = read * 2
x[pid] = write
......
from tilelang import tvm as tvm
import tilelang.testing
from tilelang import language as T
import torch
def matmul(
......@@ -22,8 +23,6 @@ def matmul(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
......@@ -93,8 +92,6 @@ def run_gemm(
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
......@@ -114,9 +111,9 @@ def test_gemm_f16f16f16_nn():
768,
False,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
128,
256,
32,
......@@ -129,9 +126,9 @@ def test_gemm_f16f16f16_nn():
768,
False,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
128,
256,
32,
......
......@@ -5,7 +5,7 @@ import tilelang.language as T
import torch
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
num_stages = 0
@T.prim_func
......@@ -61,7 +61,7 @@ def test_matmul_codegen():
def test_matmul_compile():
def matmul_jit_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul_jit_test(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
# a simple kernel just for jit test
@T.prim_func
def matmul(
......@@ -103,7 +103,7 @@ def test_matmul_compile():
with tvm.target.Target("c"):
complied_fun = tilelang.compile(cpu_func, -1, execution_backend="ctypes")
in_dtype = "float16"
in_dtype = T.float16
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype))
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype))
......
......@@ -5,7 +5,7 @@ import tilelang.testing
import tilelang.language as T
def debug_print_buffer(M=16, N=16, dtype="float16"):
def debug_print_buffer(M=16, N=16, dtype=T.float16):
@T.prim_func
def program(Q: T.Tensor((M, N), dtype)):
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
......@@ -18,28 +18,28 @@ def debug_print_buffer(M=16, N=16, dtype="float16"):
def test_debug_print_buffer():
debug_print_buffer(dtype="bool")
debug_print_buffer(dtype="int8")
debug_print_buffer(dtype="int16")
debug_print_buffer(dtype="int32")
debug_print_buffer(dtype="int64")
debug_print_buffer(dtype="uint8")
debug_print_buffer(dtype="uint16")
debug_print_buffer(dtype="uint32")
debug_print_buffer(dtype="uint64")
debug_print_buffer(dtype="float16")
debug_print_buffer(dtype="float32")
debug_print_buffer(dtype="float64")
debug_print_buffer(dtype="bfloat16")
debug_print_buffer(dtype="float8_e4m3")
debug_print_buffer(dtype="float8_e4m3fn")
debug_print_buffer(dtype="float8_e4m3fnuz")
debug_print_buffer(dtype="float8_e5m2")
debug_print_buffer(dtype="float8_e5m2fnuz")
debug_print_buffer(dtype=T.bool)
debug_print_buffer(dtype=T.int8)
debug_print_buffer(dtype=T.int16)
debug_print_buffer(dtype=T.int32)
debug_print_buffer(dtype=T.int64)
debug_print_buffer(dtype=T.uint8)
debug_print_buffer(dtype=T.uint16)
debug_print_buffer(dtype=T.uint32)
debug_print_buffer(dtype=T.uint64)
debug_print_buffer(dtype=T.float16)
debug_print_buffer(dtype=T.float32)
debug_print_buffer(dtype=T.float64)
debug_print_buffer(dtype=T.bfloat16)
debug_print_buffer(dtype=T.float8_e4m3fn)
debug_print_buffer(dtype=T.float8_e4m3fn)
debug_print_buffer(dtype=T.float8_e4m3fnuz)
debug_print_buffer(dtype=T.float8_e5m2)
debug_print_buffer(dtype=T.float8_e5m2fnuz)
def debug_print_buffer_conditional(M=16, N=16):
dtype = "float16"
dtype = T.float16
@T.prim_func
def program(Q: T.Tensor((M, N), dtype)):
......@@ -59,7 +59,7 @@ def test_debug_print_buffer_conditional():
def debug_print_value_conditional(M=16, N=16):
dtype = "float16"
dtype = T.float16
@T.prim_func
def program(Q: T.Tensor((M, N), dtype)):
......@@ -78,7 +78,7 @@ def test_debug_print_value_conditional():
def debug_print_register_files(M=16, N=16):
dtype = "float16"
dtype = T.float16
@T.prim_func
def program(Q: T.Tensor((M, N), dtype)):
......@@ -97,7 +97,7 @@ def test_debug_print_register_files():
def debug_print_msg(M=16, N=16):
dtype = "float16"
dtype = T.float16
@T.prim_func
def program(Q: T.Tensor((M, N), dtype)):
......
......@@ -33,18 +33,18 @@ def tl_matmul_macro(
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
if out_dtype == T.int32:
micro_size_k = 32
# This is a debug config
......@@ -52,7 +52,7 @@ def tl_matmul_macro(
block_col_warps = 1
warp_row_tiles = 16
warp_col_tiles = 16
chunk = 32 if in_dtype == "float16" else 64
chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn"
# Pipeline Stage
......@@ -453,36 +453,36 @@ def assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
def test_assert_tl_matmul_macro():
assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_macro_correctness(66, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_macro_correctness(32, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_macro_correctness(128, 128, 128, T.float16, T.float16, T.float16)
assert_tl_matmul_macro_correctness(66, 128, 128, T.float16, T.float16, T.float16)
assert_tl_matmul_macro_correctness(32, 128, 128, T.float16, T.float16, T.float16)
def test_assert_tl_matmul_block():
assert_tl_matmul_block_correctness(128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32)
assert_tl_matmul_block_correctness(67, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32)
assert_tl_matmul_block_correctness(36, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32)
assert_tl_matmul_block_correctness(128, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32)
assert_tl_matmul_block_correctness(67, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32)
assert_tl_matmul_block_correctness(36, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32)
def test_assert_tl_matmul_block_all_dynamic():
assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32)
def test_assert_tl_matmul_block_all_dynamic_with_pass_config():
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=8
128, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32, dynamic_alignment=8
)
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
64, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=8
64, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32, dynamic_alignment=8
)
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
64, 128, 60, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=4
64, 128, 60, False, False, T.float16, T.float16, T.float16, 64, 64, 32, dynamic_alignment=4
)
# Tail split is enabled with dynamic alignment 0
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
64, 128, 64, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=0
64, 128, 64, False, False, T.float16, T.float16, T.float16, 64, 64, 32, dynamic_alignment=0
)
......
......@@ -437,7 +437,7 @@ def assert_tl_matmul_block_dynamic_mnk(
def run_assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, "float16", "float16", "float32")
assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, T.float16, T.float16, T.float32)
def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
......@@ -450,9 +450,9 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
block_K,
False,
False,
"float16",
"float16",
"float32",
T.float16,
T.float16,
T.float32,
pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8},
)
assert_tl_matmul_block_dynamic_m(
......@@ -464,9 +464,9 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
block_K,
False,
False,
"float16",
"float16",
"float32",
T.float16,
T.float16,
T.float32,
pass_configs={"tl.disable_dynamic_tail_split": False},
)
......@@ -481,9 +481,9 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
block_K,
False,
False,
"float16",
"float16",
"float32",
T.float16,
T.float16,
T.float32,
pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8},
)
assert_tl_matmul_block_dynamic_mn(
......@@ -495,9 +495,9 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
block_K,
False,
False,
"float16",
"float16",
"float32",
T.float16,
T.float16,
T.float32,
pass_configs={"tl.disable_dynamic_tail_split": False},
)
......@@ -512,9 +512,9 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
block_K,
False,
False,
"float16",
"float16",
"float32",
T.float16,
T.float16,
T.float32,
pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 4},
)
assert_tl_matmul_block_dynamic_mnk(
......@@ -526,9 +526,9 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
block_K,
False,
False,
"float16",
"float16",
"float32",
T.float16,
T.float16,
T.float32,
pass_configs={"tl.disable_dynamic_tail_split": False},
)
......
......@@ -50,7 +50,7 @@ def check_non_fastmath_usage(source, mathop_name):
check_fastmath_usage(source, mathop_name, expect_fastmath=False)
def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32):
"""
Test single-argument mathops.
T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
......@@ -86,7 +86,7 @@ def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=3
print(f"✓ {mathop_name} compilation and execution test passed")
def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32):
"""
Test two-argument mathops to ensure they generate non-fastmath CUDA code.
"""
......@@ -134,7 +134,7 @@ def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32,
check_non_fastmath_usage(source_fastmath, mathop_name)
# Test numerical correctness
torch_dtype = getattr(torch, dtype)
torch_dtype = dtype.as_torch()
a = torch.randn(M, N, device="cuda", dtype=torch_dtype)
b = torch.randn(M, N, device="cuda", dtype=torch_dtype)
......@@ -160,8 +160,8 @@ def run_abs_test():
@T.prim_func
def main(
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
A: T.Tensor((M, N), T.float32),
B: T.Tensor((M, N), T.float32),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
......@@ -189,7 +189,7 @@ def run_abs_test():
print("✓ abs numerical test passed")
def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32):
"""
Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix).
"""
......@@ -222,7 +222,7 @@ def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32,
check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True)
# Test numerical correctness
torch_dtype = getattr(torch, dtype)
torch_dtype = dtype.as_torch()
a = torch.randn(M, N, device="cuda", dtype=torch_dtype)
# Ensure positive values for functions that need them
......@@ -272,7 +272,7 @@ def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32,
@tilelang.testing.requires_cuda
def test_mathops_generate_no_fastmath(name, func):
"""Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)"""
run_single_arg_mathop_test(name, func, dtype="float32")
run_single_arg_mathop_test(name, func, dtype=T.float32)
print(f"✓ {name} test passed")
......@@ -286,7 +286,7 @@ def test_mathops_generate_no_fastmath(name, func):
@tilelang.testing.requires_cuda
def test_two_arg_mathops_fastmath(name, func):
"""Test all two-argument mathops"""
run_two_arg_mathop_test(name, func, dtype="float32")
run_two_arg_mathop_test(name, func, dtype=T.float32)
@tilelang.testing.requires_cuda
......@@ -311,7 +311,7 @@ def test_abs_maps_to_fabs():
@tilelang.testing.requires_cuda
def test_fastmath_versions(name, func):
"""Test that __exp, __exp10, __log, __log2, __log10, __tan, __cos, __sin generate fastmath CUDA code"""
run_fastmath_mathop_test(name, func, dtype="float32")
run_fastmath_mathop_test(name, func, dtype=T.float32)
print(f"✓ {name} test passed")
......
......@@ -14,9 +14,9 @@ def _cumsum_view_infer_layout(hidden):
num_tokens = T.dynamic("num_tokens")
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens, hidden), "float"]):
def buggy_kernel(x: T.Tensor[(num_tokens, hidden), T.float]):
with T.Kernel(num_tokens, threads=128) as pid:
smem = T.alloc_shared((hidden,), dtype="float")
smem = T.alloc_shared((hidden,), dtype=T.float32)
T.copy(x[pid, :], smem)
T.cumsum(T.view(smem, (1, hidden)), dim=1)
......
......@@ -33,7 +33,7 @@ def _fill_with_dynamic_region_kernel():
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), "int64"]): # noqa: F821
with T.Kernel(num_tokens, threads=128) as _:
a, b = T.alloc_var("int"), T.alloc_var("int")
a, b = T.alloc_var(T.int), T.alloc_var(T.int)
T.fill(x[a:b], 0)
return buggy_kernel
......
......@@ -9,7 +9,7 @@ def test_int64_address():
S,
D,
pos_ty="int64",
dtype="float32",
dtype=T.float32,
):
@T.prim_func
def main(
......@@ -36,7 +36,7 @@ def test_int64_address():
pos_int64 = torch.arange(S, device="cuda", dtype=torch.int64)
pos_int32 = torch.arange(S, device="cuda", dtype=torch.int32)
kernel_int64 = set_cache_kernel(S, D, "int64")
kernel_int32 = set_cache_kernel(S, D, "int32")
kernel_int32 = set_cache_kernel(S, D, T.int32)
kernel_int64(pos_int64, value, cache)
torch.testing.assert_close(cache, value)
kernel_int32(pos_int32, value, cache)
......
......@@ -9,7 +9,7 @@ def test_issue_1198():
[
32,
],
"int32",
T.int32,
),
):
pass
......
......@@ -4,10 +4,10 @@ import tilelang.testing
def _make_kernel(M, N):
dtype = "bfloat16"
dtype = T.bfloat16
@T.prim_func
def fwd_main(KV: T.Tensor((M, N), dtype), ids: T.Tensor((4,), "int32")):
def fwd_main(KV: T.Tensor((M, N), dtype), ids: T.Tensor((4,), T.int32)):
with T.Kernel(4, threads=1):
A = T.alloc_shared([N], dtype)
B = T.alloc_shared([N], dtype)
......
......@@ -7,12 +7,12 @@ def test_issue_1237_dynamic_copy_extent_builds():
# The goal is to ensure T.copy correctly handles dynamic extents
# (e.g., src slice length vs. static dst buffer size) during prim_func building.
length = T.symbolic("len", dtype="int32")
length = T.symbolic("len", dtype=T.int32)
@T.prim_func
def sample_kernel(global_tensor: T.Tensor[(length,), "int32"]): # noqa: F821
def sample_kernel(global_tensor: T.Tensor[(length,), T.int32]): # noqa: F821
with T.Kernel(1, threads=32):
buffer_shared = T.alloc_shared((1024,), dtype="int32")
buffer_shared = T.alloc_shared((1024,), dtype=T.int32)
T.copy(global_tensor[0:length], buffer_shared)
# Building the prim_func is sufficient to exercise the bug path; no need to JIT/execute.
......
......@@ -5,7 +5,7 @@ import torch
@tilelang.jit
def _tmp_var_kernel(N, block_N, dtype="float"):
def _tmp_var_kernel(N, block_N, dtype=T.float32):
@T.prim_func
def kernel(
A: T.Tensor((N,), dtype),
......
......@@ -34,7 +34,7 @@ def _empty_with_dead_code_kernel():
num_tokens = T.dynamic("num_tokens")
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), "float32"]):
def buggy_kernel(x: T.Tensor[(num_tokens,), T.float32]):
with T.Kernel(num_tokens, threads=32) as pid:
y = x[pid]
......
......@@ -4,7 +4,7 @@ import tilelang.language as T
import torch
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......
......@@ -8,10 +8,10 @@ import tilelang.language as T
def merge_if_test():
@T.prim_func
def main():
A = T.alloc_fragment((1,), "float16")
B = T.alloc_fragment((1,), "float16")
C = T.alloc_fragment((1,), "float16")
D = T.alloc_fragment((1,), "float16")
A = T.alloc_fragment((1,), T.float16)
B = T.alloc_fragment((1,), T.float16)
C = T.alloc_fragment((1,), T.float16)
D = T.alloc_fragment((1,), T.float16)
if A[0] == 0:
A[0] = 0
if B[0] == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment