"vscode:/vscode.git/clone" did not exist on "84bdaebf57b7921bf1ad47f3b41961bdcf03c9d9"
Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
import tilelang.testing import tilelang.testing
from tilelang import language as T
def alloc_var( def alloc_var(
...@@ -6,8 +7,6 @@ def alloc_var( ...@@ -6,8 +7,6 @@ def alloc_var(
block_N, block_N,
dtype, dtype,
): ):
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
...@@ -38,7 +37,7 @@ def run_alloc_var( ...@@ -38,7 +37,7 @@ def run_alloc_var(
def test_alloc_var(): def test_alloc_var():
run_alloc_var(1024, 128, "float16") run_alloc_var(1024, 128, T.float16)
def alloc_var_add( def alloc_var_add(
...@@ -78,7 +77,7 @@ def run_alloc_var_add( ...@@ -78,7 +77,7 @@ def run_alloc_var_add(
def test_alloc_var_add(): def test_alloc_var_add():
run_alloc_var_add(1024, 128, "float16") run_alloc_var_add(1024, 128, T.float16)
def alloc_var_with_initializer( def alloc_var_with_initializer(
...@@ -117,7 +116,7 @@ def run_alloc_var_with_initializer( ...@@ -117,7 +116,7 @@ def run_alloc_var_with_initializer(
def test_alloc_var_with_initializer(): def test_alloc_var_with_initializer():
run_alloc_var_with_initializer(256, 64, "int32", 5) run_alloc_var_with_initializer(256, 64, T.int32, 5)
def alloc_multi_vars_with_initializer( def alloc_multi_vars_with_initializer(
...@@ -156,7 +155,7 @@ def run_alloc_multi_vars_with_initializer( ...@@ -156,7 +155,7 @@ def run_alloc_multi_vars_with_initializer(
def test_alloc_multi_vars_with_initializer(): def test_alloc_multi_vars_with_initializer():
run_alloc_multi_vars_with_initializer(256, 64, "int32") run_alloc_multi_vars_with_initializer(256, 64, T.int32)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0): def tilelang_copy(M, N, block_M, block_N, dtype=T.float16, pad_value=0):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
...@@ -26,7 +26,7 @@ def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0): ...@@ -26,7 +26,7 @@ def tilelang_copy(M, N, block_M, block_N, dtype="float16", pad_value=0):
return main return main
def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16", pad_value=0): def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16, pad_value=0):
program = tilelang_copy(M, N, block_M, block_N, dtype, pad_value=pad_value) program = tilelang_copy(M, N, block_M, block_N, dtype, pad_value=pad_value)
kernel = tilelang.compile( kernel = tilelang.compile(
program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
......
...@@ -31,8 +31,8 @@ def blocksparse_matmul_global( ...@@ -31,8 +31,8 @@ def blocksparse_matmul_global(
num_stages, num_stages,
thread_num, thread_num,
enable_rasteration, enable_rasteration,
dtype="float16", dtype=T.float16,
accum_dtype="float", accum_dtype=T.float32,
): ):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
...@@ -75,8 +75,8 @@ def blocksparse_matmul_shared( ...@@ -75,8 +75,8 @@ def blocksparse_matmul_shared(
num_stages, num_stages,
thread_num, thread_num,
enable_rasteration, enable_rasteration,
dtype="float16", dtype=T.float16,
accum_dtype="float", accum_dtype=T.float32,
): ):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
...@@ -124,8 +124,8 @@ def blocksparse_matmul_local( ...@@ -124,8 +124,8 @@ def blocksparse_matmul_local(
num_stages, num_stages,
thread_num, thread_num,
enable_rasteration, enable_rasteration,
dtype="float16", dtype=T.float16,
accum_dtype="float", accum_dtype=T.float32,
): ):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
......
...@@ -9,7 +9,7 @@ def test_assume_remove_boundary_check(): ...@@ -9,7 +9,7 @@ def test_assume_remove_boundary_check():
N = T.dynamic("N") N = T.dynamic("N")
@T.prim_func @T.prim_func
def main(A: T.Tensor((N,), "float32"), l: T.int32, r: T.int32): def main(A: T.Tensor((N,), T.float32), l: T.int32, r: T.int32):
with T.Kernel(1, threads=32) as _: with T.Kernel(1, threads=32) as _:
for i in T.serial(r - l + 1): for i in T.serial(r - l + 1):
T.assume(l + i >= 0 and l + i < N) T.assume(l + i >= 0 and l + i < N)
...@@ -31,8 +31,8 @@ def test_assume_enable_vectorization(): ...@@ -31,8 +31,8 @@ def test_assume_enable_vectorization():
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), "float32"), A: T.Tensor((M, N), T.float32),
B: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), T.float32),
): ):
with T.Kernel(1, threads=32) as _: with T.Kernel(1, threads=32) as _:
tid = T.get_thread_binding() tid = T.get_thread_binding()
...@@ -60,8 +60,8 @@ def test_assume_complex_indexing(): ...@@ -60,8 +60,8 @@ def test_assume_complex_indexing():
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), "float32"), A: T.Tensor((M, N), T.float32),
B: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), T.float32),
): ):
with T.Kernel(1, threads=32) as _: with T.Kernel(1, threads=32) as _:
tid = T.get_thread_binding() tid = T.get_thread_binding()
......
...@@ -3,7 +3,7 @@ import tilelang.language as T ...@@ -3,7 +3,7 @@ import tilelang.language as T
@tilelang.jit @tilelang.jit
def atomic_add_program(K, M, N, block_M, block_N, dtype="float"): def atomic_add_program(K, M, N, block_M, block_N, dtype=T.float32):
@T.prim_func @T.prim_func
def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
...@@ -17,7 +17,7 @@ def atomic_add_program(K, M, N, block_M, block_N, dtype="float"): ...@@ -17,7 +17,7 @@ def atomic_add_program(K, M, N, block_M, block_N, dtype="float"):
return atomic_add return atomic_add
def run_atomic_add(K, M, N, block_M, block_N, dtype="float32"): def run_atomic_add(K, M, N, block_M, block_N, dtype=T.float32):
kernel = atomic_add_program(K, M, N, block_M, block_N, dtype=dtype) kernel = atomic_add_program(K, M, N, block_M, block_N, dtype=dtype)
import torch import torch
...@@ -36,7 +36,7 @@ def run_atomic_add(K, M, N, block_M, block_N, dtype="float32"): ...@@ -36,7 +36,7 @@ def run_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
@tilelang.jit @tilelang.jit
def tile_atomic_add_program(K, M, N, block_M, block_N, dtype="float"): def tile_atomic_add_program(K, M, N, block_M, block_N, dtype=T.float32):
@T.prim_func @T.prim_func
def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
...@@ -49,7 +49,7 @@ def tile_atomic_add_program(K, M, N, block_M, block_N, dtype="float"): ...@@ -49,7 +49,7 @@ def tile_atomic_add_program(K, M, N, block_M, block_N, dtype="float"):
return atomic_add return atomic_add
def run_tile_atomic_add(K, M, N, block_M, block_N, dtype="float32"): def run_tile_atomic_add(K, M, N, block_M, block_N, dtype=T.float32):
kernel = tile_atomic_add_program(K, M, N, block_M, block_N, dtype=dtype) kernel = tile_atomic_add_program(K, M, N, block_M, block_N, dtype=dtype)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
import torch import torch
...@@ -71,7 +71,7 @@ def run_tile_atomic_add(K, M, N, block_M, block_N, dtype="float32"): ...@@ -71,7 +71,7 @@ def run_tile_atomic_add(K, M, N, block_M, block_N, dtype="float32"):
@tilelang.jit @tilelang.jit
def atomic_max_program(K, M, N, block_M, block_N, dtype="float"): def atomic_max_program(K, M, N, block_M, block_N, dtype=T.float32):
@T.prim_func @T.prim_func
def atomic_max(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): def atomic_max(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
...@@ -85,7 +85,7 @@ def atomic_max_program(K, M, N, block_M, block_N, dtype="float"): ...@@ -85,7 +85,7 @@ def atomic_max_program(K, M, N, block_M, block_N, dtype="float"):
return atomic_max return atomic_max
def run_atomic_max(K, M, N, block_M, block_N, dtype="float32"): def run_atomic_max(K, M, N, block_M, block_N, dtype=T.float32):
kernel = atomic_max_program(K, M, N, block_M, block_N, dtype=dtype) kernel = atomic_max_program(K, M, N, block_M, block_N, dtype=dtype)
import torch import torch
...@@ -104,7 +104,7 @@ def run_atomic_max(K, M, N, block_M, block_N, dtype="float32"): ...@@ -104,7 +104,7 @@ def run_atomic_max(K, M, N, block_M, block_N, dtype="float32"):
@tilelang.jit @tilelang.jit
def atomic_min_program(K, M, N, block_M, block_N, dtype="float"): def atomic_min_program(K, M, N, block_M, block_N, dtype=T.float32):
@T.prim_func @T.prim_func
def atomic_min(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): def atomic_min(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
...@@ -118,7 +118,7 @@ def atomic_min_program(K, M, N, block_M, block_N, dtype="float"): ...@@ -118,7 +118,7 @@ def atomic_min_program(K, M, N, block_M, block_N, dtype="float"):
return atomic_min return atomic_min
def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"): def run_atomic_min(K, M, N, block_M, block_N, dtype=T.float32):
kernel = atomic_min_program(K, M, N, block_M, block_N, dtype=dtype) kernel = atomic_min_program(K, M, N, block_M, block_N, dtype=dtype)
import torch import torch
...@@ -137,7 +137,7 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"): ...@@ -137,7 +137,7 @@ def run_atomic_min(K, M, N, block_M, block_N, dtype="float32"):
@tilelang.jit @tilelang.jit
def atomic_load_store_program(M, N, block_M, block_N, dtype="float"): def atomic_load_store_program(M, N, block_M, block_N, dtype=T.float32):
@T.prim_func @T.prim_func
def atomic_load_store(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)): def atomic_load_store(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
...@@ -151,7 +151,7 @@ def atomic_load_store_program(M, N, block_M, block_N, dtype="float"): ...@@ -151,7 +151,7 @@ def atomic_load_store_program(M, N, block_M, block_N, dtype="float"):
return atomic_load_store return atomic_load_store
def run_atomic_load_store(M, N, block_M, block_N, dtype="float32"): def run_atomic_load_store(M, N, block_M, block_N, dtype=T.float32):
kernel = atomic_load_store_program(M, N, block_M, block_N, dtype=dtype) kernel = atomic_load_store_program(M, N, block_M, block_N, dtype=dtype)
import torch import torch
...@@ -162,7 +162,7 @@ def run_atomic_load_store(M, N, block_M, block_N, dtype="float32"): ...@@ -162,7 +162,7 @@ def run_atomic_load_store(M, N, block_M, block_N, dtype="float32"):
@tilelang.jit @tilelang.jit
def atomic_memory_order_program(K, M, N, block_M, block_N, dtype="float"): def atomic_memory_order_program(K, M, N, block_M, block_N, dtype=T.float32):
@T.prim_func @T.prim_func
def atomic_with_memory_order(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): def atomic_with_memory_order(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz):
...@@ -176,7 +176,7 @@ def atomic_memory_order_program(K, M, N, block_M, block_N, dtype="float"): ...@@ -176,7 +176,7 @@ def atomic_memory_order_program(K, M, N, block_M, block_N, dtype="float"):
return atomic_with_memory_order return atomic_with_memory_order
def run_atomic_memory_order(K, M, N, block_M, block_N, dtype="float32"): def run_atomic_memory_order(K, M, N, block_M, block_N, dtype=T.float32):
kernel = atomic_memory_order_program(K, M, N, block_M, block_N, dtype=dtype) kernel = atomic_memory_order_program(K, M, N, block_M, block_N, dtype=dtype)
import torch import torch
...@@ -197,7 +197,7 @@ def run_atomic_memory_order(K, M, N, block_M, block_N, dtype="float32"): ...@@ -197,7 +197,7 @@ def run_atomic_memory_order(K, M, N, block_M, block_N, dtype="float32"):
@tilelang.jit @tilelang.jit
def atomic_addx2_program(M, N, block_M, block_N): def atomic_addx2_program(M, N, block_M, block_N):
@T.prim_func @T.prim_func
def atomic_addx2(A: T.Tensor((M, N), "float16"), B: T.Tensor((M, N), "float16")): def atomic_addx2(A: T.Tensor((M, N), T.float16), B: T.Tensor((M, N), T.float16)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
for i, j in T.Parallel(block_M, block_N // 2): for i, j in T.Parallel(block_M, block_N // 2):
idx_i = bx * block_M + i idx_i = bx * block_M + i
...@@ -248,7 +248,7 @@ def test_atomic_addx2(): ...@@ -248,7 +248,7 @@ def test_atomic_addx2():
@tilelang.jit @tilelang.jit
def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"): def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype=T.float32):
@T.prim_func @T.prim_func
def atomic_different_orders( def atomic_different_orders(
A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype), D: T.Tensor((M, N), dtype) A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), C: T.Tensor((M, N), dtype), D: T.Tensor((M, N), dtype)
...@@ -266,7 +266,7 @@ def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float" ...@@ -266,7 +266,7 @@ def atomic_different_memory_orders_program(M, N, block_M, block_N, dtype="float"
return atomic_different_orders return atomic_different_orders
def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype="float32"): def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype=T.float32):
kernel = atomic_different_memory_orders_program(M, N, block_M, block_N, dtype=dtype) kernel = atomic_different_memory_orders_program(M, N, block_M, block_N, dtype=dtype)
import torch import torch
...@@ -285,7 +285,7 @@ def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype="float32"): ...@@ -285,7 +285,7 @@ def run_atomic_different_memory_orders(M, N, block_M, block_N, dtype="float32"):
@tilelang.jit @tilelang.jit
def atomic_addx4_program(M, N, block_M, block_N): def atomic_addx4_program(M, N, block_M, block_N):
@T.prim_func @T.prim_func
def atomic_addx4(A: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), "float32")): def atomic_addx4(A: T.Tensor((M, N), T.float32), B: T.Tensor((M, N), T.float32)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
for i, j in T.Parallel(block_M, block_N // 4): for i, j in T.Parallel(block_M, block_N // 4):
idx_i = bx * block_M + i idx_i = bx * block_M + i
...@@ -315,7 +315,7 @@ def run_atomic_addx4(M, N, block_M, block_N): ...@@ -315,7 +315,7 @@ def run_atomic_addx4(M, N, block_M, block_N):
@tilelang.jit @tilelang.jit
def atomic_return_prev_program(M, N, block_M, block_N, dtype="float"): def atomic_return_prev_program(M, N, block_M, block_N, dtype=T.float32):
@T.prim_func @T.prim_func
def atomic_with_return_prev(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), old_vals: T.Tensor((M, N), dtype)): def atomic_with_return_prev(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype), old_vals: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by): with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=32) as (bx, by):
...@@ -328,7 +328,7 @@ def atomic_return_prev_program(M, N, block_M, block_N, dtype="float"): ...@@ -328,7 +328,7 @@ def atomic_return_prev_program(M, N, block_M, block_N, dtype="float"):
return atomic_with_return_prev return atomic_with_return_prev
def run_atomic_return_prev(M, N, block_M, block_N, dtype="float32"): def run_atomic_return_prev(M, N, block_M, block_N, dtype=T.float32):
kernel = atomic_return_prev_program(M, N, block_M, block_N, dtype=dtype) kernel = atomic_return_prev_program(M, N, block_M, block_N, dtype=dtype)
import torch import torch
...@@ -344,9 +344,9 @@ def run_atomic_return_prev(M, N, block_M, block_N, dtype="float32"): ...@@ -344,9 +344,9 @@ def run_atomic_return_prev(M, N, block_M, block_N, dtype="float32"):
def test_atomic_different_memory_orders(): def test_atomic_different_memory_orders():
run_atomic_different_memory_orders(32, 32, 8, 8, dtype="float") run_atomic_different_memory_orders(32, 32, 8, 8, dtype=T.float32)
run_atomic_different_memory_orders(32, 32, 8, 8, dtype="float16") run_atomic_different_memory_orders(32, 32, 8, 8, dtype=T.float16)
run_atomic_different_memory_orders(32, 32, 8, 8, dtype="bfloat16") run_atomic_different_memory_orders(32, 32, 8, 8, dtype=T.bfloat16)
def test_atomic_addx4(): def test_atomic_addx4():
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def _ceildiv_kernel(a: int, b: int): def _ceildiv_kernel(a: int, b: int):
@T.prim_func @T.prim_func
def ceildiv_kernel(A: T.Tensor((1,), "int32")): def ceildiv_kernel(A: T.Tensor((1,), T.int32)):
with T.Kernel(1, threads=1) as _: with T.Kernel(1, threads=1) as _:
A[0] = T.ceildiv(T.int32(a), T.int32(b)) A[0] = T.ceildiv(T.int32(a), T.int32(b))
...@@ -30,7 +30,7 @@ def test_ceildiv(): ...@@ -30,7 +30,7 @@ def test_ceildiv():
@tilelang.jit @tilelang.jit
def _ceildiv_kernel_dyn(b: int): def _ceildiv_kernel_dyn(b: int):
@T.prim_func @T.prim_func
def ceildiv_kernel(A: T.Tensor((1,), "int32"), a: T.int32): def ceildiv_kernel(A: T.Tensor((1,), T.int32), a: T.int32):
with T.Kernel(1, threads=1) as _: with T.Kernel(1, threads=1) as _:
A[0] = T.ceildiv(T.int32(a), T.int32(b)) A[0] = T.ceildiv(T.int32(a), T.int32(b))
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}, },
) )
def chain_equal(N, block_size, dtype="float32"): def chain_equal(N, block_size, dtype=T.float32):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
...@@ -25,7 +25,7 @@ def chain_equal(N, block_size, dtype="float32"): ...@@ -25,7 +25,7 @@ def chain_equal(N, block_size, dtype="float32"):
return main return main
def run_chain_equal(N=128, block_size=64, dtype="float32"): def run_chain_equal(N=128, block_size=64, dtype=T.float32):
kernel = chain_equal(N, block_size, dtype) kernel = chain_equal(N, block_size, dtype)
A = torch.zeros((N,), dtype=torch.float32, device="cuda") A = torch.zeros((N,), dtype=torch.float32, device="cuda")
B = torch.zeros((N,), dtype=torch.float32, device="cuda") B = torch.zeros((N,), dtype=torch.float32, device="cuda")
......
import tilelang.testing import tilelang.testing
from tilelang.utils.tensor import map_torch_type from tilelang import language as T
def clamp_within_bounds( def clamp_within_bounds(
...@@ -91,7 +91,7 @@ def run_clamp_value_range( ...@@ -91,7 +91,7 @@ def run_clamp_value_range(
import torch import torch
# Convert string dtype to torch.dtype # Convert string dtype to torch.dtype
torch_dtype = map_torch_type(dtype) torch_dtype = dtype.as_torch()
def ref_program(A): def ref_program(A):
min_val = torch.min(A) * 0.5 min_val = torch.min(A) * 0.5
...@@ -107,10 +107,10 @@ def run_clamp_value_range( ...@@ -107,10 +107,10 @@ def run_clamp_value_range(
def test_clamp(): def test_clamp():
# clamp tests for float16 and float32 # clamp tests for float16 and float32
run_clamp(1024, 128, "float16", -0.05, 0.05) run_clamp(1024, 128, T.float16, -0.05, 0.05)
run_clamp(1024, 128, "float32", -0.06, 0.05) run_clamp(1024, 128, T.float32, -0.06, 0.05)
run_clamp_value_range(1024, 128, "float16") run_clamp_value_range(1024, 128, T.float16)
run_clamp_value_range(1024, 128, "float32") run_clamp_value_range(1024, 128, T.float32)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -4,7 +4,7 @@ import tilelang.language as T ...@@ -4,7 +4,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
...@@ -39,7 +39,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -39,7 +39,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return main return main
def run_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def run_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
kernel = tilelang.compile(program, out_idx=[2], target="cuda", pass_configs={"tl.disable_tma_lower": True}) kernel = tilelang.compile(program, out_idx=[2], target="cuda", pass_configs={"tl.disable_tma_lower": True})
import torch import torch
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def tilelang_composable_copy(M, N, block_M, block_N, dtype="float16"): def tilelang_composable_copy(M, N, block_M, block_N, dtype=T.float16):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
...@@ -25,7 +25,7 @@ def tilelang_composable_copy(M, N, block_M, block_N, dtype="float16"): ...@@ -25,7 +25,7 @@ def tilelang_composable_copy(M, N, block_M, block_N, dtype="float16"):
return main return main
def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16):
program = tilelang_composable_copy(M, N, block_M, block_N, dtype) program = tilelang_composable_copy(M, N, block_M, block_N, dtype)
kernel = tilelang.compile( kernel = tilelang.compile(
program, program,
...@@ -44,7 +44,7 @@ def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype ...@@ -44,7 +44,7 @@ def run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128, dtype
def test_tilelang_copy(): def test_tilelang_copy():
run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128) run_tilelang_composable_copy(M=1024, N=1024, block_M=128, block_N=128)
run_tilelang_composable_copy(M=1024, N=576, block_M=32, block_N=576) run_tilelang_composable_copy(M=1024, N=576, block_M=32, block_N=576)
run_tilelang_composable_copy(M=1024, N=576, block_M=32, block_N=576, dtype="float") run_tilelang_composable_copy(M=1024, N=576, block_M=32, block_N=576, dtype=T.float32)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -8,7 +8,7 @@ print(torch.__version__) ...@@ -8,7 +8,7 @@ print(torch.__version__)
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def tilelang_copy(M, N, block_M, block_N, src_dtype="float16", dst_dtype="float16"): def tilelang_copy(M, N, block_M, block_N, src_dtype=T.float16, dst_dtype=T.float16):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), src_dtype), A: T.Tensor((M, N), src_dtype),
...@@ -24,7 +24,7 @@ def tilelang_copy(M, N, block_M, block_N, src_dtype="float16", dst_dtype="float1 ...@@ -24,7 +24,7 @@ def tilelang_copy(M, N, block_M, block_N, src_dtype="float16", dst_dtype="float1
return main return main
def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16):
program = tilelang_copy(M, N, block_M, block_N, src_dtype=dtype, dst_dtype=dtype) program = tilelang_copy(M, N, block_M, block_N, src_dtype=dtype, dst_dtype=dtype)
kernel = tilelang.compile( kernel = tilelang.compile(
program, program,
...@@ -42,10 +42,10 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16") ...@@ -42,10 +42,10 @@ def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16")
def test_tilelang_copy(): def test_tilelang_copy():
run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128) run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128)
run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576) run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576)
run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576, dtype="float") run_tilelang_copy(M=1024, N=576, block_M=32, block_N=576, dtype=T.float32)
def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"): def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype=T.float16):
@T.prim_func @T.prim_func
def main( def main(
A: T.StridedTensor((M, N), (NN, 1), dtype), A: T.StridedTensor((M, N), (NN, 1), dtype),
...@@ -59,7 +59,7 @@ def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"): ...@@ -59,7 +59,7 @@ def tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype="float16"):
return main return main
def run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128, dtype="float16"): def run_tilelang_copy_with_stride(M=1024, N=1024, NN=2048, block_M=128, block_N=128, dtype=T.float16):
if isinstance(NN, int): if isinstance(NN, int):
assert NN > N, "NN must be greater than N" assert NN > N, "NN must be greater than N"
program = tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype) program = tilelang_copy_with_stride(M, N, NN, block_M, block_N, dtype)
...@@ -84,21 +84,21 @@ def test_tilelang_copy_with_stride(): ...@@ -84,21 +84,21 @@ def test_tilelang_copy_with_stride():
run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.dynamic("NN"), block_M=128, block_N=128) run_tilelang_copy_with_stride(M=1024, N=1024, NN=T.dynamic("NN"), block_M=128, block_N=128)
def tilelang_copy_bufferload(num_tokens, dtype="float16"): def tilelang_copy_bufferload(num_tokens, dtype=T.float16):
@T.prim_func @T.prim_func
def main( def main(
indices: T.Tensor((num_tokens,), "int32"), indices: T.Tensor((num_tokens,), T.int32),
x: T.Tensor((num_tokens,), dtype), x: T.Tensor((num_tokens,), dtype),
): ):
with T.Kernel(num_tokens, threads=32) as pid: with T.Kernel(num_tokens, threads=32) as pid:
idx = T.alloc_local([1], "int32") idx = T.alloc_local([1], T.int32)
T.copy(indices[pid], idx[0]) T.copy(indices[pid], idx[0])
x[idx[0]] = x[idx[0]] + 1 x[idx[0]] = x[idx[0]] + 1
return main return main
def run_tilelang_copy_bufferload(num_tokens=128, dtype="float16"): def run_tilelang_copy_bufferload(num_tokens=128, dtype=T.float16):
program = tilelang_copy_bufferload(num_tokens, dtype) program = tilelang_copy_bufferload(num_tokens, dtype)
# test compilation only # test compilation only
tilelang.compile( tilelang.compile(
...@@ -112,7 +112,7 @@ def test_tilelang_copy_bufferload(): ...@@ -112,7 +112,7 @@ def test_tilelang_copy_bufferload():
run_tilelang_copy_bufferload(num_tokens=128) run_tilelang_copy_bufferload(num_tokens=128)
def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float16"): def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype=T.float16):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
...@@ -126,7 +126,7 @@ def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float ...@@ -126,7 +126,7 @@ def tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype="float
return main return main
def run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): def run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16):
program = tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype) program = tilelang_copy_buffer_load_with_parallel(M, N, block_M, block_N, dtype)
kernel = tilelang.compile( kernel = tilelang.compile(
program, program,
...@@ -143,7 +143,7 @@ def test_tilelang_copy_buffer_load_with_parallel(): ...@@ -143,7 +143,7 @@ def test_tilelang_copy_buffer_load_with_parallel():
run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128) run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128)
def run_tilelang_copy_fp8_e8m0(M=1024, N=1024, block_M=128, block_N=128, src_dtype="float8_e8m0fnu", dst_dtype="float8_e8m0fnu"): def run_tilelang_copy_fp8_e8m0(M=1024, N=1024, block_M=128, block_N=128, src_dtype=T.float8_e8m0fnu, dst_dtype=T.float8_e8m0fnu):
program = tilelang_copy(M, N, block_M, block_N, src_dtype=src_dtype, dst_dtype=dst_dtype) program = tilelang_copy(M, N, block_M, block_N, src_dtype=src_dtype, dst_dtype=dst_dtype)
kernel = tilelang.compile( kernel = tilelang.compile(
program, program,
...@@ -159,10 +159,10 @@ def run_tilelang_copy_fp8_e8m0(M=1024, N=1024, block_M=128, block_N=128, src_dty ...@@ -159,10 +159,10 @@ def run_tilelang_copy_fp8_e8m0(M=1024, N=1024, block_M=128, block_N=128, src_dty
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(10, 0) @tilelang.testing.requires_cuda_compute_version_ge(10, 0)
def test_tilelang_copy_fp8_e8m0(): def test_tilelang_copy_fp8_e8m0():
run_tilelang_copy_fp8_e8m0(src_dtype="float8_e8m0fnu", dst_dtype="float8_e8m0fnu") run_tilelang_copy_fp8_e8m0(src_dtype=T.float8_e8m0fnu, dst_dtype=T.float8_e8m0fnu)
def run_tilelang_copy_fp4(M=1024, N=1024, block_M=128, block_N=128, src_dtype="float4_e2m1fn", dst_dtype="float4_e2m1fn"): def run_tilelang_copy_fp4(M=1024, N=1024, block_M=128, block_N=128, src_dtype=T.float4_e2m1fn, dst_dtype=T.float4_e2m1fn):
program = tilelang_copy(M, N, block_M, block_N, src_dtype=src_dtype, dst_dtype=dst_dtype) program = tilelang_copy(M, N, block_M, block_N, src_dtype=src_dtype, dst_dtype=dst_dtype)
kernel = tilelang.compile( kernel = tilelang.compile(
program, program,
...@@ -179,9 +179,9 @@ def run_tilelang_copy_fp4(M=1024, N=1024, block_M=128, block_N=128, src_dtype="f ...@@ -179,9 +179,9 @@ def run_tilelang_copy_fp4(M=1024, N=1024, block_M=128, block_N=128, src_dtype="f
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(10, 0) @tilelang.testing.requires_cuda_compute_version_ge(10, 0)
def test_tilelang_copy_fp4(): def test_tilelang_copy_fp4():
run_tilelang_copy_fp4(src_dtype="float4_e2m1fn", dst_dtype="float4_e2m1fn") run_tilelang_copy_fp4(src_dtype=T.float4_e2m1fn, dst_dtype=T.float4_e2m1fn)
run_tilelang_copy_fp4(src_dtype="float4_e2m1fn", dst_dtype="float16") run_tilelang_copy_fp4(src_dtype=T.float4_e2m1fn, dst_dtype=T.float16)
run_tilelang_copy_fp4(src_dtype="float4_e2m1fn", dst_dtype="bfloat16") run_tilelang_copy_fp4(src_dtype=T.float4_e2m1fn, dst_dtype=T.bfloat16)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -2,11 +2,10 @@ from tilelang import tvm as tvm ...@@ -2,11 +2,10 @@ from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang as tl import tilelang as tl
import torch import torch
import tilelang.language as T
def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32"): def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype=T.float32):
import tilelang.language as T
@T.prim_func @T.prim_func
def cumsum( def cumsum(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
...@@ -23,7 +22,7 @@ def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float3 ...@@ -23,7 +22,7 @@ def cumsum_smem_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float3
return cumsum return cumsum
def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32"): def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype=T.float32):
import tilelang.language as T import tilelang.language as T
@T.prim_func @T.prim_func
...@@ -44,7 +43,7 @@ def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="fl ...@@ -44,7 +43,7 @@ def cumsum_fragment_test(M, N, block_M, block_N, dim=0, reverse=False, dtype="fl
return cumsum return cumsum
def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32", scope="smem"): def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype=T.float32, scope="smem"):
if scope == "smem": if scope == "smem":
program = cumsum_smem_test(M, N, block_M, block_N, dim, reverse, dtype) program = cumsum_smem_test(M, N, block_M, block_N, dim, reverse, dtype)
elif scope == "fragment": elif scope == "fragment":
...@@ -74,7 +73,7 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32", sc ...@@ -74,7 +73,7 @@ def run_cumsum(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32", sc
torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3) torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3)
def cumsum_smem_test_1d(N, block_N, reverse=False, dtype="float32"): def cumsum_smem_test_1d(N, block_N, reverse=False, dtype=T.float32):
import tilelang.language as T import tilelang.language as T
@T.prim_func @T.prim_func
...@@ -92,7 +91,7 @@ def cumsum_smem_test_1d(N, block_N, reverse=False, dtype="float32"): ...@@ -92,7 +91,7 @@ def cumsum_smem_test_1d(N, block_N, reverse=False, dtype="float32"):
return cumsum return cumsum
def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype="float32"): def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype=T.float32):
import tilelang.language as T import tilelang.language as T
@T.prim_func @T.prim_func
...@@ -112,7 +111,7 @@ def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype="float32"): ...@@ -112,7 +111,7 @@ def cumsum_fragment_test_1d(N, block_N, reverse=False, dtype="float32"):
return cumsum return cumsum
def run_cumsum_1d(N, block_N, reverse=False, dtype="float32", scope="smem"): def run_cumsum_1d(N, block_N, reverse=False, dtype=T.float32, scope="smem"):
if scope == "smem": if scope == "smem":
program = cumsum_smem_test_1d(N, block_N, reverse, dtype) program = cumsum_smem_test_1d(N, block_N, reverse, dtype)
elif scope == "fragment": elif scope == "fragment":
...@@ -150,8 +149,8 @@ def test_cumsum_smem(): ...@@ -150,8 +149,8 @@ def test_cumsum_smem():
run_cumsum(1024, 1024, 128, 128, dim=1, reverse=True) run_cumsum(1024, 1024, 128, 128, dim=1, reverse=True)
# Test different dtypes # Test different dtypes
run_cumsum(256, 256, 128, 128, dtype="float32") run_cumsum(256, 256, 128, 128, dtype=T.float32)
run_cumsum(256, 256, 128, 128, dtype="float32") run_cumsum(256, 256, 128, 128, dtype=T.float32)
def test_cumsum_fragment(): def test_cumsum_fragment():
...@@ -160,8 +159,8 @@ def test_cumsum_fragment(): ...@@ -160,8 +159,8 @@ def test_cumsum_fragment():
run_cumsum(1024, 1024, 128, 128, dim=1, reverse=True, scope="fragment") run_cumsum(1024, 1024, 128, 128, dim=1, reverse=True, scope="fragment")
# Test different dtypes # Test different dtypes
run_cumsum(256, 256, 128, 128, dtype="float32", scope="fragment") run_cumsum(256, 256, 128, 128, dtype=T.float32, scope="fragment")
run_cumsum(256, 256, 128, 128, dtype="float32", scope="fragment") run_cumsum(256, 256, 128, 128, dtype=T.float32, scope="fragment")
def test_cumsum_smem_1d(): def test_cumsum_smem_1d():
...@@ -174,7 +173,7 @@ def test_cumsum_fragment_1d(): ...@@ -174,7 +173,7 @@ def test_cumsum_fragment_1d():
run_cumsum_1d(1024, 128, reverse=True, scope="fragment") run_cumsum_1d(1024, 128, reverse=True, scope="fragment")
def cumsum_region_test_1d(N, chunk_size, reverse=False, dtype="float32"): def cumsum_region_test_1d(N, chunk_size, reverse=False, dtype=T.float32):
"""Test cumsum with buffer region (slice) as input.""" """Test cumsum with buffer region (slice) as input."""
import tilelang.language as T import tilelang.language as T
...@@ -198,7 +197,7 @@ def cumsum_region_test_1d(N, chunk_size, reverse=False, dtype="float32"): ...@@ -198,7 +197,7 @@ def cumsum_region_test_1d(N, chunk_size, reverse=False, dtype="float32"):
return cumsum_region return cumsum_region
def run_cumsum_region_1d(N, chunk_size, reverse=False, dtype="float32"): def run_cumsum_region_1d(N, chunk_size, reverse=False, dtype=T.float32):
"""Run test for cumsum with region input.""" """Run test for cumsum with region input."""
program = cumsum_region_test_1d(N, chunk_size, reverse, dtype) program = cumsum_region_test_1d(N, chunk_size, reverse, dtype)
jit_kernel = tl.compile(program, out_idx=-1) jit_kernel = tl.compile(program, out_idx=-1)
...@@ -224,7 +223,7 @@ def run_cumsum_region_1d(N, chunk_size, reverse=False, dtype="float32"): ...@@ -224,7 +223,7 @@ def run_cumsum_region_1d(N, chunk_size, reverse=False, dtype="float32"):
torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3) torch.testing.assert_close(tilelang_res, ref_res, atol=1e-3, rtol=1e-3)
def cumsum_region_test_2d(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32"): def cumsum_region_test_2d(M, N, block_M, block_N, dim=0, reverse=False, dtype=T.float32):
"""Test cumsum with buffer region (slice) as input in 2D.""" """Test cumsum with buffer region (slice) as input in 2D."""
import tilelang.language as T import tilelang.language as T
...@@ -253,7 +252,7 @@ def cumsum_region_test_2d(M, N, block_M, block_N, dim=0, reverse=False, dtype="f ...@@ -253,7 +252,7 @@ def cumsum_region_test_2d(M, N, block_M, block_N, dim=0, reverse=False, dtype="f
return cumsum_region return cumsum_region
def run_cumsum_region_2d(M, N, block_M, block_N, dim=0, reverse=False, dtype="float32"): def run_cumsum_region_2d(M, N, block_M, block_N, dim=0, reverse=False, dtype=T.float32):
"""Run test for cumsum with 2D region input.""" """Run test for cumsum with 2D region input."""
program = cumsum_region_test_2d(M, N, block_M, block_N, dim, reverse, dtype) program = cumsum_region_test_2d(M, N, block_M, block_N, dim, reverse, dtype)
jit_kernel = tl.compile(program, out_idx=-1) jit_kernel = tl.compile(program, out_idx=-1)
......
...@@ -303,8 +303,8 @@ def test_serial_for_with_step(): ...@@ -303,8 +303,8 @@ def test_serial_for_with_step():
assert torch.all(res == ref), f"Expected {ref}, but got {res}" assert torch.all(res == ref), f"Expected {ref}, but got {res}"
assert isinstance(T.serial(1, 10, 1), IRBuilderFrame) assert isinstance(T.serial(1, 10, 1), IRBuilderFrame)
assert isinstance(T.serial(1, 10, IntImm("int32", 1)), IRBuilderFrame) assert isinstance(T.serial(1, 10, IntImm(T.int32, 1)), IRBuilderFrame)
assert not isinstance(T.serial(1, 10, Var("tmp", "int32")), IRBuilderFrame) assert not isinstance(T.serial(1, 10, Var("tmp", T.int32)), IRBuilderFrame)
assert not isinstance(T.serial(10, -1, -1), IRBuilderFrame) assert not isinstance(T.serial(10, -1, -1), IRBuilderFrame)
...@@ -433,7 +433,7 @@ def test_frame_inside_macro(): ...@@ -433,7 +433,7 @@ def test_frame_inside_macro():
idx_out: T.Tensor[(32,), T.int32], idx_out: T.Tensor[(32,), T.int32],
): ):
with T.Kernel(num_blocks, threads=32) as block_idx: # noqa: F841 with T.Kernel(num_blocks, threads=32) as block_idx: # noqa: F841
fragment = T.alloc_fragment(32, "int32") fragment = T.alloc_fragment(32, T.int32)
T.copy(idx_out, fragment) T.copy(idx_out, fragment)
for i in T.Parallel(32): for i in T.Parallel(32):
...@@ -458,10 +458,10 @@ def test_buffer_slice_step(): ...@@ -458,10 +458,10 @@ def test_buffer_slice_step():
def test_boolop(): def test_boolop():
a = Var("a", "int32") a = Var("a", T.int32)
b = Var("b", "int32") b = Var("b", T.int32)
c = Var("c", "int32") c = Var("c", T.int32)
d = Var("d", "int32") d = Var("d", T.int32)
@T.macro @T.macro
def cond(): def cond():
......
...@@ -24,7 +24,7 @@ def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int: ...@@ -24,7 +24,7 @@ def _resolve_warps_per_group(warps_per_group: Optional[int]) -> int:
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None): def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@T.prim_func @T.prim_func
def laneid_kernel(A: T.Tensor((num_threads,), "int32")): def laneid_kernel(A: T.Tensor((num_threads,), T.int32)):
with T.Kernel(1, threads=num_threads) as _: with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding() tx = T.get_thread_binding()
A[tx] = T.get_lane_idx(warp_size) A[tx] = T.get_lane_idx(warp_size)
...@@ -35,7 +35,7 @@ def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None): ...@@ -35,7 +35,7 @@ def _get_laneid_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = None): def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@T.prim_func @T.prim_func
def warp_idx_sync_kernel(A: T.Tensor((num_threads,), "int32")): def warp_idx_sync_kernel(A: T.Tensor((num_threads,), T.int32)):
with T.Kernel(1, threads=num_threads) as _: with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding() tx = T.get_thread_binding()
A[tx] = T.get_warp_idx_sync(warp_size) A[tx] = T.get_warp_idx_sync(warp_size)
...@@ -46,7 +46,7 @@ def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] = ...@@ -46,7 +46,7 @@ def _get_warp_idx_sync_kernel(num_threads: int = 128, warp_size: Optional[int] =
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def _get_warp_idx_kernel(num_threads: int = 128, warp_size: Optional[int] = None): def _get_warp_idx_kernel(num_threads: int = 128, warp_size: Optional[int] = None):
@T.prim_func @T.prim_func
def warp_idx_kernel(A: T.Tensor((num_threads,), "int32")): def warp_idx_kernel(A: T.Tensor((num_threads,), T.int32)):
with T.Kernel(1, threads=num_threads) as _: with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding() tx = T.get_thread_binding()
A[tx] = T.get_warp_idx(warp_size) A[tx] = T.get_warp_idx(warp_size)
...@@ -61,7 +61,7 @@ def _get_warp_group_idx_kernel( ...@@ -61,7 +61,7 @@ def _get_warp_group_idx_kernel(
warps_per_group: Optional[int] = None, warps_per_group: Optional[int] = None,
): ):
@T.prim_func @T.prim_func
def warp_group_idx_kernel(A: T.Tensor((num_threads,), "int32")): def warp_group_idx_kernel(A: T.Tensor((num_threads,), T.int32)):
with T.Kernel(1, threads=num_threads) as _: with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding() tx = T.get_thread_binding()
A[tx] = T.get_warp_group_idx(warp_size, warps_per_group) A[tx] = T.get_warp_group_idx(warp_size, warps_per_group)
...@@ -72,7 +72,7 @@ def _get_warp_group_idx_kernel( ...@@ -72,7 +72,7 @@ def _get_warp_group_idx_kernel(
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def _shuffle_elect_kernel(num_threads: int = 128, thread_extent: int = 64): def _shuffle_elect_kernel(num_threads: int = 128, thread_extent: int = 64):
@T.prim_func @T.prim_func
def shuffle_elect_kernel(A: T.Tensor((num_threads,), "int32")): def shuffle_elect_kernel(A: T.Tensor((num_threads,), T.int32)):
with T.Kernel(1, threads=num_threads) as _: with T.Kernel(1, threads=num_threads) as _:
tx = T.get_thread_binding() tx = T.get_thread_binding()
elected = T.shuffle_elect(thread_extent) elected = T.shuffle_elect(thread_extent)
......
...@@ -7,7 +7,7 @@ import tilelang.testing ...@@ -7,7 +7,7 @@ import tilelang.testing
@tilelang.jit( @tilelang.jit(
out_idx=[1], out_idx=[1],
) )
def tilelang_if_range(M, N, block_M, block_N, dtype="float16"): def tilelang_if_range(M, N, block_M, block_N, dtype=T.float16):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
...@@ -27,7 +27,7 @@ def tilelang_if_range(M, N, block_M, block_N, dtype="float16"): ...@@ -27,7 +27,7 @@ def tilelang_if_range(M, N, block_M, block_N, dtype="float16"):
return main return main
def run_tilelang_if_range(M=128, N=128, block_M=32, block_N=32, dtype="float16"): def run_tilelang_if_range(M=128, N=128, block_M=32, block_N=32, dtype=T.float16):
kernel = tilelang_if_range(M, N, block_M, block_N, dtype) kernel = tilelang_if_range(M, N, block_M, block_N, dtype)
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a) b = kernel(a)
......
...@@ -22,10 +22,10 @@ def _test_infinity(dtype: str): ...@@ -22,10 +22,10 @@ def _test_infinity(dtype: str):
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
def test_infinity(): def test_infinity():
_test_infinity("float16") _test_infinity(T.float16)
_test_infinity("bfloat16") _test_infinity(T.bfloat16)
_test_infinity("float32") _test_infinity(T.float32)
_test_infinity("float64") _test_infinity(T.float64)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -3,7 +3,7 @@ import tilelang.language as T ...@@ -3,7 +3,7 @@ import tilelang.language as T
@tilelang.jit @tilelang.jit
def fill_symbolic(value: float, dtype="bfloat16"): def fill_symbolic(value: float, dtype=T.bfloat16):
n = T.symbolic("n", "int64") n = T.symbolic("n", "int64")
block_n = 512 block_n = 512
...@@ -33,7 +33,7 @@ def test_fill_symbolic(): ...@@ -33,7 +33,7 @@ def test_fill_symbolic():
@tilelang.jit @tilelang.jit
def fill_static(n: int, value: float, dtype="bfloat16"): def fill_static(n: int, value: float, dtype=T.bfloat16):
block_n = 512 block_n = 512
@T.prim_func @T.prim_func
......
...@@ -9,8 +9,8 @@ def test_language_ldg_codegen(): ...@@ -9,8 +9,8 @@ def test_language_ldg_codegen():
@T.prim_func @T.prim_func
def main( def main(
x: T.Tensor((N,), "float32"), x: T.Tensor((N,), T.float32),
y: T.Tensor((N,), "float32"), y: T.Tensor((N,), T.float32),
): ):
with T.Kernel(N, threads=32) as pid: with T.Kernel(N, threads=32) as pid:
# Explicitly request read-only cache load for x[pid] # Explicitly request read-only cache load for x[pid]
......
...@@ -60,8 +60,8 @@ def test_jit2_gemm_annot(): ...@@ -60,8 +60,8 @@ def test_jit2_gemm_annot():
) )
for in_dtype, out_dtype in prod: for in_dtype, out_dtype in prod:
in_dtype = in_dtype.torch() in_dtype = in_dtype.as_torch()
out_dtype = out_dtype.torch() out_dtype = out_dtype.as_torch()
A = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") A = torch.randn(1024, 1024, dtype=in_dtype, device="cuda")
B = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") B = torch.randn(1024, 1024, dtype=in_dtype, device="cuda")
C_ref = out_dtype(A @ B) C_ref = out_dtype(A @ B)
...@@ -97,8 +97,8 @@ def test_jit2_gemm_ptr(): ...@@ -97,8 +97,8 @@ def test_jit2_gemm_ptr():
] ]
) )
for in_dtype, out_dtype in prod: for in_dtype, out_dtype in prod:
in_dtype = in_dtype.torch() in_dtype = in_dtype.as_torch()
out_dtype = out_dtype.torch() out_dtype = out_dtype.as_torch()
A = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") A = torch.randn(1024, 1024, dtype=in_dtype, device="cuda")
B = torch.randn(1024, 1024, dtype=in_dtype, device="cuda") B = torch.randn(1024, 1024, dtype=in_dtype, device="cuda")
C_ref = out_dtype(A @ B) C_ref = out_dtype(A @ B)
...@@ -326,8 +326,8 @@ def test_jit2_return(): ...@@ -326,8 +326,8 @@ def test_jit2_return():
def test_jit2_deepseek_deepgemm(): def test_jit2_deepseek_deepgemm():
@tilelang.lazy_jit @tilelang.lazy_jit
def deep_gemm( def deep_gemm(
A: T.Tensor[[int, int], T.float8_e4m3], A: T.Tensor[[int, int], T.float8_e4m3fn],
B: T.Tensor[[int, int], T.float8_e4m3], B: T.Tensor[[int, int], T.float8_e4m3fn],
scales_a: T.Tensor[[int, int], T.float32], scales_a: T.Tensor[[int, int], T.float32],
scales_b: T.Tensor[[int, int], T.float32], scales_b: T.Tensor[[int, int], T.float32],
out_dtype: T.dtype = T.bfloat16, out_dtype: T.dtype = T.bfloat16,
......
...@@ -6,7 +6,7 @@ from tilelang import language as T ...@@ -6,7 +6,7 @@ from tilelang import language as T
def test_let_vectorize_load(): def test_let_vectorize_load():
@T.prim_func @T.prim_func
def main(A_ptr: T.handle): def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) A = T.match_buffer(A_ptr, (16, 16), dtype=T.float32, align=16)
for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for _threadIdx in T.thread_binding(128, thread="threadIdx.x"): for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment