Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
......@@ -5,7 +5,7 @@ import torch
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"):
def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype=T.float16):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
......@@ -26,7 +26,7 @@ def tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype="float16"):
return main
def run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
def run_tilelang_copy_mask_parallel(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16):
program = tilelang_copy_mask_parallel(M, N, block_M, block_N, dtype)
kernel = tilelang.compile(
program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
......@@ -42,7 +42,7 @@ def test_tilelang_copy_mask_parallel():
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"):
def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype=T.float16):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
......@@ -62,7 +62,7 @@ def tilelang_copy_mask_copy(M, N, block_M, block_N, dtype="float16"):
return main
def run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
def run_tilelang_copy_mask_copy(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16):
program = tilelang_copy_mask_copy(M, N, block_M, block_N, dtype)
kernel = tilelang.compile(
program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
......@@ -78,7 +78,7 @@ def test_tilelang_copy_mask_copy():
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"):
def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype=T.float16):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
......@@ -99,7 +99,7 @@ def tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype="float16"):
return main
def run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
def run_tilelang_copy_mask_parallel_range(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16):
program = tilelang_copy_mask_parallel_range(M, N, block_M, block_N, dtype)
kernel = tilelang.compile(
program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
......@@ -115,7 +115,7 @@ def test_tilelang_copy_mask_parallel_range():
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"):
def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype=T.float16):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
......@@ -135,7 +135,7 @@ def tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype="float16"):
return main
def run_tilelang_copy_mask_copy_range(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
def run_tilelang_copy_mask_copy_range(M=1024, N=1024, block_M=128, block_N=128, dtype=T.float16):
program = tilelang_copy_mask_copy_range(M, N, block_M, block_N, dtype)
kernel = tilelang.compile(
program, out_idx=[1], target="cuda", pass_configs={"tl.disable_warp_specialized": True, "tl.disable_tma_lower": True}
......
from tilelang import tvm
import tilelang as tl
import tilelang.testing
from tvm.script import tir as T
import tilelang.language as T
@T.prim_func
def negative_index_before(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
def negative_index_before(A: T.Buffer((16,), T.float32), B: T.Buffer((16,), T.float32)):
T.func_attr({"tir.noalias": True})
B[0] = A[T.int32(-1)]
@T.prim_func
def negative_index_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
def negative_index_expected(A: T.Buffer((16,), T.float32), B: T.Buffer((16,), T.float32)):
T.func_attr({"tir.noalias": True})
B[0] = A[T.int32(15)]
@T.prim_func
def negative_index_loop_before(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")):
def negative_index_loop_before(A: T.Buffer((16,), T.float32), B: T.Buffer((4,), T.float32)):
T.func_attr({"tir.noalias": True})
for i in T.serial(4):
B[i] = A[-i - 1]
@T.prim_func
def negative_index_loop_expected(A: T.Buffer((16,), "float32"), B: T.Buffer((4,), "float32")):
def negative_index_loop_expected(A: T.Buffer((16,), T.float32), B: T.Buffer((4,), T.float32)):
T.func_attr({"tir.noalias": True})
for i in T.serial(4):
B[i] = A[15 - i]
@T.prim_func
def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), "float32"), B: T.Buffer((16,), "float32")):
def negative_index_symbolic_before(shift: T.int32, A: T.Buffer((16,), T.float32), B: T.Buffer((16,), T.float32)):
T.func_attr({"tir.noalias": True})
for i in T.serial(16):
B[i] = A[shift + i]
......
......@@ -8,7 +8,7 @@ tilelang.testing.set_random_seed()
@tilelang.jit(out_idx=[1])
def parallel_elementwise_static(length=256, dtype="float32"):
def parallel_elementwise_static(length=256, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -22,7 +22,7 @@ def parallel_elementwise_static(length=256, dtype="float32"):
@tilelang.jit(out_idx=[1])
def parallel_elementwise_dynamic(max_len=512, threads=256, dtype="float32"):
def parallel_elementwise_dynamic(max_len=512, threads=256, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((max_len,), dtype),
......
from tilelang import tvm as tvm
import tilelang.testing
import tilelang.language as T
def matmul(
......@@ -23,8 +24,6 @@ def matmul(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
......@@ -63,9 +62,9 @@ def run_gemm(
block_K = 32
trans_A = False
trans_B = False
in_dtype = "float16"
out_dtype = "float16"
dtypeAccum = "float32"
in_dtype = T.float16
out_dtype = T.float16
dtypeAccum = T.float32
num_threads = 128
program = matmul(
M,
......@@ -101,7 +100,7 @@ def run_gemm(
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
if in_dtype == T.float32:
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
......@@ -127,7 +126,7 @@ def test_pipeline_order_stage():
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
},
)
def blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, dtype="float16", accum_dtype="float"):
def blocksparse_matmul(M, N, K, block_M, block_N, block_K, num_stages, dtype=T.float16, accum_dtype=T.float32):
block_mask_shape = (M // block_M, N // block_N, K // block_K)
import tilelang.language as T
......
......@@ -6,7 +6,7 @@ import tilelang.language as T
from tilelang.utils import map_torch_type
def matmul_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul_test(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
a_ptr: T.ptr,
......@@ -39,7 +39,7 @@ def matmul_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype
return main
def run_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def run_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
program = matmul_test(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
jit_kernel = tl.compile(program, target="cuda", execution_backend="cython")
......
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
import tilelang.language as T
tilelang.testing.set_random_seed()
def _make_shared_reduce(M, N, dtype, reduce_cb):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
......@@ -30,7 +29,7 @@ def _run_program(program, ref_program, atol=1e-2, rtol=1e-2):
profiler.assert_allclose(ref_program, atol=atol, rtol=rtol)
def reduce_max_test(M, N, dtype="float16"):
def reduce_max_test(M, N, dtype=T.float16):
import tilelang.language as T
@T.prim_func
......@@ -49,7 +48,7 @@ def reduce_max_test(M, N, dtype="float16"):
return main
def reduce_sum_test(M, N, dtype="float32"):
def reduce_sum_test(M, N, dtype=T.float32):
import tilelang.language as T
@T.prim_func
......@@ -68,27 +67,27 @@ def reduce_sum_test(M, N, dtype="float32"):
return main
def reduce_sum_ss(M, N, dtype="float32"):
def reduce_sum_ss(M, N, dtype=T.float32):
return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_sum(src, dst, dim=1))
def reduce_max_ss(M, N, dtype="float32"):
def reduce_max_ss(M, N, dtype=T.float32):
return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_max(src, dst, dim=1))
def reduce_min_ss(M, N, dtype="float32"):
def reduce_min_ss(M, N, dtype=T.float32):
return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_min(src, dst, dim=1))
def reduce_abssum_ss(M, N, dtype="float32"):
def reduce_abssum_ss(M, N, dtype=T.float32):
return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_abssum(src, dst, dim=1))
def reduce_absmax_ss(M, N, dtype="float32"):
def reduce_absmax_ss(M, N, dtype=T.float32):
return _make_shared_reduce(M, N, dtype, lambda T, src, dst: T.reduce_absmax(src, dst, dim=1))
def run_reduce_sum(M, N, dtype="float32", mode="rr"):
def run_reduce_sum(M, N, dtype=T.float32, mode="rr"):
if mode == "rr":
program = reduce_sum_test(M, N, dtype)
elif mode == "ss":
......@@ -98,12 +97,12 @@ def run_reduce_sum(M, N, dtype="float32", mode="rr"):
_run_program(program, lambda A: A.sum(dim=1))
def run_shared_reduce(program_builder, ref_program, M, N, dtype="float32"):
def run_shared_reduce(program_builder, ref_program, M, N, dtype=T.float32):
program = program_builder(M, N, dtype)
_run_program(program, ref_program)
def run_reduce_max(M, N, dtype="float16"):
def run_reduce_max(M, N, dtype=T.float16):
program = reduce_max_test(M, N, dtype)
_run_program(program, lambda A: A.max(dim=1).values, atol=1e-2, rtol=1e-2)
......@@ -119,28 +118,28 @@ def test_reduce_sum_shared():
def test_reduce_max():
run_reduce_max(256, 256, "float16")
run_reduce_max(512, 128, "float16")
run_reduce_max(256, 256, "float32")
run_reduce_max(256, 256, T.float16)
run_reduce_max(512, 128, T.float16)
run_reduce_max(256, 256, T.float32)
def test_reduce_max_shared():
run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 64, 64, "float32")
run_shared_reduce(reduce_max_ss, lambda A: A.max(dim=1).values, 64, 64, T.float32)
def test_reduce_min_shared():
run_shared_reduce(reduce_min_ss, lambda A: A.min(dim=1).values, 64, 64, "float32")
run_shared_reduce(reduce_min_ss, lambda A: A.min(dim=1).values, 64, 64, T.float32)
def test_reduce_abssum_shared():
run_shared_reduce(reduce_abssum_ss, lambda A: A.abs().sum(dim=1), 64, 64, "float32")
run_shared_reduce(reduce_abssum_ss, lambda A: A.abs().sum(dim=1), 64, 64, T.float32)
def test_reduce_absmax_shared():
run_shared_reduce(reduce_absmax_ss, lambda A: A.abs().max(dim=1).values, 64, 64, "float32")
run_shared_reduce(reduce_absmax_ss, lambda A: A.abs().max(dim=1).values, 64, 64, T.float32)
def reduce_sum_test_clear(M, N, dtype="float32"):
def reduce_sum_test_clear(M, N, dtype=T.float32):
import tilelang.language as T
@T.prim_func
......@@ -160,7 +159,7 @@ def reduce_sum_test_clear(M, N, dtype="float32"):
return main
def run_reduce_sum_clear(M, N, dtype="float32"):
def run_reduce_sum_clear(M, N, dtype=T.float32):
program = reduce_sum_test_clear(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
......@@ -176,12 +175,12 @@ def run_reduce_sum_clear(M, N, dtype="float32"):
def test_reduce_sum_clear():
run_reduce_sum_clear(256, 256, "float32")
run_reduce_sum_clear(512, 128, "float32")
run_reduce_sum_clear(128, 512, "float32")
run_reduce_sum_clear(256, 256, T.float32)
run_reduce_sum_clear(512, 128, T.float32)
run_reduce_sum_clear(128, 512, T.float32)
def reduce_max_test_clear(M, N, dtype="float16"):
def reduce_max_test_clear(M, N, dtype=T.float16):
import tilelang.language as T
@T.prim_func
......@@ -201,7 +200,7 @@ def reduce_max_test_clear(M, N, dtype="float16"):
return main
def run_reduce_max_clear(M, N, dtype="float16"):
def run_reduce_max_clear(M, N, dtype=T.float16):
program = reduce_max_test_clear(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1)
......@@ -217,7 +216,7 @@ def run_reduce_max_clear(M, N, dtype="float16"):
def test_reduce_max_clear():
run_reduce_max_clear(256, 256, "float16")
run_reduce_max_clear(256, 256, T.float16)
if __name__ == "__main__":
......
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
from tilelang import language as T
import torch
import pytest
def reshape_test(N, M, dtype):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
......@@ -42,13 +40,11 @@ def run_reshape(N, M, dtype):
def test_reshape_smem():
# Test reshape
run_reshape(1024, 32, "float32")
run_reshape(2048, 64, "float16")
run_reshape(1024, 32, T.float32)
run_reshape(2048, 64, T.float16)
def reshape_test_smem_1d_2_2d(N, M, dtype):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
......@@ -86,13 +82,11 @@ def run_reshape_smem_1d_2_2d(N, M, dtype):
def test_reshape_smem_1d_2_2d():
run_reshape_smem_1d_2_2d(1024, 32, "float32")
run_reshape_smem_1d_2_2d(2048, 64, "float16")
run_reshape_smem_1d_2_2d(1024, 32, T.float32)
run_reshape_smem_1d_2_2d(2048, 64, T.float16)
def reshape_test_smem_2d_2_1d(N, M, dtype):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((N // M, M), dtype),
......@@ -130,13 +124,11 @@ def run_reshape_smem_2d_2_1d(N, M, dtype):
def test_reshape_smem_2d_2_1d():
run_reshape_smem_2d_2_1d(1024, 32, "float32")
run_reshape_smem_2d_2_1d(2048, 64, "float16")
run_reshape_smem_2d_2_1d(1024, 32, T.float32)
run_reshape_smem_2d_2_1d(2048, 64, T.float16)
def reshape_fragment_test(N, M, dtype):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((N // M, M), dtype),
......@@ -175,12 +167,11 @@ def run_reshape_fragment(N, M, dtype):
def test_reshape_fragment():
run_reshape_fragment(1024, 32, "float32")
run_reshape_fragment(2048, 64, "float16")
run_reshape_fragment(1024, 32, T.float32)
run_reshape_fragment(2048, 64, T.float16)
def reshape_layout_transform_shared(N, M, dtype):
import tilelang.language as T
from tilelang.intrinsics.mma_layout import make_mma_swizzle_layout
@T.prim_func
......@@ -222,13 +213,11 @@ def run_reshape_layout_transform_shared(N, M, dtype):
def test_reshape_layout_transform_shared():
run_reshape_layout_transform_shared(1024, 32, "float32")
run_reshape_layout_transform_shared(2048, 64, "float16")
run_reshape_layout_transform_shared(1024, 32, T.float32)
run_reshape_layout_transform_shared(2048, 64, T.float16)
def reduce_after_reshape_test(N, M, dtype):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
......@@ -267,13 +256,11 @@ def run_reduce_after_reshape(N, M, dtype):
def test_reduce_after_reshape():
run_reduce_after_reshape(1024, 32, "float32")
run_reduce_after_reshape(2048, 64, "float16")
run_reduce_after_reshape(1024, 32, T.float32)
run_reduce_after_reshape(2048, 64, T.float16)
def reshape_shape_mismatch_test(N, M, dtype):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((N,), dtype),
......@@ -288,7 +275,7 @@ def reshape_shape_mismatch_test(N, M, dtype):
def test_reshape_shape_mismatch():
with pytest.raises(AssertionError):
reshape_shape_mismatch_test(1024, 32, "float32")
reshape_shape_mismatch_test(1024, 32, T.float32)
if __name__ == "__main__":
......
......@@ -7,7 +7,7 @@ import tilelang.testing
@tilelang.jit(
out_idx=[1],
)
def tilelang_ternary(M, N, block_M, block_N, dtype="float16"):
def tilelang_ternary(M, N, block_M, block_N, dtype=T.float16):
@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
......@@ -21,7 +21,7 @@ def tilelang_ternary(M, N, block_M, block_N, dtype="float16"):
return main
def run_tilelang_ternary(M=128, N=128, block_M=32, block_N=32, dtype="float16"):
def run_tilelang_ternary(M=128, N=128, block_M=32, block_N=32, dtype=T.float16):
kernel = tilelang_ternary(M, N, block_M, block_N, dtype)
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a)
......
......@@ -34,7 +34,7 @@ def run_elementwise_add(M, N):
# Default config
block_M, block_N = 128, 128
config = {"block_M": block_M, "block_N": block_N, "threads": 128}
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
kernel = elementwise_add(M, N, **config, in_dtype=T.float32, out_dtype=T.float32)
out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
......
......@@ -6,7 +6,7 @@ from tilelang import language as T
def test_unroll_with_step():
@T.prim_func
def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
A = T.match_buffer(A_ptr, (16, 16), dtype=T.float32, align=16)
for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
......@@ -20,7 +20,7 @@ def test_unroll_with_step():
def test_unroll_with_unroll_factor():
@T.prim_func
def main(A_ptr: T.handle):
A = T.match_buffer(A_ptr, (16, 16), dtype="float32", align=16)
A = T.match_buffer(A_ptr, (16, 16), dtype=T.float32, align=16)
for _blockIdx in T.thread_binding(1, thread="blockIdx.x"):
for _threadIdx in T.thread_binding(128, thread="threadIdx.x"):
......
......@@ -7,12 +7,12 @@ def test_var_assign() -> None:
@tilelang.jit(out_idx=-1)
def jit_kernel():
@T.prim_func
def test_var_assign(A: T.Tensor((2,), "int32")):
def test_var_assign(A: T.Tensor((2,), T.int32)):
with T.Kernel(1) as _:
a = T.alloc_var("int32", init=1)
b = T.alloc_var("int32", init=a) # b gets value of a
a = T.alloc_var(T.int32, init=1)
b = T.alloc_var(T.int32, init=a) # b gets value of a
a = 2
d = T.alloc_var("int32", init=a) # c gets new value of a
d = T.alloc_var(T.int32, init=a) # c gets new value of a
A[0] = b
A[1] = d
......
......@@ -7,8 +7,8 @@ import tilelang.language as T
def vectorize_test(N, M, stride_A, stride_B):
@T.prim_func
def main(
A: T.StridedTensor[(N, M), (1, stride_A), "float32"], # noqa: F821
B: T.StridedTensor[(N, M), (1, stride_B), "float32"], # noqa: F821
A: T.StridedTensor[(N, M), (1, stride_A), T.float32], # noqa: F821
B: T.StridedTensor[(N, M), (1, stride_B), T.float32], # noqa: F821
):
with T.Kernel(M // 128, threads=128) as (bx):
tx = T.get_thread_binding(0)
......@@ -60,9 +60,9 @@ def test_vectorize():
def vectorize_test_invariant_index(N, M, K):
@T.prim_func
def main(
A: T.Tensor[(N, M), "float32"], # noqa: F821
B: T.Tensor[(N, M), "float32"], # noqa: F821
C: T.Tensor[(N, M // K), "float32"], # noqa: F821
A: T.Tensor[(N, M), T.float32], # noqa: F821
B: T.Tensor[(N, M), T.float32], # noqa: F821
C: T.Tensor[(N, M // K), T.float32], # noqa: F821
):
with T.Kernel(N // 128, threads=128) as (bx):
tx = T.get_thread_binding(0)
......
......@@ -4,11 +4,11 @@ import tilelang.testing
import tilelang.language as T
str2dtype = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float8_e4m3": torch.float8_e4m3fn,
"float8_e5m2": torch.float8_e5m2,
T.float32: torch.float32,
T.float16: torch.float16,
T.bfloat16: torch.bfloat16,
T.float8_e4m3fn: torch.float8_e4m3fn,
T.float8_e5m2: torch.float8_e5m2,
}
......@@ -81,22 +81,22 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
@pytest.mark.parametrize(
"src_dtype, dst_dtype, check_str, lanes",
[
("float32", "float16", "__float22half2_rn", 2),
("float32", "float16", "__float22half2_rn", 4),
("float16", "float32", "__half22float2", 2),
("float16", "float32", "__half22float2", 4),
("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 2),
("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 4),
("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 2),
("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 4),
("float32", "bfloat16", "__float22bfloat162_rn", 2),
("float32", "bfloat16", "__float22bfloat162_rn", 4),
("bfloat16", "float32", "__bfloat1622float2", 2),
("bfloat16", "float32", "__bfloat1622float2", 4),
("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 2),
("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 4),
("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 2),
("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 4),
(T.float32, T.float16, "__float22half2_rn", 2),
(T.float32, T.float16, "__float22half2_rn", 4),
(T.float16, T.float32, "__half22float2", 2),
(T.float16, T.float32, "__half22float2", 4),
(T.float32, T.float8_e4m3fn, "__nv_cvt_float2_to_fp8x2", 2),
(T.float32, T.float8_e4m3fn, "__nv_cvt_float2_to_fp8x2", 4),
(T.float32, T.float8_e5m2, "__nv_cvt_float2_to_fp8x2", 2),
(T.float32, T.float8_e5m2, "__nv_cvt_float2_to_fp8x2", 4),
(T.float32, T.bfloat16, "__float22bfloat162_rn", 2),
(T.float32, T.bfloat16, "__float22bfloat162_rn", 4),
(T.bfloat16, T.float32, "__bfloat1622float2", 2),
(T.bfloat16, T.float32, "__bfloat1622float2", 4),
(T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 2),
(T.float8_e4m3fn, T.float32, "__tl_cvt_fp8x2_to_float2", 4),
(T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 2),
(T.float8_e5m2, T.float32, "__tl_cvt_fp8x2_to_float2", 4),
],
)
def test_vectorized_cast(src_dtype, dst_dtype, check_str, lanes):
......
import tilelang.language as T
from tilelang import tvm as tvm
import tilelang.testing
import tilelang as tl
......@@ -5,8 +6,6 @@ import pytest
def view_test(N, M, dtype, new_dtype=None):
import tilelang.language as T
new_shape = [N // M, M]
if new_dtype:
from tvm import DataType
......@@ -37,9 +36,7 @@ def run_view(N, M, dtype, new_dtype=None):
def ref_program(A):
if new_dtype:
from tilelang.utils.tensor import map_torch_type
torch_dtype = map_torch_type(new_dtype)
torch_dtype = T.dtype(new_dtype).as_torch()
return A.view(N // M, M).view(dtype=torch_dtype)
return A.view(N // M, M)
......@@ -48,17 +45,15 @@ def run_view(N, M, dtype, new_dtype=None):
def test_reshape_view():
# Test view with same dtype
run_view(1024, 32, "float32")
run_view(2048, 64, "float16")
run_view(1024, 32, T.float32)
run_view(2048, 64, T.float16)
# Test view with dtype conversion
run_view(1024, 32, "float32", "float16")
run_view(2048, 64, "float16", "float32")
run_view(1024, 32, T.float32, T.float16)
run_view(2048, 64, T.float16, T.float32)
def view_shape_mismatch_test(N, M, dtype, new_dtype=None):
import tilelang.language as T
new_shape = [N // M, M + 1]
if new_dtype:
from tvm import DataType
......@@ -84,7 +79,7 @@ def view_shape_mismatch_test(N, M, dtype, new_dtype=None):
def test_view_shape_mismatch():
with pytest.raises(AssertionError):
view_shape_mismatch_test(1024, 32, "float32")
view_shape_mismatch_test(1024, 32, T.float32)
if __name__ == "__main__":
......
......@@ -33,7 +33,7 @@ def get_kernel(reduce_op: str, dtype: str):
def test_warp_reduce_sum():
a = torch.randn((32,), dtype=torch.float32, device="cuda")
kernel = get_kernel("sum", "float32")
kernel = get_kernel("sum", T.float32)
ref = torch.full_like(a, a.sum())
kernel(a)
torch.testing.assert_close(a, ref)
......@@ -41,7 +41,7 @@ def test_warp_reduce_sum():
def test_warp_reduce_max():
a = torch.randn((32,), dtype=torch.float32, device="cuda")
kernel = get_kernel("max", "float32")
kernel = get_kernel("max", T.float32)
print(kernel.get_kernel_source())
ref = torch.full_like(a, a.max())
kernel(a)
......@@ -50,7 +50,7 @@ def test_warp_reduce_max():
def test_warp_reduce_min():
a = torch.randn((32,), dtype=torch.float32, device="cuda")
kernel = get_kernel("min", "float32")
kernel = get_kernel("min", T.float32)
ref = torch.full_like(a, a.min())
kernel(a)
torch.testing.assert_close(a, ref)
......@@ -58,7 +58,7 @@ def test_warp_reduce_min():
def test_warp_reduce_bitand():
a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda")
kernel = get_kernel("bitand", "int32")
kernel = get_kernel("bitand", T.int32)
ref_val = a[0]
for i in range(1, a.shape[0]):
ref_val = ref_val & a[i]
......@@ -69,7 +69,7 @@ def test_warp_reduce_bitand():
def test_warp_reduce_bitor():
a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device="cuda")
kernel = get_kernel("bitor", "int32")
kernel = get_kernel("bitor", T.int32)
ref_val = a[0]
for i in range(1, a.shape[0]):
ref_val = ref_val | a[i]
......
......@@ -14,8 +14,8 @@ VEC_SIZE = 32
def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int):
@T.prim_func
def main(
a: T.Buffer((B, M, N), "bfloat16"),
a_out: T.Buffer((B, M, N), "float32"),
a: T.Buffer((B, M, N), T.bfloat16),
a_out: T.Buffer((B, M, N), T.float32),
):
with T.Kernel(
T.ceildiv(M, BLOCK_MN),
......@@ -23,7 +23,7 @@ def fused_index_kernel(B: int, M: int, N: int, BLOCK_MN: int, BLOCK_K: int):
B,
threads=128,
) as (pid_m, pid_n, pid_b):
a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), "float32")
a_fp32_local = T.alloc_fragment((BLOCK_MN * BLOCK_K // VEC_SIZE, VEC_SIZE), T.float32)
offs_m = pid_m * BLOCK_MN
offs_n = pid_n * BLOCK_K
......
......@@ -21,15 +21,15 @@ def bitwise_reduce(
):
@T.prim_func
def reduce_func(
A: T.Tensor((M, N), "int32"),
B: T.Tensor((M), "int32"),
Output: T.Tensor((M), "int32"),
A: T.Tensor((M, N), T.int32),
B: T.Tensor((M), T.int32),
Output: T.Tensor((M), T.int32),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_N), "int32")
A_fragment = T.alloc_fragment((block_M, block_N), "int32")
B_shared = T.alloc_shared((block_M,), "int32")
B_fragment = T.alloc_fragment((block_M), "int32")
A_shared = T.alloc_shared((block_M, block_N), T.int32)
A_fragment = T.alloc_fragment((block_M, block_N), T.int32)
B_shared = T.alloc_shared((block_M,), T.int32)
B_fragment = T.alloc_fragment((block_M), T.int32)
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(A_shared, A_fragment)
T.copy(B[by * block_M], B_shared)
......
......@@ -49,7 +49,7 @@ def check_non_fastmath_usage(source, mathop_name):
check_fastmath_usage(source, mathop_name, expect_fastmath=False)
def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32):
"""
Test single-argument mathops.
T.exp should generate expf (non-fastmath), T.__exp should generate __expf (fastmath)
......@@ -85,7 +85,7 @@ def run_single_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=3
print(f"✓ {mathop_name} compilation and execution test passed")
def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32):
"""
Test two-argument mathops to ensure they generate non-fastmath CUDA code.
"""
......@@ -133,7 +133,7 @@ def run_two_arg_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32,
check_non_fastmath_usage(source_fastmath, mathop_name)
# Test numerical correctness
torch_dtype = getattr(torch, dtype)
torch_dtype = dtype.as_torch()
a = torch.randn(M, N, device="cuda", dtype=torch_dtype)
b = torch.randn(M, N, device="cuda", dtype=torch_dtype)
......@@ -159,8 +159,8 @@ def run_abs_test():
@T.prim_func
def main(
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
A: T.Tensor((M, N), T.float32),
B: T.Tensor((M, N), T.float32),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N):
......@@ -188,7 +188,7 @@ def run_abs_test():
print("✓ abs numerical test passed")
def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype="float32"):
def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, block_N=32, dtype=T.float32):
"""
Test fastmath mathops to ensure they generate fastmath CUDA code (with __ prefix).
"""
......@@ -221,7 +221,7 @@ def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32,
check_fastmath_usage(source_fastmath, cuda_mathop_name, expect_fastmath=True)
# Test numerical correctness
torch_dtype = getattr(torch, dtype)
torch_dtype = dtype.as_torch()
a = torch.randn(M, N, device="cuda", dtype=torch_dtype)
# Ensure positive values for functions that need them
......@@ -273,7 +273,7 @@ def test_mathops_generate_no_fastmath():
]
for name, func in single_arg_mathops:
run_single_arg_mathop_test(name, func, dtype="float32")
run_single_arg_mathop_test(name, func, dtype=T.float32)
print(f"✓ {name} test passed")
......@@ -287,7 +287,7 @@ def test_two_arg_mathops_fastmath():
]
for name, func in two_arg_mathops:
run_two_arg_mathop_test(name, func, dtype="float32")
run_two_arg_mathop_test(name, func, dtype=T.float32)
@tilelang.testing.requires_cuda
......@@ -312,7 +312,7 @@ def test_fastmath_versions():
]
for name, func in fastmath_mathops:
run_fastmath_mathop_test(name, func, dtype="float32")
run_fastmath_mathop_test(name, func, dtype=T.float32)
print(f"✓ {name} test passed")
......
......@@ -5,7 +5,7 @@ import tilelang.testing
import pytest
def run_ieee_math_test(mathop_name, mathop_func, rounding_mode="rn", M=128, N=128, block_M=32, block_N=32, dtype="float32"):
def run_ieee_math_test(mathop_name, mathop_func, rounding_mode="rn", M=128, N=128, block_M=32, block_N=32, dtype=T.float32):
"""
Test IEEE-compliant math operations with specified rounding modes.
"""
......@@ -75,7 +75,7 @@ def run_ieee_math_test(mathop_name, mathop_func, rounding_mode="rn", M=128, N=12
print(f"✓ {mathop_name} compilation test passed")
# Test numerical execution
torch_dtype = getattr(torch, dtype)
torch_dtype = dtype.as_torch()
a = torch.randn(M, N, device="cuda", dtype=torch_dtype)
if num_inputs >= 2:
......@@ -186,8 +186,8 @@ def test_ieee_frsqrt_rn_only():
@T.prim_func
def main(
A: T.Tensor((128, 128), "float32"),
B: T.Tensor((128, 128), "float32"),
A: T.Tensor((128, 128), T.float32),
B: T.Tensor((128, 128), T.float32),
):
with T.Kernel(T.ceildiv(128, 32), T.ceildiv(128, 32), threads=128) as (bx, by):
for i, j in T.Parallel(32, 32):
......
......@@ -6,7 +6,7 @@ import torch
@tilelang.jit(execution_backend="torch")
def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float32, accum_dtype=T.float32):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
......@@ -39,13 +39,13 @@ def assert_gemm(
block_M,
block_N,
block_K,
dtype="float32",
accum_dtype="float",
dtype=T.float32,
accum_dtype=T.float32,
atol=1e-8,
):
jit_kernel = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype)
torch_dtype = getattr(torch, dtype)
torch_dtype = dtype.as_torch()
a, b = None, None
if "int" in dtype:
a = torch.randint(100, (M, K), dtype=torch_dtype, device="mps")
......@@ -69,12 +69,12 @@ def test_gemm_float32():
@tilelang.testing.requires_metal
def test_gemm_float16():
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype="float16", atol=1)
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype=T.float16, atol=1)
@tilelang.testing.requires_metal
def test_gemm_int32():
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype="int32", atol=1)
assert_gemm(1024, 1024, 1024, 16, 16, 16, dtype=T.int32, atol=1)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment