Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"): def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype=T.float16):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
...@@ -26,7 +26,7 @@ def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"): ...@@ -26,7 +26,7 @@ def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"):
return main return main
def run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): def run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16):
program = tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype) program = tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype)
kernel = tilelang.compile( kernel = tilelang.compile(
program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
...@@ -42,7 +42,7 @@ def test_tilelang_copy_mask_parallel(): ...@@ -42,7 +42,7 @@ def test_tilelang_copy_mask_parallel():
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"): def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype=T.float16):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
...@@ -62,7 +62,7 @@ def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"): ...@@ -62,7 +62,7 @@ def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"):
return main return main
def run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): def run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16):
program = tilelang_copy_mask_copy(M, N, block_M, block_N, dtype) program = tilelang_copy_mask_copy(M, N, block_M, block_N, dtype)
kernel = tilelang.compile( kernel = tilelang.compile(
program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
...@@ -78,7 +78,7 @@ def test_tilelang_copy_mask_copy(): ...@@ -78,7 +78,7 @@ def test_tilelang_copy_mask_copy():
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"): def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype=T.float16):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
...@@ -99,7 +99,7 @@ def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"): ...@@ -99,7 +99,7 @@ def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"):
return main return main
def run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): def run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16):
program = tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype) program = tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype)
kernel = tilelang.compile( kernel = tilelang.compile(
program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
...@@ -115,7 +115,7 @@ def test_tilelang_copy_mask_parallel_range(): ...@@ -115,7 +115,7 @@ def test_tilelang_copy_mask_parallel_range():
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"): def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype=T.float16):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
...@@ -135,7 +135,7 @@ def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"): ...@@ -135,7 +135,7 @@ def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"):
return main return main
def run_tilelang_copy_mask_copy_range(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): def run_tilelang_copy_mask_copy_range(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16):
program = tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype) program = tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype)
kernel = tilelang.compile( kernel = tilelang.compile(
program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True} program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
......
from tilelang import tvm from tilelang import tvm
import tilelang as tl import tilelang as tl
import tilelang.testing import tilelang.testing
from tvm.script import tir as T import tilelang.language as T
@T.prim_func @T.prim_func
def negative_index_before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): def negative_index_before(A: T.Buffer((16,), T.float32), B: T.Buffer((16,), T.float32)):
T.func_attr({"tir.noalias": True}) T.func_attr({"tir.noalias": True})
B[0] = A[T.int32(-1)] B[0] = A[T.int32(-1)]
@T.prim_func @T.prim_func
def negative_index_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): def negative_index_expected(A: T.Buffer((16,), T.float32), B: T.Buffer((16,), T.float32)):
T.func_attr({"tir.noalias": True}) T.func_attr({"tir.noalias": True})
B[0] = A[T.int32(15)] B[0] = A[T.int32(15)]
@T.prim_func @T.prim_func
def negative_index_loop_before(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")): def negative_index_loop_before(A: T.Buffer((16,), T.float32), B: T.Buffer((4,), T.float32)):
T.func_attr({"tir.noalias": True}) T.func_attr({"tir.noalias": True})
for i in T.serial(4): for i in T.serial(4):
B[i] = A[-i - 1] B[i] = A[-i - 1]
@T.prim_func @T.prim_func
def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")): def negative_index_loop_expected(A: T.Buffer((16,), T.float32), B: T.Buffer((4,), T.float32)):
T.func_attr({"tir.noalias": True}) T.func_attr({"tir.noalias": True})
for i in T.serial(4): for i in T.serial(4):
B[i] = A[15 - i] B[i] = A[15 - i]
@T.prim_func @T.prim_func
def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")): def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), T.float32), B: T.Buffer((16,), T.float32)):
T.func_attr({"tir.noalias": True}) T.func_attr({"tir.noalias": True})
for i in T.serial(16): for i in T.serial(16):
B[i] = A[shift + i] B[i] = A[shift + i]
......
...@@ -8,7 +8,7 @@ tilelang.testing.set_random_seed() ...@@ -8,7 +8,7 @@ tilelang.testing.set_random_seed()
@tilelang.jit(out_idx=[1]) @tilelang.jit(out_idx=[1])
def parallel_elementwise_static(length=256, dtype="float32"): def parallel_elementwise_static(length=256, dtype=T.float32):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((length,), dtype), A: T.Tensor((length,), dtype),
...@@ -22,7 +22,7 @@ def parallel_elementwise_static(length=256, dtype="float32"): ...@@ -22,7 +22,7 @@ def parallel_elementwise_static(length=256, dtype="float32"):
@tilelang.jit(out_idx=[1]) @tilelang.jit(out_idx=[1])
def parallel_elementwise_dynamic(max_len=512, threads=256, dtype="float32"): def parallel_elementwise_dynamic(max_len=512, threads=256, dtype=T.float32):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((max_len,), dtype), A: T.Tensor((max_len,), dtype),
......
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang.language as T
def matmul( def matmul(
...@@ -23,8 +24,6 @@ def matmul( ...@@ -23,8 +24,6 @@ def matmul(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
...@@ -63,9 +62,9 @@ def run_gemm( ...@@ -63,9 +62,9 @@ def run_gemm(
block_K = 32 block_K = 32
trans_A = False trans_A = False
trans_B = False trans_B = False
in_dtype = "float16" in_dtype = T.float16
out_dtype = "float16" out_dtype = T.float16
dtypeAccum = "float32" dtypeAccum = T.float32
num_threads = 128 num_threads = 128
program = matmul( program = matmul(
M, M,
...@@ -101,7 +100,7 @@ def run_gemm( ...@@ -101,7 +100,7 @@ def run_gemm(
A = A.T A = A.T
if trans_B: if trans_B:
B = B.T B = B.T
if in_dtype == "float32": if in_dtype == T.float32:
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate # Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas # float32 automatically, -0x1000 meas
A = (A.view(torch.int32) - 0x1000).view(torch.float32) A = (A.view(torch.int32) - 0x1000).view(torch.float32)
...@@ -127,7 +126,7 @@ def test_pipeline_order_stage(): ...@@ -127,7 +126,7 @@ def test_pipeline_order_stage():
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
}, },
) )
def blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, dtype="float16", accum_dtype="float"): def blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, dtype=T.float16, accum_dtype=T.float32):
block_mask_shape = (M // block_M, N // block_N, K // block_K) block_mask_shape = (M // block_M, N // block_N, K // block_K)
import tilelang.language as T import tilelang.language as T
......
...@@ -6,7 +6,7 @@ import tilelang.language as T ...@@ -6,7 +6,7 @@ import tilelang.language as T
from tilelang.utils import map_torch_type from tilelang.utils import map_torch_type
def matmul_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul_test(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func @T.prim_func
def main( def main(
a_ptr: T.ptr, a_ptr: T.ptr,
...@@ -39,7 +39,7 @@ def matmul_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype ...@@ -39,7 +39,7 @@ def matmul_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype
return main return main
def run_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def run_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
program = matmul_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) program = matmul_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
jit_kernel = tl.compile(program, target="cuda", execution_backend="cython") jit_kernel = tl.compile(program, target="cuda", execution_backend="cython")
......
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang as tl import tilelang as tl
import tilelang.language as T
tilelang.testing.set_random_seed() tilelang.testing.set_random_seed()
def _make_shared_reduce(M, N, dtype, reduce_cb): def _make_shared_reduce(M, N, dtype, reduce_cb):
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
...@@ -30,7 +29,7 @@ def _run_program(program, ref_program, atol=1e-2, rtol=1e-2): ...@@ -30,7 +29,7 @@ def _run_program(program, ref_program, atol=1e-2, rtol=1e-2):
profiler.assert_allclose(ref_program, atol=atol, rtol=rtol) profiler.assert_allclose(ref_program, atol=atol, rtol=rtol)
def reduce_max_test(M, N, dtype="float16"): def reduce_max_test(M, N, dtype=T.float16):
import tilelang.language as T import tilelang.language as T
@T.prim_func @T.prim_func
...@@ -49,7 +48,7 @@ def reduce_max_test(M, N, dtype="float16"): ...@@ -49,7 +48,7 @@ def reduce_max_test(M, N, dtype="float16"):
return main return main
def reduce_sum_test(M, N, dtype="float32"): def reduce_sum_test(M, N, dtype=T.float32):
import tilelang.language as T import tilelang.language as T
@T.prim_func @T.prim_func
...@@ -68,27 +67,27 @@ def reduce_sum_test(M, N, dtype="float32"): ...@@ -68,27 +67,27 @@ def reduce_sum_test(M, N, dtype="float32"):
return main return main
def reduce_sum_ss(M, N, dtype="float32"): def reduce_sum_ss(M, N, dtype=T.float32):
return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_sum(src, dst, dim=1)) return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_sum(src, dst, dim=1))
def reduce_max_ss(M, N, dtype="float32"): def reduce_max_ss(M, N, dtype=T.float32):
return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_max(src, dst, dim=1)) return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_max(src, dst, dim=1))
def reduce_min_ss(M, N, dtype="float32"): def reduce_min_ss(M, N, dtype=T.float32):
return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_min(src, dst, dim=1)) return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_min(src, dst, dim=1))
def reduce_abssum_ss(M, N, dtype="float32"): def reduce_abssum_ss(M, N, dtype=T.float32):
return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_abssum(src, dst, dim=1)) return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_abssum(src, dst, dim=1))
def reduce_absmax_ss(M, N, dtype="float32"): def reduce_absmax_ss(M, N, dtype=T.float32):
return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_absmax(src, dst, dim=1)) return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_absmax(src, dst, dim=1))
def run_reduce_sum(M, N, dtype="float32", mode="rr"): def run_reduce_sum(M, N, dtype=T.float32, mode="rr"):
if mode == "rr": if mode == "rr":
program = reduce_sum_test(M, N, dtype) program = reduce_sum_test(M, N, dtype)
elif mode == "ss": elif mode == "ss":
...@@ -98,12 +97,12 @@ def run_reduce_sum(M, N, dtype="float32", mode="rr"): ...@@ -98,12 +97,12 @@ def run_reduce_sum(M, N, dtype="float32", mode="rr"):
_run_program(program, lambda A: A.sum(dim=1)) _run_program(program, lambda A: A.sum(dim=1))
def run_shared_reduce(program_builder, ref_program, M, N, dtype="float32"): def run_shared_reduce(program_builder, ref_program, M, N, dtype=T.float32):
program = program_builder(M, N, dtype) program = program_builder(M, N, dtype)
_run_program(program, ref_program) _run_program(program, ref_program)
def run_reduce_max(M, N, dtype="float16"): def run_reduce_max(M, N, dtype=T.float16):
program = reduce_max_test(M, N, dtype) program = reduce_max_test(M, N, dtype)
_run_program(program, lambda A: A.max(dim=1).values, atol=1e-2, rtol=1e-2) _run_program(program, lambda A: A.max(dim=1).values, atol=1e-2, rtol=1e-2)
...@@ -119,28 +118,28 @@ def test_reduce_sum_shared(): ...@@ -119,28 +118,28 @@ def test_reduce_sum_shared():
def test_reduce_max(): def test_reduce_max():
run_reduce_max(256, 256, "float16") run_reduce_max(256, 256, T.float16)
run_reduce_max(512, 128, "float16") run_reduce_max(512, 128, T.float16)
run_reduce_max(256, 256, "float32") run_reduce_max(256, 256, T.float32)
def test_reduce_max_shared(): def test_reduce_max_shared():
run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 64, 64, "float32") run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 64, 64, T.float32)
def test_reduce_min_shared(): def test_reduce_min_shared():
run_shared_reduce(reduce_min_ss, lambda A: A.min(dim=1).values, 64, 64, "float32") run_shared_reduce(reduce_min_ss, lambda A: A.min(dim=1).values, 64, 64, T.float32)
def test_reduce_abssum_shared(): def test_reduce_abssum_shared():
run_shared_reduce(reduce_abssum_ss, lambda A: A.abs().sum(dim=1), 64, 64, "float32") run_shared_reduce(reduce_abssum_ss, lambda A: A.abs().sum(dim=1), 64, 64, T.float32)
def test_reduce_absmax_shared(): def test_reduce_absmax_shared():
run_shared_reduce(reduce_absmax_ss, lambda A: A.abs().max(dim=1).values, 64, 64, "float32") run_shared_reduce(reduce_absmax_ss, lambda A: A.abs().max(dim=1).values, 64, 64, T.float32)
def reduce_sum_test_clear(M, N, dtype="float32"): def reduce_sum_test_clear(M, N, dtype=T.float32):
import tilelang.language as T import tilelang.language as T
@T.prim_func @T.prim_func
...@@ -160,7 +159,7 @@ def reduce_sum_test_clear(M, N, dtype="float32"): ...@@ -160,7 +159,7 @@ def reduce_sum_test_clear(M, N, dtype="float32"):
return main return main
def run_reduce_sum_clear(M, N, dtype="float32"): def run_reduce_sum_clear(M, N, dtype=T.float32):
program = reduce_sum_test_clear(M, N, dtype) program = reduce_sum_test_clear(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1) jit_kernel = tl.compile(program, out_idx=-1)
...@@ -176,12 +175,12 @@ def run_reduce_sum_clear(M, N, dtype="float32"): ...@@ -176,12 +175,12 @@ def run_reduce_sum_clear(M, N, dtype="float32"):
def test_reduce_sum_clear(): def test_reduce_sum_clear():
run_reduce_sum_clear(256, 256, "float32") run_reduce_sum_clear(256, 256, T.float32)
run_reduce_sum_clear(512, 128, "float32") run_reduce_sum_clear(512, 128, T.float32)
run_reduce_sum_clear(128, 512, "float32") run_reduce_sum_clear(128, 512, T.float32)
def reduce_max_test_clear(M, N, dtype="float16"): def reduce_max_test_clear(M, N, dtype=T.float16):
import tilelang.language as T import tilelang.language as T
@T.prim_func @T.prim_func
...@@ -201,7 +200,7 @@ def reduce_max_test_clear(M, N, dtype="float16"): ...@@ -201,7 +200,7 @@ def reduce_max_test_clear(M, N, dtype="float16"):
return main return main
def run_reduce_max_clear(M, N, dtype="float16"): def run_reduce_max_clear(M, N, dtype=T.float16):
program = reduce_max_test_clear(M, N, dtype) program = reduce_max_test_clear(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1) jit_kernel = tl.compile(program, out_idx=-1)
...@@ -217,7 +216,7 @@ def run_reduce_max_clear(M, N, dtype="float16"): ...@@ -217,7 +216,7 @@ def run_reduce_max_clear(M, N, dtype="float16"):
def test_reduce_max_clear(): def test_reduce_max_clear():
run_reduce_max_clear(256, 256, "float16") run_reduce_max_clear(256, 256, T.float16)
if __name__ == "__main__": if __name__ == "__main__":
......
from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang as tl import tilelang as tl
from tilelang import language as T
import torch import torch
import pytest import pytest
def reshape_test(N, M, dtype): def reshape_test(N, M, dtype):
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
...@@ -42,13 +40,11 @@ def run_reshape(N, M, dtype): ...@@ -42,13 +40,11 @@ def run_reshape(N, M, dtype):
def test_reshape_smem(): def test_reshape_smem():
# Test reshape # Test reshape
run_reshape(1024, 32, "float32") run_reshape(1024, 32, T.float32)
run_reshape(2048, 64, "float16") run_reshape(2048, 64, T.float16)
def reshape_test_smem_1d_2_2d(N, M, dtype): def reshape_test_smem_1d_2_2d(N, M, dtype):
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
...@@ -86,13 +82,11 @@ def run_reshape_smem_1d_2_2d(N, M, dtype): ...@@ -86,13 +82,11 @@ def run_reshape_smem_1d_2_2d(N, M, dtype):
def test_reshape_smem_1d_2_2d(): def test_reshape_smem_1d_2_2d():
run_reshape_smem_1d_2_2d(1024, 32, "float32") run_reshape_smem_1d_2_2d(1024, 32, T.float32)
run_reshape_smem_1d_2_2d(2048, 64, "float16") run_reshape_smem_1d_2_2d(2048, 64, T.float16)
def reshape_test_smem_2d_2_1d(N, M, dtype): def reshape_test_smem_2d_2_1d(N, M, dtype):
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N // M, M), dtype), A: T.Tensor((N // M, M), dtype),
...@@ -130,13 +124,11 @@ def run_reshape_smem_2d_2_1d(N, M, dtype): ...@@ -130,13 +124,11 @@ def run_reshape_smem_2d_2_1d(N, M, dtype):
def test_reshape_smem_2d_2_1d(): def test_reshape_smem_2d_2_1d():
run_reshape_smem_2d_2_1d(1024, 32, "float32") run_reshape_smem_2d_2_1d(1024, 32, T.float32)
run_reshape_smem_2d_2_1d(2048, 64, "float16") run_reshape_smem_2d_2_1d(2048, 64, T.float16)
def reshape_fragment_test(N, M, dtype): def reshape_fragment_test(N, M, dtype):
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N // M, M), dtype), A: T.Tensor((N // M, M), dtype),
...@@ -175,12 +167,11 @@ def run_reshape_fragment(N, M, dtype): ...@@ -175,12 +167,11 @@ def run_reshape_fragment(N, M, dtype):
def test_reshape_fragment(): def test_reshape_fragment():
run_reshape_fragment(1024, 32, "float32") run_reshape_fragment(1024, 32, T.float32)
run_reshape_fragment(2048, 64, "float16") run_reshape_fragment(2048, 64, T.float16)
def reshape_layout_transform_shared(N, M, dtype): def reshape_layout_transform_shared(N, M, dtype):
import tilelang.language as T
from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout
@T.prim_func @T.prim_func
...@@ -222,13 +213,11 @@ def run_reshape_layout_transform_shared(N, M, dtype): ...@@ -222,13 +213,11 @@ def run_reshape_layout_transform_shared(N, M, dtype):
def test_reshape_layout_transform_shared(): def test_reshape_layout_transform_shared():
run_reshape_layout_transform_shared(1024, 32, "float32") run_reshape_layout_transform_shared(1024, 32, T.float32)
run_reshape_layout_transform_shared(2048, 64, "float16") run_reshape_layout_transform_shared(2048, 64, T.float16)
def reduce_after_reshape_test(N, M, dtype): def reduce_after_reshape_test(N, M, dtype):
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
...@@ -267,13 +256,11 @@ def run_reduce_after_reshape(N, M, dtype): ...@@ -267,13 +256,11 @@ def run_reduce_after_reshape(N, M, dtype):
def test_reduce_after_reshape(): def test_reduce_after_reshape():
run_reduce_after_reshape(1024, 32, "float32") run_reduce_after_reshape(1024, 32, T.float32)
run_reduce_after_reshape(2048, 64, "float16") run_reduce_after_reshape(2048, 64, T.float16)
def reshape_shape_mismatch_test(N, M, dtype): def reshape_shape_mismatch_test(N, M, dtype):
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((N,), dtype), A: T.Tensor((N,), dtype),
...@@ -288,7 +275,7 @@ def reshape_shape_mismatch_test(N, M, dtype): ...@@ -288,7 +275,7 @@ def reshape_shape_mismatch_test(N, M, dtype):
def test_reshape_shape_mismatch(): def test_reshape_shape_mismatch():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
reshape_shape_mismatch_test(1024, 32, "float32") reshape_shape_mismatch_test(1024, 32, T.float32)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -7,7 +7,7 @@ import tilelang.testing ...@@ -7,7 +7,7 @@ import tilelang.testing
@tilelang.jit( @tilelang.jit(
out_idx=[1], out_idx=[1],
) )
def tilelang_ternary(M, N, block_M, block_N, dtype="float16"): def tilelang_ternary(M, N, block_M, block_N, dtype=T.float16):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), dtype),
...@@ -21,7 +21,7 @@ def tilelang_ternary(M, N, block_M, block_N, dtype="float16"): ...@@ -21,7 +21,7 @@ def tilelang_ternary(M, N, block_M, block_N, dtype="float16"):
return main return main
def run_tilelang_ternary(M=128, N=128, block_M=32, block_N=32, dtype="float16"): def run_tilelang_ternary(M=128, N=128, block_M=32, block_N=32, dtype=T.float16):
kernel = tilelang_ternary(M, N, block_M, block_N, dtype) kernel = tilelang_ternary(M, N, block_M, block_N, dtype)
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a) b = kernel(a)
......
...@@ -34,7 +34,7 @@ def run_elementwise_add(M, N): ...@@ -34,7 +34,7 @@ def run_elementwise_add(M, N):
# Default config # Default config
block_M, block_N = 128, 128 block_M, block_N = 128, 128
config = {"block_M": block_M, "block_N": block_N, "threads": 128} config = {"block_M": block_M, "block_N": block_N, "threads": 128}
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") kernel = elementwise_add(M, N, **config, in_dtype=T.float32, out_dtype=T.float32)
out = kernel(a, b) out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
......
...@@ -6,7 +6,7 @@ from tilelang import language as T ...@@ -6,7 +6,7 @@ from tilelang import language as T
def test_unroll_with_step(): def test_unroll_with_step():
@T.prim_func @T.prim_func
def main(A_ptr: T.handle): def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) A = T.match_buffer(A_ptr, (16, 16), dtype=T.float32, align=16)
for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for _threadIdx in T.thread_binding(128, thread="threadIdx.x"): for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
...@@ -20,7 +20,7 @@ def test_unroll_with_step(): ...@@ -20,7 +20,7 @@ def test_unroll_with_step():
def test_unroll_with_unroll_factor(): def test_unroll_with_unroll_factor():
@T.prim_func @T.prim_func
def main(A_ptr: T.handle): def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16) A = T.match_buffer(A_ptr, (16, 16), dtype=T.float32, align=16)
for _blockIdx in T.thread_binding(1, thread="blockIdx.x"): for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for _threadIdx in T.thread_binding(128, thread="threadIdx.x"): for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
......
...@@ -7,12 +7,12 @@ def test_var_assign() -> None: ...@@ -7,12 +7,12 @@ def test_var_assign() -> None:
@tilelang.jit(out_idx=-1) @tilelang.jit(out_idx=-1)
def jit_kernel(): def jit_kernel():
@T.prim_func @T.prim_func
def test_var_assign(A: T.Tensor((2,), "int32")): def test_var_assign(A: T.Tensor((2,), T.int32)):
with T.Kernel(1) as _: with T.Kernel(1) as _:
a = T.alloc_var("int32", init=1) a = T.alloc_var(T.int32, init=1)
b = T.alloc_var("int32", init=a) # b gets value of a b = T.alloc_var(T.int32, init=a) # b gets value of a
a = 2 a = 2
d = T.alloc_var("int32", init=a) # c gets new value of a d = T.alloc_var(T.int32, init=a) # c gets new value of a
A[0] = b A[0] = b
A[1] = d A[1] = d
......
...@@ -7,8 +7,8 @@ import tilelang.language as T ...@@ -7,8 +7,8 @@ import tilelang.language as T
def vectorize_test(N, M, stride_A, stride_B): def vectorize_test(N, M, stride_A, stride_B):
@T.prim_func @T.prim_func
def main( def main(
A: T.StridedTensor[(N, M), (1, stride_A), "float32"], # noqa: F821 A: T.StridedTensor[(N, M), (1, stride_A), T.float32], # noqa: F821
B: T.StridedTensor[(N, M), (1, stride_B), "float32"], # noqa: F821 B: T.StridedTensor[(N, M), (1, stride_B), T.float32], # noqa: F821
): ):
with T.Kernel(M // 128, threads=128) as (bx): with T.Kernel(M // 128, threads=128) as (bx):
tx = T.get_thread_binding(0) tx = T.get_thread_binding(0)
...@@ -60,9 +60,9 @@ def test_vectorize(): ...@@ -60,9 +60,9 @@ def test_vectorize():
def vectorize_test_invariant_index(N, M, K): def vectorize_test_invariant_index(N, M, K):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor[(N, M), "float32"], # noqa: F821 A: T.Tensor[(N, M), T.float32], # noqa: F821
B: T.Tensor[(N, M), "float32"], # noqa: F821 B: T.Tensor[(N, M), T.float32], # noqa: F821
C: T.Tensor[(N, M // K), "float32"], # noqa: F821 C: T.Tensor[(N, M // K), T.float32], # noqa: F821
): ):
with T.Kernel(N // 128, threads=128) as (bx): with T.Kernel(N // 128, threads=128) as (bx):
tx = T.get_thread_binding(0) tx = T.get_thread_binding(0)
......
...@@ -4,11 +4,11 @@ import tilelang.testing ...@@ -4,11 +4,11 @@ import tilelang.testing
import tilelang.language as T import tilelang.language as T
str2dtype = { str2dtype = {
"float32": torch.float32, T.float32: torch.float32,
"float16": torch.float16, T.float16: torch.float16,
"bfloat16": torch.bfloat16, T.bfloat16: torch.bfloat16,
"float8_e4m3": torch.float8_e4m3fn, T.float8_e4m3fn: torch.float8_e4m3fn,
"float8_e5m2": torch.float8_e5m2, T.float8_e5m2: torch.float8_e5m2,
} }
...@@ -81,22 +81,22 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, ...@@ -81,22 +81,22 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
@pytest.mark.parametrize( @pytest.mark.parametrize(
"src_dtype, dst_dtype, check_str, lanes", "src_dtype, dst_dtype, check_str, lanes",
[ [
("float32", "float16", "__float22half2_rn", 2), (T.float32, T.float16, "__float22half2_rn", 2),
("float32", "float16", "__float22half2_rn", 4), (T.float32, T.float16, "__float22half2_rn", 4),
("float16", "float32", "__half22float2", 2), (T.float16, T.float32, "__half22float2", 2),
("float16", "float32", "__half22float2", 4), (T.float16, T.float32, "__half22float2", 4),
("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 2), (T.float32, T.float8_e4m3fn, "__nv_cvt_float2_to_fp8x2", 2),
("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 4), (T.float32, T.float8_e4m3fn, "__nv_cvt_float2_to_fp8x2", 4),
("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 2), (T.float32, T.float8_e5m2, "__nv_cvt_float2_to_fp8x2", 2),
("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 4), (T.float32, T.float8_e5m2, "__nv_cvt_float2_to_fp8x2", 4),
("float32", "bfloat16", "__float22bfloat162_rn", 2), (T.float32, T.bfloat16, "__float22bfloat162_rn", 2),
("float32", "bfloat16", "__float22bfloat162_rn", 4), (T.float32, T.bfloat16, "__float22bfloat162_rn", 4),
("bfloat16", "float32", "__bfloat1622float2", 2), (T.bfloat16, T.float32, "__bfloat1622float2", 2),
("bfloat16", "float32", "__bfloat1622float2", 4), (T.bfloat16, T.float32, "__bfloat1622float2", 4),
("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 2), (T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 2),
("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 4), (T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 4),
("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 2), (T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 2),
("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 4), (T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 4),
], ],
) )
def test_vectorized_cast(src_dtype, dst_dtype, check_str, lanes): def test_vectorized_cast(src_dtype, dst_dtype, check_str, lanes):
......
import tilelang.language as T
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang as tl import tilelang as tl
...@@ -5,8 +6,6 @@ import pytest ...@@ -5,8 +6,6 @@ import pytest
def view_test(N, M, dtype, new_dtype=None): def view_test(N, M, dtype, new_dtype=None):
import tilelang.language as T
new_shape = [N // M, M] new_shape = [N // M, M]
if new_dtype: if new_dtype:
from tvm import DataType from tvm import DataType
...@@ -37,9 +36,7 @@ def run_view(N, M, dtype, new_dtype=None): ...@@ -37,9 +36,7 @@ def run_view(N, M, dtype, new_dtype=None):
def ref_program(A): def ref_program(A):
if new_dtype: if new_dtype:
from tilelang.utils.tensor import map_torch_type torch_dtype = T.dtype(new_dtype).as_torch()
torch_dtype = map_torch_type(new_dtype)
return A.view(N // M, M).view(dtype=torch_dtype) return A.view(N // M, M).view(dtype=torch_dtype)
return A.view(N // M, M) return A.view(N // M, M)
...@@ -48,17 +45,15 @@ def run_view(N, M, dtype, new_dtype=None): ...@@ -48,17 +45,15 @@ def run_view(N, M, dtype, new_dtype=None):
def test_reshape_view(): def test_reshape_view():
# Test view with same dtype # Test view with same dtype
run_view(1024, 32, "float32") run_view(1024, 32, T.float32)
run_view(2048, 64, "float16") run_view(2048, 64, T.float16)
# Test view with dtype conversion # Test view with dtype conversion
run_view(1024, 32, "float32", "float16") run_view(1024, 32, T.float32, T.float16)
run_view(2048, 64, "float16", "float32") run_view(2048, 64, T.float16, T.float32)
def view_shape_mismatch_test(N, M, dtype, new_dtype=None): def view_shape_mismatch_test(N, M, dtype, new_dtype=None):
import tilelang.language as T
new_shape = [N // M, M + 1] new_shape = [N // M, M + 1]
if new_dtype: if new_dtype:
from tvm import DataType from tvm import DataType
...@@ -84,7 +79,7 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None): ...@@ -84,7 +79,7 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None):
def test_view_shape_mismatch(): def test_view_shape_mismatch():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
view_shape_mismatch_test(1024, 32, "float32") view_shape_mismatch_test(1024, 32, T.float32)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -33,7 +33,7 @@ def get_kernel(reduce_op: str, dtype: str): ...@@ -33,7 +33,7 @@ def get_kernel(reduce_op: str, dtype: str):
def test_warp_reduce_sum(): def test_warp_reduce_sum():
a = torch.randn((32,), dtype=torch.float32, device="cuda") a = torch.randn((32,), dtype=torch.float32, device="cuda")
kernel = get_kernel("sum", "float32") kernel = get_kernel("sum", T.float32)
ref = torch.full_like(a, a.sum()) ref = torch.full_like(a, a.sum())
kernel(a) kernel(a)
torch.testing.assert_close(a, ref) torch.testing.assert_close(a, ref)
...@@ -41,7 +41,7 @@ def test_warp_reduce_sum(): ...@@ -41,7 +41,7 @@ def test_warp_reduce_sum():
def test_warp_reduce_max(): def test_warp_reduce_max():
a = torch.randn((32,), dtype=torch.float32, device="cuda") a = torch.randn((32,), dtype=torch.float32, device="cuda")
kernel = get_kernel("max", "float32") kernel = get_kernel("max", T.float32)
print(kernel.get_kernel_source()) print(kernel.get_kernel_source())
ref = torch.full_like(a, a.max()) ref = torch.full_like(a, a.max())
kernel(a) kernel(a)
...@@ -50,7 +50,7 @@ def test_warp_reduce_max(): ...@@ -50,7 +50,7 @@ def test_warp_reduce_max():
def test_warp_reduce_min(): def test_warp_reduce_min():
a = torch.randn((32,), dtype=torch.float32, device="cuda") a = torch.randn((32,), dtype=torch.float32, device="cuda")
kernel = get_kernel("min", "float32") kernel = get_kernel("min", T.float32)
ref = torch.full_like(a, a.min()) ref = torch.full_like(a, a.min())
kernel(a) kernel(a)
torch.testing.assert_close(a, ref) torch.testing.assert_close(a, ref)
...@@ -58,7 +58,7 @@ def test_warp_reduce_min(): ...@@ -58,7 +58,7 @@ def test_warp_reduce_min():
def test_warp_reduce_bitand(): def test_warp_reduce_bitand():
a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda") a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda")
kernel = get_kernel("bitand", "int32") kernel = get_kernel("bitand", T.int32)
ref_val = a[0] ref_val = a[0]
for i in range(1, a.shape[0]): for i in range(1, a.shape[0]):
ref_val = ref_val & a[i] ref_val = ref_val & a[i]
...@@ -69,7 +69,7 @@ def test_warp_reduce_bitand(): ...@@ -69,7 +69,7 @@ def test_warp_reduce_bitand():
def test_warp_reduce_bitor(): def test_warp_reduce_bitor():
a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda") a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda")
kernel = get_kernel("bitor", "int32") kernel = get_kernel("bitor", T.int32)
ref_val = a[0] ref_val = a[0]
for i in range(1, a.shape[0]): for i in range(1, a.shape[0]):
ref_val = ref_val | a[i] ref_val = ref_val | a[i]
......
...@@ -14,8 +14,8 @@ VEC_SIZE = 32 ...@@ -14,8 +14,8 @@ VEC_SIZE = 32
def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int): def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int):
@T.prim_func @T.prim_func
def main( def main(
a: T.Buffer((B, M, N), "bfloat16"), a: T.Buffer((B, M, N), T.bfloat16),
a_out: T.Buffer((B, M, N), "float32"), a_out: T.Buffer((B, M, N), T.float32),
): ):
with T.Kernel( with T.Kernel(
T.ceildiv(M, BLOCK_MN), T.ceildiv(M, BLOCK_MN),
...@@ -23,7 +23,7 @@ def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int): ...@@ -23,7 +23,7 @@ def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int):
B, B,
threads=128, threads=128,
) as (pid_m, pid_n, pid_b): ) as (pid_m, pid_n, pid_b):
a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), "float32") a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), T.float32)
offs_m = pid_m * BLOCK_MN offs_m = pid_m * BLOCK_MN
offs_n = pid_n * BLOCK_K offs_n = pid_n * BLOCK_K
......
...@@ -21,15 +21,15 @@ def bitwise_reduce( ...@@ -21,15 +21,15 @@ def bitwise_reduce(
): ):
@T.prim_func @T.prim_func
def reduce_func( def reduce_func(
A: T.Tensor((M, N), "int32"), A: T.Tensor((M, N), T.int32),
B: T.Tensor((M), "int32"), B: T.Tensor((M), T.int32),
Output: T.Tensor((M), "int32"), Output: T.Tensor((M), T.int32),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), "int32") A_shared = T.alloc_shared((block_M, block_N), T.int32)
A_fragment = T.alloc_fragment((block_M, block_N), "int32") A_fragment = T.alloc_fragment((block_M, block_N), T.int32)
B_shared = T.alloc_shared((block_M,), "int32") B_shared = T.alloc_shared((block_M,), T.int32)
B_fragment = T.alloc_fragment((block_M), "int32") B_fragment = T.alloc_fragment((block_M), T.int32)
T.copy(A[by * block_M, bx * block_N], A_shared) T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(A_shared, A_fragment) T.copy(A_shared, A_fragment)
T.copy(B[by * block_M], B_shared) T.copy(B[by * block_M], B_shared)
......
...@@ -49,7 +49,7 @@ def check_non_fastmath_usage(source, mathop_name): ...@@ -49,7 +49,7 @@ def check_non_fastmath_usage(source, mathop_name):
check_fastmath_usage(source, mathop_name, expect_fastmath=False) check_fastmath_usage(source, mathop_name, expect_fastmath=False)
def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32):
""" """
Test single-argument mathops. Test single-argument mathops.
T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath) T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
...@@ -85,7 +85,7 @@ def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=3 ...@@ -85,7 +85,7 @@ def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=3
print(f"✓ {mathop_name} compilation and execution test passed") print(f"✓ {mathop_name} compilation and execution test passed")
def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32):
""" """
Test two-argument mathops to ensure they generate non-fastmath CUDA code. Test two-argument mathops to ensure they generate non-fastmath CUDA code.
""" """
...@@ -133,7 +133,7 @@ def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, ...@@ -133,7 +133,7 @@ def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32,
check_non_fastmath_usage(source_fastmath, mathop_name) check_non_fastmath_usage(source_fastmath, mathop_name)
# Test numerical correctness # Test numerical correctness
torch_dtype = getattr(torch, dtype) torch_dtype = dtype.as_torch()
a = torch.randn(M, N, device="cuda", dtype=torch_dtype) a = torch.randn(M, N, device="cuda", dtype=torch_dtype)
b = torch.randn(M, N, device="cuda", dtype=torch_dtype) b = torch.randn(M, N, device="cuda", dtype=torch_dtype)
...@@ -159,8 +159,8 @@ def run_abs_test(): ...@@ -159,8 +159,8 @@ def run_abs_test():
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), "float32"), A: T.Tensor((M, N), T.float32),
B: T.Tensor((M, N), "float32"), B: T.Tensor((M, N), T.float32),
): ):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N): for i, j in T.Parallel(block_M, block_N):
...@@ -188,7 +188,7 @@ def run_abs_test(): ...@@ -188,7 +188,7 @@ def run_abs_test():
print("✓ abs numerical test passed") print("✓ abs numerical test passed")
def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"): def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32):
""" """
Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix). Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix).
""" """
...@@ -221,7 +221,7 @@ def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, ...@@ -221,7 +221,7 @@ def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32,
check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True) check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True)
# Test numerical correctness # Test numerical correctness
torch_dtype = getattr(torch, dtype) torch_dtype = dtype.as_torch()
a = torch.randn(M, N, device="cuda", dtype=torch_dtype) a = torch.randn(M, N, device="cuda", dtype=torch_dtype)
# Ensure positive values for functions that need them # Ensure positive values for functions that need them
...@@ -273,7 +273,7 @@ def test_mathops_generate_no_fastmath(): ...@@ -273,7 +273,7 @@ def test_mathops_generate_no_fastmath():
] ]
for name, func in single_arg_mathops: for name, func in single_arg_mathops:
run_single_arg_mathop_test(name, func, dtype="float32") run_single_arg_mathop_test(name, func, dtype=T.float32)
print(f"✓ {name} test passed") print(f"✓ {name} test passed")
...@@ -287,7 +287,7 @@ def test_two_arg_mathops_fastmath(): ...@@ -287,7 +287,7 @@ def test_two_arg_mathops_fastmath():
] ]
for name, func in two_arg_mathops: for name, func in two_arg_mathops:
run_two_arg_mathop_test(name, func, dtype="float32") run_two_arg_mathop_test(name, func, dtype=T.float32)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
...@@ -312,7 +312,7 @@ def test_fastmath_versions(): ...@@ -312,7 +312,7 @@ def test_fastmath_versions():
] ]
for name, func in fastmath_mathops: for name, func in fastmath_mathops:
run_fastmath_mathop_test(name, func, dtype="float32") run_fastmath_mathop_test(name, func, dtype=T.float32)
print(f"✓ {name} test passed") print(f"✓ {name} test passed")
......
...@@ -5,7 +5,7 @@ import tilelang.testing ...@@ -5,7 +5,7 @@ import tilelang.testing
import pytest import pytest
def run_ieee_math_test(mathop_name, mathop_func, rounding_mode="rn", M=128, N=128, block_M=32, block_N=32, dtype="float32"): def run_ieee_math_test(mathop_name, mathop_func, rounding_mode="rn", M=128, N=128, block_M=32, block_N=32, dtype=T.float32):
""" """
Test IEEE-compliant math operations with specified rounding modes. Test IEEE-compliant math operations with specified rounding modes.
""" """
...@@ -75,7 +75,7 @@ def run_ieee_math_test(mathop_name, mathop_func, rounding_mode="rn", M=128, N=12 ...@@ -75,7 +75,7 @@ def run_ieee_math_test(mathop_name, mathop_func, rounding_mode="rn", M=128, N=12
print(f"✓ {mathop_name} compilation test passed") print(f"✓ {mathop_name} compilation test passed")
# Test numerical execution # Test numerical execution
torch_dtype = getattr(torch, dtype) torch_dtype = dtype.as_torch()
a = torch.randn(M, N, device="cuda", dtype=torch_dtype) a = torch.randn(M, N, device="cuda", dtype=torch_dtype)
if num_inputs >= 2: if num_inputs >= 2:
...@@ -186,8 +186,8 @@ def test_ieee_frsqrt_rn_only(): ...@@ -186,8 +186,8 @@ def test_ieee_frsqrt_rn_only():
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((128, 128), "float32"), A: T.Tensor((128, 128), T.float32),
B: T.Tensor((128, 128), "float32"), B: T.Tensor((128, 128), T.float32),
): ):
with T.Kernel(T.ceildiv(128, 32), T.ceildiv(128, 32), threads=128) as (bx, by): with T.Kernel(T.ceildiv(128, 32), T.ceildiv(128, 32), threads=128) as (bx, by):
for i, j in T.Parallel(32, 32): for i, j in T.Parallel(32, 32):
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
@tilelang.jit(execution_backend="torch") @tilelang.jit(execution_backend="torch")
def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float32, accum_dtype=T.float32):
@T.prim_func @T.prim_func
def gemm( def gemm(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
...@@ -39,13 +39,13 @@ def assert_gemm( ...@@ -39,13 +39,13 @@ def assert_gemm(
block_M, block_M,
block_N, block_N,
block_K, block_K,
dtype="float32", dtype=T.float32,
accum_dtype="float", accum_dtype=T.float32,
atol=1e-8, atol=1e-8,
): ):
jit_kernel = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype) jit_kernel = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype)
torch_dtype = getattr(torch, dtype) torch_dtype = dtype.as_torch()
a, b = None, None a, b = None, None
if "int" in dtype: if "int" in dtype:
a = torch.randint(100, (M, K), dtype=torch_dtype, device="mps") a = torch.randint(100, (M, K), dtype=torch_dtype, device="mps")
...@@ -69,12 +69,12 @@ def test_gemm_float32(): ...@@ -69,12 +69,12 @@ def test_gemm_float32():
@tilelang.testing.requires_metal @tilelang.testing.requires_metal
def test_gemm_float16(): def test_gemm_float16():
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype="float16", atol=1) assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype=T.float16, atol=1)
@tilelang.testing.requires_metal @tilelang.testing.requires_metal
def test_gemm_int32(): def test_gemm_int32():
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype="int32", atol=1) assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype=T.int32, atol=1)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment