Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
...@@ -39,8 +39,8 @@ def get_configs(): ...@@ -39,8 +39,8 @@ def get_configs():
) )
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M=128, block_N=128, block_K=32, num_stages=0, thread_num=128, enable_rasterization=False): def matmul(M, N, K, block_M=128, block_N=128, block_K=32, num_stages=0, thread_num=128, enable_rasterization=False):
dtype = "float16" dtype = T.float16
accum_dtype = "float" accum_dtype = T.float32
@T.prim_func @T.prim_func
def main( def main(
......
...@@ -3,19 +3,20 @@ from tilelang import carver ...@@ -3,19 +3,20 @@ from tilelang import carver
from tilelang.carver.roller import PrimFuncNode, OutputNode, Edge from tilelang.carver.roller import PrimFuncNode, OutputNode, Edge
from tilelang.carver.arch import auto_infer_current_arch from tilelang.carver.arch import auto_infer_current_arch
from tvm import te from tvm import te
from tilelang.language import dtypes as T
def run_general_matmul_emit_configs(M, N, K, topk: int = 20): def run_general_matmul_emit_configs(M, N, K, topk: int = 20):
arch = auto_infer_current_arch() arch = auto_infer_current_arch()
def gemm(M, N, K): def gemm(M, N, K):
A = te.placeholder((M, K), name="A", dtype="float16") A = te.placeholder((M, K), name="A", dtype=T.float16)
B = te.placeholder((N, K), name="B", dtype="float16") B = te.placeholder((N, K), name="B", dtype=T.float16)
# Describe the matrix multiplication in TE # Describe the matrix multiplication in TE
k = te.reduce_axis((0, K), name="k") k = te.reduce_axis((0, K), name="k")
C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype("float16") * B[j, k].astype("float16"), axis=[k]), name="C") C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype(T.float16) * B[j, k].astype(T.float16), axis=[k]), name="C")
return A, B, C return A, B, C
...@@ -55,13 +56,13 @@ def run_general_matmul_matmul_emit_configs(M, N, K, topk: int = 20): ...@@ -55,13 +56,13 @@ def run_general_matmul_matmul_emit_configs(M, N, K, topk: int = 20):
arch = auto_infer_current_arch() arch = auto_infer_current_arch()
def gemm(M, N, K): def gemm(M, N, K):
A = te.placeholder((M, K), name="A", dtype="float16") A = te.placeholder((M, K), name="A", dtype=T.float16)
B = te.placeholder((N, K), name="B", dtype="float16") B = te.placeholder((N, K), name="B", dtype=T.float16)
# Describe the matrix multiplication in TE # Describe the matrix multiplication in TE
k = te.reduce_axis((0, K), name="k") k = te.reduce_axis((0, K), name="k")
C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype("float16") * B[j, k].astype("float16"), axis=[k]), name="C") C = te.compute((M, N), lambda i, j: te.sum(A[i, k].astype(T.float16) * B[j, k].astype(T.float16), axis=[k]), name="C")
return A, B, C return A, B, C
......
import tilelang.testing import tilelang.testing
from tilelang import carver from tilelang import carver
from tilelang.language import dtypes as T
from tilelang.carver.arch import auto_infer_current_arch from tilelang.carver.arch import auto_infer_current_arch
from typing import List from typing import List
def run_general_reduction_recommend_hints(structure: str = "SSR", shape: List[int] = None, dtype: str = "float16", topk: int = 20): def run_general_reduction_recommend_hints(structure: str = "SSR", shape: List[int] = None, dtype: T.dtype = T.float16, topk: int = 20):
arch = auto_infer_current_arch() arch = auto_infer_current_arch()
carve_template = carver.GeneralReductionTemplate( carve_template = carver.GeneralReductionTemplate(
structure=structure, structure=structure,
...@@ -20,12 +21,12 @@ def run_general_reduction_recommend_hints(structure: str = "SSR", shape: List[in ...@@ -20,12 +21,12 @@ def run_general_reduction_recommend_hints(structure: str = "SSR", shape: List[in
def test_general_reduction_recommend_hints(): def test_general_reduction_recommend_hints():
run_general_reduction_recommend_hints("SSR", [1024, 1024, 1024], "float16") run_general_reduction_recommend_hints("SSR", [1024, 1024, 1024], T.float16)
run_general_reduction_recommend_hints("SS", [1024, 1024], "float16") run_general_reduction_recommend_hints("SS", [1024, 1024], T.float16)
run_general_reduction_recommend_hints("SRS", [1024, 1024, 1024], "float16") run_general_reduction_recommend_hints("SRS", [1024, 1024, 1024], T.float16)
def run_elementwise_recommend_hints(shape: List[int] = None, dtype: str = "float16", topk: int = 20): def run_elementwise_recommend_hints(shape: List[int] = None, dtype: T.dtype = T.float16, topk: int = 20):
arch = auto_infer_current_arch() arch = auto_infer_current_arch()
carve_template = carver.ElementwiseTemplate( carve_template = carver.ElementwiseTemplate(
shape=shape, shape=shape,
...@@ -40,18 +41,18 @@ def run_elementwise_recommend_hints(shape: List[int] = None, dtype: str = "float ...@@ -40,18 +41,18 @@ def run_elementwise_recommend_hints(shape: List[int] = None, dtype: str = "float
def test_elementwise_recommend_hints(): def test_elementwise_recommend_hints():
run_elementwise_recommend_hints([1024, 1024], "float16") run_elementwise_recommend_hints([1024, 1024], T.float16)
run_elementwise_recommend_hints([1024], "float16") run_elementwise_recommend_hints([1024], T.float16)
run_elementwise_recommend_hints([1024, 1024, 1024], "float16") run_elementwise_recommend_hints([1024, 1024, 1024], T.float16)
def run_matmul_recommend_hints( def run_matmul_recommend_hints(
M: int = 1024, M: int = 1024,
N: int = 1024, N: int = 1024,
K: int = 1024, K: int = 1024,
in_dtype: str = "float16", in_dtype: T.dtype = T.float16,
out_dtype: str = "float16", out_dtype: T.dtype = T.float16,
accum_dtype: str = "float16", accum_dtype: T.dtype = T.float16,
): ):
arch = auto_infer_current_arch() arch = auto_infer_current_arch()
carve_template = carver.MatmulTemplate( carve_template = carver.MatmulTemplate(
...@@ -71,13 +72,13 @@ def run_matmul_recommend_hints( ...@@ -71,13 +72,13 @@ def run_matmul_recommend_hints(
def test_matmul_recommend_hints(): def test_matmul_recommend_hints():
run_matmul_recommend_hints(1024, 1024, 1024, "float16", "float16", "float16") run_matmul_recommend_hints(1024, 1024, 1024, T.float16, T.float16, T.float16)
run_matmul_recommend_hints(1024, 1024, 1024, "int8", "int32", "int32") run_matmul_recommend_hints(1024, 1024, 1024, T.int8, T.int32, T.int32)
run_matmul_recommend_hints(1024, 1024, 1024, "float16", "float32", "float16") run_matmul_recommend_hints(1024, 1024, 1024, T.float16, T.float32, T.float16)
def run_gemv_recommend_hints( def run_gemv_recommend_hints(
N: int = 1024, K: int = 1024, in_dtype: str = "float16", out_dtype: str = "float16", accum_dtype: str = "float16" N: int = 1024, K: int = 1024, in_dtype: T.dtype = T.float16, out_dtype: T.dtype = T.float16, accum_dtype: T.dtype = T.float16
): ):
arch = auto_infer_current_arch() arch = auto_infer_current_arch()
carve_template = carver.GEMVTemplate( carve_template = carver.GEMVTemplate(
...@@ -96,9 +97,9 @@ def run_gemv_recommend_hints( ...@@ -96,9 +97,9 @@ def run_gemv_recommend_hints(
def test_gemv_recommend_hints(): def test_gemv_recommend_hints():
run_gemv_recommend_hints(1024, 1024, "float16", "float16", "float16") run_gemv_recommend_hints(1024, 1024, T.float16, T.float16, T.float16)
run_gemv_recommend_hints(1024, 1024, "int8", "int32", "int32") run_gemv_recommend_hints(1024, 1024, T.int8, T.int32, T.int32)
run_gemv_recommend_hints(1024, 1024, "float16", "float32", "float16") run_gemv_recommend_hints(1024, 1024, T.float16, T.float32, T.float16)
def run_fmha_recommend_hints( def run_fmha_recommend_hints(
...@@ -107,9 +108,9 @@ def run_fmha_recommend_hints( ...@@ -107,9 +108,9 @@ def run_fmha_recommend_hints(
seq_length: int = 512, seq_length: int = 512,
seq_kv_length: int = 512, seq_kv_length: int = 512,
head_dim: int = 128, head_dim: int = 128,
in_dtype: str = "float16", in_dtype: T.dtype = T.float16,
accum_dtype: str = "float16", accum_dtype: T.dtype = T.float16,
out_dtype: str = "float16", out_dtype: T.dtype = T.float16,
): ):
arch = auto_infer_current_arch() arch = auto_infer_current_arch()
carve_template = carver.FlashAttentionTemplate( carve_template = carver.FlashAttentionTemplate(
...@@ -133,8 +134,8 @@ def run_fmha_recommend_hints( ...@@ -133,8 +134,8 @@ def run_fmha_recommend_hints(
def test_fmha_recommend_hints(): def test_fmha_recommend_hints():
run_fmha_recommend_hints(4, 32, 512, 512, 128, "float16", "float16", "float16") run_fmha_recommend_hints(4, 32, 512, 512, 128, T.float16, T.float16, T.float16)
run_fmha_recommend_hints(4, 32, 512, 512, 128, "int8", "int32", "int32") run_fmha_recommend_hints(4, 32, 512, 512, 128, T.int8, T.int32, T.int32)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -8,12 +8,12 @@ def _compile_kernel_without_inplace(): ...@@ -8,12 +8,12 @@ def _compile_kernel_without_inplace():
num_tokens = T.symbolic("num_tokens") num_tokens = T.symbolic("num_tokens")
@T.prim_func @T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), "float"]): def buggy_kernel(x: T.Tensor[(num_tokens,), T.float]):
with T.Kernel(num_tokens, threads=32) as pid: with T.Kernel(num_tokens, threads=32) as pid:
read = T.alloc_var("int") read = T.alloc_var(T.int)
read = x[pid] read = x[pid]
write = T.alloc_var("int") write = T.alloc_var(T.int)
write = read * 2 write = read * 2
x[pid] = write x[pid] = write
...@@ -29,12 +29,12 @@ def _compile_kernel_with_inplace(): ...@@ -29,12 +29,12 @@ def _compile_kernel_with_inplace():
num_tokens = T.symbolic("num_tokens") num_tokens = T.symbolic("num_tokens")
@T.prim_func @T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), "float"]): def buggy_kernel(x: T.Tensor[(num_tokens,), T.float]):
with T.Kernel(num_tokens, threads=32) as pid: with T.Kernel(num_tokens, threads=32) as pid:
read = T.alloc_var("int") read = T.alloc_var(T.int)
read = x[pid] read = x[pid]
write = T.alloc_var("int") write = T.alloc_var(T.int)
write = read * 2 write = read * 2
x[pid] = write x[pid] = write
......
from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
from tilelang import language as T
import torch
def matmul( def matmul(
...@@ -22,8 +23,6 @@ def matmul( ...@@ -22,8 +23,6 @@ def matmul(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
...@@ -93,8 +92,6 @@ def run_gemm( ...@@ -93,8 +92,6 @@ def run_gemm(
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
def ref_program(A, B): def ref_program(A, B):
import torch
if trans_A: if trans_A:
A = A.T A = A.T
if trans_B: if trans_B:
...@@ -114,9 +111,9 @@ def test_gemm_f16f16f16_nn(): ...@@ -114,9 +111,9 @@ def test_gemm_f16f16f16_nn():
768, 768,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float16", T.float16,
128, 128,
256, 256,
32, 32,
...@@ -129,9 +126,9 @@ def test_gemm_f16f16f16_nn(): ...@@ -129,9 +126,9 @@ def test_gemm_f16f16f16_nn():
768, 768,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float16", T.float16,
128, 128,
256, 256,
32, 32,
......
...@@ -5,7 +5,7 @@ import tilelang.language as T ...@@ -5,7 +5,7 @@ import tilelang.language as T
import torch import torch
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
num_stages = 0 num_stages = 0
@T.prim_func @T.prim_func
...@@ -61,7 +61,7 @@ def test_matmul_codegen(): ...@@ -61,7 +61,7 @@ def test_matmul_codegen():
def test_matmul_compile(): def test_matmul_compile():
def matmul_jit_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul_jit_test(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
# a simple kernel just for jit test # a simple kernel just for jit test
@T.prim_func @T.prim_func
def matmul( def matmul(
...@@ -103,7 +103,7 @@ def test_matmul_compile(): ...@@ -103,7 +103,7 @@ def test_matmul_compile():
with tvm.target.Target("c"): with tvm.target.Target("c"):
complied_fun = tilelang.compile(cpu_func, -1, execution_backend="ctypes") complied_fun = tilelang.compile(cpu_func, -1, execution_backend="ctypes")
in_dtype = "float16" in_dtype = T.float16
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)) A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype))
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)) B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype))
......
...@@ -5,7 +5,7 @@ import tilelang.testing ...@@ -5,7 +5,7 @@ import tilelang.testing
import tilelang.language as T import tilelang.language as T
def debug_print_buffer(M=16, N=16, dtype="float16"): def debug_print_buffer(M=16, N=16, dtype=T.float16):
@T.prim_func @T.prim_func
def program(Q: T.Tensor((M, N), dtype)): def program(Q: T.Tensor((M, N), dtype)):
with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz): with T.Kernel(4, 4, 2, threads=128 * 2) as (bx, by, bz):
...@@ -18,28 +18,28 @@ def debug_print_buffer(M=16, N=16, dtype="float16"): ...@@ -18,28 +18,28 @@ def debug_print_buffer(M=16, N=16, dtype="float16"):
def test_debug_print_buffer(): def test_debug_print_buffer():
debug_print_buffer(dtype="bool") debug_print_buffer(dtype=T.bool)
debug_print_buffer(dtype="int8") debug_print_buffer(dtype=T.int8)
debug_print_buffer(dtype="int16") debug_print_buffer(dtype=T.int16)
debug_print_buffer(dtype="int32") debug_print_buffer(dtype=T.int32)
debug_print_buffer(dtype="int64") debug_print_buffer(dtype=T.int64)
debug_print_buffer(dtype="uint8") debug_print_buffer(dtype=T.uint8)
debug_print_buffer(dtype="uint16") debug_print_buffer(dtype=T.uint16)
debug_print_buffer(dtype="uint32") debug_print_buffer(dtype=T.uint32)
debug_print_buffer(dtype="uint64") debug_print_buffer(dtype=T.uint64)
debug_print_buffer(dtype="float16") debug_print_buffer(dtype=T.float16)
debug_print_buffer(dtype="float32") debug_print_buffer(dtype=T.float32)
debug_print_buffer(dtype="float64") debug_print_buffer(dtype=T.float64)
debug_print_buffer(dtype="bfloat16") debug_print_buffer(dtype=T.bfloat16)
debug_print_buffer(dtype="float8_e4m3") debug_print_buffer(dtype=T.float8_e4m3fn)
debug_print_buffer(dtype="float8_e4m3fn") debug_print_buffer(dtype=T.float8_e4m3fn)
debug_print_buffer(dtype="float8_e4m3fnuz") debug_print_buffer(dtype=T.float8_e4m3fnuz)
debug_print_buffer(dtype="float8_e5m2") debug_print_buffer(dtype=T.float8_e5m2)
debug_print_buffer(dtype="float8_e5m2fnuz") debug_print_buffer(dtype=T.float8_e5m2fnuz)
def debug_print_buffer_conditional(M=16, N=16): def debug_print_buffer_conditional(M=16, N=16):
dtype = "float16" dtype = T.float16
@T.prim_func @T.prim_func
def program(Q: T.Tensor((M, N), dtype)): def program(Q: T.Tensor((M, N), dtype)):
...@@ -59,7 +59,7 @@ def test_debug_print_buffer_conditional(): ...@@ -59,7 +59,7 @@ def test_debug_print_buffer_conditional():
def debug_print_value_conditional(M=16, N=16): def debug_print_value_conditional(M=16, N=16):
dtype = "float16" dtype = T.float16
@T.prim_func @T.prim_func
def program(Q: T.Tensor((M, N), dtype)): def program(Q: T.Tensor((M, N), dtype)):
...@@ -78,7 +78,7 @@ def test_debug_print_value_conditional(): ...@@ -78,7 +78,7 @@ def test_debug_print_value_conditional():
def debug_print_register_files(M=16, N=16): def debug_print_register_files(M=16, N=16):
dtype = "float16" dtype = T.float16
@T.prim_func @T.prim_func
def program(Q: T.Tensor((M, N), dtype)): def program(Q: T.Tensor((M, N), dtype)):
...@@ -97,7 +97,7 @@ def test_debug_print_register_files(): ...@@ -97,7 +97,7 @@ def test_debug_print_register_files():
def debug_print_msg(M=16, N=16): def debug_print_msg(M=16, N=16):
dtype = "float16" dtype = T.float16
@T.prim_func @T.prim_func
def program(Q: T.Tensor((M, N), dtype)): def program(Q: T.Tensor((M, N), dtype)):
......
...@@ -33,18 +33,18 @@ def tl_matmul_macro( ...@@ -33,18 +33,18 @@ def tl_matmul_macro(
accum_dtype, accum_dtype,
): ):
assert in_dtype in [ assert in_dtype in [
"float16", T.float16,
"int8", T.int8,
], "Currently only float16 and int8 are supported" ], "Currently only float16 and int8 are supported"
assert out_dtype in [ assert out_dtype in [
"float16", T.float16,
"float32", T.float32,
"int32", T.int32,
], "Currently only float16, float32 and int32 are supported" ], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32": if out_dtype == T.int32:
micro_size_k = 32 micro_size_k = 32
# This is a debug config # This is a debug config
...@@ -52,7 +52,7 @@ def tl_matmul_macro( ...@@ -52,7 +52,7 @@ def tl_matmul_macro(
block_col_warps = 1 block_col_warps = 1
warp_row_tiles = 16 warp_row_tiles = 16
warp_col_tiles = 16 warp_col_tiles = 16
chunk = 32 if in_dtype == "float16" else 64 chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn" shared_scope = "shared.dyn"
# Pipeline Stage # Pipeline Stage
...@@ -453,36 +453,36 @@ def assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( ...@@ -453,36 +453,36 @@ def assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
def test_assert_tl_matmul_macro(): def test_assert_tl_matmul_macro():
assert_tl_matmul_macro_correctness(128, 128, 128, "float16", "float16", "float16") assert_tl_matmul_macro_correctness(128, 128, 128, T.float16, T.float16, T.float16)
assert_tl_matmul_macro_correctness(66, 128, 128, "float16", "float16", "float16") assert_tl_matmul_macro_correctness(66, 128, 128, T.float16, T.float16, T.float16)
assert_tl_matmul_macro_correctness(32, 128, 128, "float16", "float16", "float16") assert_tl_matmul_macro_correctness(32, 128, 128, T.float16, T.float16, T.float16)
def test_assert_tl_matmul_block(): def test_assert_tl_matmul_block():
assert_tl_matmul_block_correctness(128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) assert_tl_matmul_block_correctness(128, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32)
assert_tl_matmul_block_correctness(67, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) assert_tl_matmul_block_correctness(67, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32)
assert_tl_matmul_block_correctness(36, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) assert_tl_matmul_block_correctness(36, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32)
def test_assert_tl_matmul_block_all_dynamic(): def test_assert_tl_matmul_block_all_dynamic():
assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) assert_tl_matmul_block_all_dynamic_correctness(128, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) assert_tl_matmul_block_all_dynamic_correctness(67, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32)
assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32) assert_tl_matmul_block_all_dynamic_correctness(36, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32)
def test_assert_tl_matmul_block_all_dynamic_with_pass_config(): def test_assert_tl_matmul_block_all_dynamic_with_pass_config():
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
128, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=8 128, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32, dynamic_alignment=8
) )
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
64, 128, 128, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=8 64, 128, 128, False, False, T.float16, T.float16, T.float16, 64, 64, 32, dynamic_alignment=8
) )
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
64, 128, 60, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=4 64, 128, 60, False, False, T.float16, T.float16, T.float16, 64, 64, 32, dynamic_alignment=4
) )
# Tail split is enabled with dynamic alignment 0 # Tail split is enabled with dynamic alignment 0
assert_tl_matmul_block_all_dynamic_correctness_with_pass_config( assert_tl_matmul_block_all_dynamic_correctness_with_pass_config(
64, 128, 64, False, False, "float16", "float16", "float16", 64, 64, 32, dynamic_alignment=0 64, 128, 64, False, False, T.float16, T.float16, T.float16, 64, 64, 32, dynamic_alignment=0
) )
......
...@@ -437,7 +437,7 @@ def assert_tl_matmul_block_dynamic_mnk( ...@@ -437,7 +437,7 @@ def assert_tl_matmul_block_dynamic_mnk(
def run_assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K): def run_assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K):
assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, "float16", "float16", "float32") assert_tl_matmul_block_static(M, N, K, block_M, block_N, block_K, False, False, T.float16, T.float16, T.float32)
def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K): def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
...@@ -450,9 +450,9 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K): ...@@ -450,9 +450,9 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
block_K, block_K,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float32", T.float32,
pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8}, pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8},
) )
assert_tl_matmul_block_dynamic_m( assert_tl_matmul_block_dynamic_m(
...@@ -464,9 +464,9 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K): ...@@ -464,9 +464,9 @@ def run_assert_tl_matmul_block_dynamic_m(M, N, K, block_M, block_N, block_K):
block_K, block_K,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float32", T.float32,
pass_configs={"tl.disable_dynamic_tail_split": False}, pass_configs={"tl.disable_dynamic_tail_split": False},
) )
...@@ -481,9 +481,9 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K): ...@@ -481,9 +481,9 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
block_K, block_K,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float32", T.float32,
pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8}, pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8},
) )
assert_tl_matmul_block_dynamic_mn( assert_tl_matmul_block_dynamic_mn(
...@@ -495,9 +495,9 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K): ...@@ -495,9 +495,9 @@ def run_assert_tl_matmul_block_dynamic_mn(M, N, K, block_M, block_N, block_K):
block_K, block_K,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float32", T.float32,
pass_configs={"tl.disable_dynamic_tail_split": False}, pass_configs={"tl.disable_dynamic_tail_split": False},
) )
...@@ -512,9 +512,9 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): ...@@ -512,9 +512,9 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
block_K, block_K,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float32", T.float32,
pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 4}, pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 4},
) )
assert_tl_matmul_block_dynamic_mnk( assert_tl_matmul_block_dynamic_mnk(
...@@ -526,9 +526,9 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K): ...@@ -526,9 +526,9 @@ def run_assert_tl_matmul_block_dynamic_mnk(M, N, K, block_M, block_N, block_K):
block_K, block_K,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float32", T.float32,
pass_configs={"tl.disable_dynamic_tail_split": False}, pass_configs={"tl.disable_dynamic_tail_split": False},
) )
......
...@@ -50,7 +50,7 @@ def check_non_fastmath_usage(source, mathop_name): ...@@ -50,7 +50,7 @@ def check_non_fastmath_usage(source, mathop_name):
check_fastmath_usage(source, mathop_name, expect_fastmath=False) check_fastmath_usage(source, mathop_name, expect_fastmath=False)
def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32):
""" """
Test single-argument mathops. Test single-argument mathops.
T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath) T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
...@@ -86,7 +86,7 @@ def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=3 ...@@ -86,7 +86,7 @@ def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=3
print(f"✓ {mathop_name} compilation and execution test passed") print(f"✓ {mathop_name} compilation and execution test passed")
def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32):
""" """
Test two-argument mathops to ensure they generate non-fastmath CUDA code. Test two-argument mathops to ensure they generate non-fastmath CUDA code.
""" """
...@@ -134,7 +134,7 @@ def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, ...@@ -134,7 +134,7 @@ def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32,
check_non_fastmath_usage(source_fastmath, mathop_name) check_non_fastmath_usage(source_fastmath, mathop_name)
# Test numerical correctness # Test numerical correctness
torch_dtype = getattr(torch, dtype) torch_dtype = dtype.as_torch()
a = torch.randn(M, N, device="cuda", dtype=torch_dtype) a = torch.randn(M, N, device="cuda", dtype=torch_dtype)
b = torch.randn(M, N, device="cuda", dtype=torch_dtype) b = torch.randn(M, N, device="cuda", dtype=torch_dtype)
...@@ -160,8 +160,8 @@ def run_abs_test(): ...@@ -160,8 +160,8 @@ def run_abs_test():
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), "float32"), A: T.Tensor((M, N), T.float32),
B: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), T.float32),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
...@@ -189,7 +189,7 @@ def run_abs_test(): ...@@ -189,7 +189,7 @@ def run_abs_test():
print("✓ abs numerical test passed") print("✓ abs numerical test passed")
def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32):
""" """
Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix). Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix).
""" """
...@@ -222,7 +222,7 @@ def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, ...@@ -222,7 +222,7 @@ def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32,
check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True) check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True)
# Test numerical correctness # Test numerical correctness
torch_dtype = getattr(torch, dtype) torch_dtype = dtype.as_torch()
a = torch.randn(M, N, device="cuda", dtype=torch_dtype) a = torch.randn(M, N, device="cuda", dtype=torch_dtype)
# Ensure positive values for functions that need them # Ensure positive values for functions that need them
...@@ -272,7 +272,7 @@ def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, ...@@ -272,7 +272,7 @@ def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32,
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_mathops_generate_no_fastmath(name, func): def test_mathops_generate_no_fastmath(name, func):
"""Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)""" """Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)"""
run_single_arg_mathop_test(name, func, dtype="float32") run_single_arg_mathop_test(name, func, dtype=T.float32)
print(f"✓ {name} test passed") print(f"✓ {name} test passed")
...@@ -286,7 +286,7 @@ def test_mathops_generate_no_fastmath(name, func): ...@@ -286,7 +286,7 @@ def test_mathops_generate_no_fastmath(name, func):
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_two_arg_mathops_fastmath(name, func): def test_two_arg_mathops_fastmath(name, func):
"""Test all two-argument mathops""" """Test all two-argument mathops"""
run_two_arg_mathop_test(name, func, dtype="float32") run_two_arg_mathop_test(name, func, dtype=T.float32)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
...@@ -311,7 +311,7 @@ def test_abs_maps_to_fabs(): ...@@ -311,7 +311,7 @@ def test_abs_maps_to_fabs():
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_fastmath_versions(name, func): def test_fastmath_versions(name, func):
"""Test that __exp, __exp10, __log, __log2, __log10, __tan, __cos, __sin generate fastmath CUDA code""" """Test that __exp, __exp10, __log, __log2, __log10, __tan, __cos, __sin generate fastmath CUDA code"""
run_fastmath_mathop_test(name, func, dtype="float32") run_fastmath_mathop_test(name, func, dtype=T.float32)
print(f"✓ {name} test passed") print(f"✓ {name} test passed")
......
...@@ -14,9 +14,9 @@ def _cumsum_view_infer_layout(hidden): ...@@ -14,9 +14,9 @@ def _cumsum_view_infer_layout(hidden):
num_tokens = T.dynamic("num_tokens") num_tokens = T.dynamic("num_tokens")
@T.prim_func @T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens, hidden), "float"]): def buggy_kernel(x: T.Tensor[(num_tokens, hidden), T.float]):
with T.Kernel(num_tokens, threads=128) as pid: with T.Kernel(num_tokens, threads=128) as pid:
smem = T.alloc_shared((hidden,), dtype="float") smem = T.alloc_shared((hidden,), dtype=T.float32)
T.copy(x[pid, :], smem) T.copy(x[pid, :], smem)
T.cumsum(T.view(smem, (1, hidden)), dim=1) T.cumsum(T.view(smem, (1, hidden)), dim=1)
......
...@@ -33,7 +33,7 @@ def _fill_with_dynamic_region_kernel(): ...@@ -33,7 +33,7 @@ def _fill_with_dynamic_region_kernel():
@T.prim_func @T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), "int64"]): # noqa: F821 def buggy_kernel(x: T.Tensor[(num_tokens,), "int64"]): # noqa: F821
with T.Kernel(num_tokens, threads=128) as _: with T.Kernel(num_tokens, threads=128) as _:
a, b = T.alloc_var("int"), T.alloc_var("int") a, b = T.alloc_var(T.int), T.alloc_var(T.int)
T.fill(x[a:b], 0) T.fill(x[a:b], 0)
return buggy_kernel return buggy_kernel
......
...@@ -9,7 +9,7 @@ def test_int64_address(): ...@@ -9,7 +9,7 @@ def test_int64_address():
S, S,
D, D,
pos_ty="int64", pos_ty="int64",
dtype="float32", dtype=T.float32,
): ):
@T.prim_func @T.prim_func
def main( def main(
...@@ -36,7 +36,7 @@ def test_int64_address(): ...@@ -36,7 +36,7 @@ def test_int64_address():
pos_int64 = torch.arange(S, device="cuda", dtype=torch.int64) pos_int64 = torch.arange(S, device="cuda", dtype=torch.int64)
pos_int32 = torch.arange(S, device="cuda", dtype=torch.int32) pos_int32 = torch.arange(S, device="cuda", dtype=torch.int32)
kernel_int64 = set_cache_kernel(S, D, "int64") kernel_int64 = set_cache_kernel(S, D, "int64")
kernel_int32 = set_cache_kernel(S, D, "int32") kernel_int32 = set_cache_kernel(S, D, T.int32)
kernel_int64(pos_int64, value, cache) kernel_int64(pos_int64, value, cache)
torch.testing.assert_close(cache, value) torch.testing.assert_close(cache, value)
kernel_int32(pos_int32, value, cache) kernel_int32(pos_int32, value, cache)
......
...@@ -9,7 +9,7 @@ def test_issue_1198(): ...@@ -9,7 +9,7 @@ def test_issue_1198():
[ [
32, 32,
], ],
"int32", T.int32,
), ),
): ):
pass pass
......
...@@ -4,10 +4,10 @@ import tilelang.testing ...@@ -4,10 +4,10 @@ import tilelang.testing
def _make_kernel(M, N): def _make_kernel(M, N):
dtype = "bfloat16" dtype = T.bfloat16
@T.prim_func @T.prim_func
def fwd_main(KV: T.Tensor((M, N), dtype), ids: T.Tensor((4,), "int32")): def fwd_main(KV: T.Tensor((M, N), dtype), ids: T.Tensor((4,), T.int32)):
with T.Kernel(4, threads=1): with T.Kernel(4, threads=1):
A = T.alloc_shared([N], dtype) A = T.alloc_shared([N], dtype)
B = T.alloc_shared([N], dtype) B = T.alloc_shared([N], dtype)
......
...@@ -7,12 +7,12 @@ def test_issue_1237_dynamic_copy_extent_builds(): ...@@ -7,12 +7,12 @@ def test_issue_1237_dynamic_copy_extent_builds():
# The goal is to ensure T.copy correctly handles dynamic extents # The goal is to ensure T.copy correctly handles dynamic extents
# (e.g., src slice length vs. static dst buffer size) during prim_func building. # (e.g., src slice length vs. static dst buffer size) during prim_func building.
length = T.symbolic("len", dtype="int32") length = T.symbolic("len", dtype=T.int32)
@T.prim_func @T.prim_func
def sample_kernel(global_tensor: T.Tensor[(length,), "int32"]): # noqa: F821 def sample_kernel(global_tensor: T.Tensor[(length,), T.int32]): # noqa: F821
with T.Kernel(1, threads=32): with T.Kernel(1, threads=32):
buffer_shared = T.alloc_shared((1024,), dtype="int32") buffer_shared = T.alloc_shared((1024,), dtype=T.int32)
T.copy(global_tensor[0:length], buffer_shared) T.copy(global_tensor[0:length], buffer_shared)
# Building the prim_func is sufficient to exercise the bug path; no need to JIT/execute. # Building the prim_func is sufficient to exercise the bug path; no need to JIT/execute.
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
@tilelang.jit @tilelang.jit
def _tmp_var_kernel(N, block_N, dtype="float"): def _tmp_var_kernel(N, block_N, dtype=T.float32):
@T.prim_func @T.prim_func
def kernel( def kernel(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
......
...@@ -34,7 +34,7 @@ def _empty_with_dead_code_kernel(): ...@@ -34,7 +34,7 @@ def _empty_with_dead_code_kernel():
num_tokens = T.dynamic("num_tokens") num_tokens = T.dynamic("num_tokens")
@T.prim_func @T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens,), "float32"]): def buggy_kernel(x: T.Tensor[(num_tokens,), T.float32]):
with T.Kernel(num_tokens, threads=32) as pid: with T.Kernel(num_tokens, threads=32) as pid:
y = x[pid] y = x[pid]
......
...@@ -4,7 +4,7 @@ import tilelang.language as T ...@@ -4,7 +4,7 @@ import tilelang.language as T
import torch import torch
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
......
...@@ -8,10 +8,10 @@ import tilelang.language as T ...@@ -8,10 +8,10 @@ import tilelang.language as T
def merge_if_test(): def merge_if_test():
@T.prim_func @T.prim_func
def main(): def main():
A = T.alloc_fragment((1,), "float16") A = T.alloc_fragment((1,), T.float16)
B = T.alloc_fragment((1,), "float16") B = T.alloc_fragment((1,), T.float16)
C = T.alloc_fragment((1,), "float16") C = T.alloc_fragment((1,), T.float16)
D = T.alloc_fragment((1,), "float16") D = T.alloc_fragment((1,), T.float16)
if A[0] == 0: if A[0] == 0:
A[0] = 0 A[0] = 0
if B[0] == 0: if B[0] == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment