Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
...@@ -3,7 +3,7 @@ import tilelang.language as T ...@@ -3,7 +3,7 @@ import tilelang.language as T
@tilelang.jit(out_idx=[-1]) @tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func @T.prim_func
def gemm( def gemm(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
......
...@@ -9,14 +9,14 @@ import tilelang.testing ...@@ -9,14 +9,14 @@ import tilelang.testing
@tilelang.jit @tilelang.jit
def dynamic_smem_kernel(): def dynamic_smem_kernel():
# Symbolic length to drive dynamic shared memory allocation # Symbolic length to drive dynamic shared memory allocation
length = T.symbolic("len", dtype="int32") # noqa: F821 length = T.symbolic("len", dtype=T.int32) # noqa: F821
@T.prim_func @T.prim_func
def main(global_tensor: T.Tensor[(length,), "int32"]): # noqa: F821 def main(global_tensor: T.Tensor[(length,), T.int32]): # noqa: F821
# Launch a simple kernel that copies from global memory into shared memory # Launch a simple kernel that copies from global memory into shared memory
# using a dynamically-sized allocation. No writes back to global_tensor. # using a dynamically-sized allocation. No writes back to global_tensor.
with T.Kernel(1, threads=32) as _: with T.Kernel(1, threads=32) as _:
buffer_shared = T.alloc_shared((length,), dtype="int32") # noqa: F821 buffer_shared = T.alloc_shared((length,), dtype=T.int32) # noqa: F821
T.copy(buffer_shared, global_tensor) T.copy(buffer_shared, global_tensor)
return main return main
......
import tilelang.language as T
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import pytest import pytest
...@@ -23,8 +24,6 @@ def matmul( ...@@ -23,8 +24,6 @@ def matmul(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
...@@ -112,20 +111,20 @@ def run_gemm_ss( ...@@ -112,20 +111,20 @@ def run_gemm_ss(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[ [
(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 2, 128), (512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128),
(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 2, 128), (512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 2, 128), (512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 2, 128), (512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128),
(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128), (128, 8, 32, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 0, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128), (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128), (128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128), (128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128), (128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128), (128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
], ],
) )
def test_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): def test_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
...@@ -153,8 +152,6 @@ def matmul_rs( ...@@ -153,8 +152,6 @@ def matmul_rs(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_frag_shape = A_shared_shape A_frag_shape = A_shared_shape
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
...@@ -247,20 +244,20 @@ def run_gemm_rs( ...@@ -247,20 +244,20 @@ def run_gemm_rs(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[ [
(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2, 128), (512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2, 128), (512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2, 128), (512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2, 128), (512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128), (128, 8, 32, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 0, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128), (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128), (128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128), (128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128), (128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128), (128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
], ],
) )
def test_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): def test_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
...@@ -288,8 +285,6 @@ def matmul_sr( ...@@ -288,8 +285,6 @@ def matmul_sr(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
B_frag_shape = B_shared_shape B_frag_shape = B_shared_shape
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
...@@ -381,20 +376,20 @@ def run_gemm_sr( ...@@ -381,20 +376,20 @@ def run_gemm_sr(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[ [
(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2, 128), (512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2, 128), (512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2, 128), (512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2, 128), (512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128), (128, 8, 32, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 0, 128),
(128, 128, 32, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), (128, 128, 32, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 32, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), (128, 128, 32, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 32, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), (128, 128, 32, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 32, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), (128, 128, 32, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128), (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128), (128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128), (128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128), (128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128), (128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
], ],
) )
def test_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): def test_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
...@@ -519,22 +514,22 @@ def run_gemm_rr( ...@@ -519,22 +514,22 @@ def run_gemm_rr(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[ [
(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2, 128), (512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2, 128), (512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2, 128), (512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2, 128), (512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2, 128), (512, 1024, 768, False, True, T.bfloat16, T.bfloat16, T.float, 128, 256, 32, 2, 128),
(128, 8, 128, False, True, "float16", "float16", "float16", 128, 8, 32, 2, 128), (128, 8, 128, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 2, 128),
(128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 32, 2, 128), (128, 8, 128, False, True, T.int8, T.int8, T.int32, 128, 8, 32, 2, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128), (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128), (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128), (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128), (128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128), (128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128), (128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128), (128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
], ],
) )
def test_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): def test_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
......
...@@ -2,6 +2,7 @@ import pytest ...@@ -2,6 +2,7 @@ import pytest
import torch import torch
import tilelang import tilelang
import tilelang.testing import tilelang.testing
import tilelang.language as T
from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse
from tilelang.layout import make_cutlass_metadata_layout from tilelang.layout import make_cutlass_metadata_layout
...@@ -44,14 +45,12 @@ def matmul_sp_sm90( ...@@ -44,14 +45,12 @@ def matmul_sp_sm90(
trans_A, trans_A,
trans_B, trans_B,
): ):
E_factor = 4 if in_dtype == "float32" else 8 E_factor = 4 if in_dtype == T.float32 else 8
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (K, N) if not trans_B else (N, K) B_shape = (K, N) if not trans_B else (N, K)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype), A_sparse: T.Tensor(A_sparse_shape, in_dtype),
...@@ -104,15 +103,13 @@ def matmul_sp_sm80( ...@@ -104,15 +103,13 @@ def matmul_sp_sm80(
trans_B, trans_B,
): ):
is_8_bit = "8" in in_dtype is_8_bit = "8" in in_dtype
metadata_dtype = "int32" if is_8_bit else "int16" metadata_dtype = T.int32 if is_8_bit else T.int16
E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype] E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype]
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M) A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (K, N) if not trans_B else (N, K) B_shape = (K, N) if not trans_B else (N, K)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K) B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype), A_sparse: T.Tensor(A_sparse_shape, in_dtype),
...@@ -312,19 +309,18 @@ def run_gemm_sp_sm80( ...@@ -312,19 +309,18 @@ def run_gemm_sp_sm80(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B", "M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B",
[ [
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 2, 128, False, False), (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 32, 2, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 0, 256, False, False), (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 32, 0, 256, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, False), (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128, False, False), (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 2, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 0, 128, False, False), (512, 1024, 768, T.float16, T.float32, T.float32, 128, 128, 128, 0, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 2, 128, False, False), (512, 1024, 768, T.float16, T.float32, T.float32, 128, 128, 128, 2, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 0, 128, False, False), (512, 1024, 768, T.float16, T.float32, T.float32, 64, 128, 256, 0, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 2, 128, False, False), (512, 1024, 768, T.float16, T.float32, T.float32, 64, 128, 256, 2, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True), (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, True),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, False), (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, True), (512, 1024, 768, T.float8_e4m3fn, T.float16, T.float16, 64, 64, 64, 2, 128, False, True),
(512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False, True), (512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 2, 128, False, True),
(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True),
], ],
) )
def test_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B): def test_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B):
...@@ -337,21 +333,20 @@ def test_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_ ...@@ -337,21 +333,20 @@ def test_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B", "M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B",
[ [
(512, 1024, 768, "float16", "float32", "float32", 32, 32, 32, 0, 32, False, False), (512, 1024, 768, T.float16, T.float32, T.float32, 32, 32, 32, 0, 32, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, False), (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 32, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, False), (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False, True), (512, 1024, 768, T.float16, T.float32, T.float32, 32, 32, 64, 0, 32, False, True),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, True), (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 32, False, True),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True), (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, True),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 1, 128, False, False), (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 1, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128, False, False), (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 2, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 3, 128, False, False), (512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 3, 128, False, False),
(512, 1024, 768, "int8", "int32", "int32", 32, 32, 64, 0, 32, False, True), (512, 1024, 768, T.int8, T.int32, T.int32, 32, 32, 64, 0, 32, False, True),
(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 0, 32, False, True), (512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 0, 32, False, True),
(512, 1024, 768, "int8", "int32", "int32", 128, 128, 128, 0, 128, False, True), (512, 1024, 768, T.int8, T.int32, T.int32, 128, 128, 128, 0, 128, False, True),
(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 1, 128, False, True), (512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 1, 128, False, True),
(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True), (512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 2, 128, False, True),
(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 3, 128, False, True),
], ],
) )
def test_gemm_sp_sm80(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B): def test_gemm_sp_sm80(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B):
......
...@@ -7,6 +7,7 @@ from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmi ...@@ -7,6 +7,7 @@ from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmi
import tilelang.testing import tilelang.testing
import torch import torch
import tilelang.language as T
def matmul( def matmul(
...@@ -31,8 +32,6 @@ def matmul( ...@@ -31,8 +32,6 @@ def matmul(
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M) A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype), A_sparse: T.Tensor(A_sparse_shape, in_dtype),
...@@ -83,7 +82,7 @@ def run_gemm_ss( ...@@ -83,7 +82,7 @@ def run_gemm_ss(
num_stages=3, num_stages=3,
num_threads=128, num_threads=128,
): ):
metadata_dtype = "int32" if ("8" in in_dtype) else "int16" metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16
program = matmul( program = matmul(
M, M,
N, N,
...@@ -157,17 +156,17 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype): ...@@ -157,17 +156,17 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[ [
(512, 1024, 768, False, True, "float16", "float16", "float", 128, 128, 32, 2, 128), (512, 1024, 768, False, True, T.float16, T.float16, T.float, 128, 128, 32, 2, 128),
(512, 1024, 768, False, False, "float16", "float16", "float", 128, 128, 32, 2, 128), (512, 1024, 768, False, False, T.float16, T.float16, T.float, 128, 128, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float", 128, 128, 32, 2, 128), (512, 1024, 768, True, False, T.float16, T.float16, T.float, 128, 128, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float", 128, 128, 32, 2, 128), (512, 1024, 768, True, True, T.float16, T.float16, T.float, 128, 128, 32, 2, 128),
(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128), (128, 8, 64, False, True, T.float16, T.float16, T.float, 128, 8, 32, 0, 128),
(128, 128, 128, False, True, "int8", "int32", "int32", 128, 128, 64, 2, 128), (128, 128, 128, False, True, T.int8, T.int32, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128), (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128), (128, 128, 128, False, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128), (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128),
], ],
) )
def test_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): def test_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
...@@ -252,7 +251,7 @@ def run_gemm_rs( ...@@ -252,7 +251,7 @@ def run_gemm_rs(
num_stages=3, num_stages=3,
num_threads=128, num_threads=128,
): ):
metadata_dtype = "int32" if ("8" in in_dtype) else "int16" metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16
program = matmul_rs( program = matmul_rs(
M, M,
N, N,
...@@ -308,16 +307,16 @@ def run_gemm_rs( ...@@ -308,16 +307,16 @@ def run_gemm_rs(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[ [
(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2, 128), (512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2, 128), (512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2, 128), (512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2, 128), (512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128), (128, 8, 64, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 0, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2, 128), (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128), (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128), (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128),
], ],
) )
def test_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): def test_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
...@@ -402,7 +401,7 @@ def run_gemm_sr( ...@@ -402,7 +401,7 @@ def run_gemm_sr(
num_stages=3, num_stages=3,
num_threads=128, num_threads=128,
): ):
metadata_dtype = "int32" if ("8" in in_dtype) else "int16" metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16
program = matmul_sr( program = matmul_sr(
M, M,
N, N,
...@@ -458,16 +457,16 @@ def run_gemm_sr( ...@@ -458,16 +457,16 @@ def run_gemm_sr(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[ [
(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2, 128), (512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2, 128), (512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2, 128), (512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2, 128), (512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128), (128, 8, 64, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 0, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 128, 2, 128), (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 128, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 128, 2, 128), (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 128, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128), (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128), (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128),
], ],
) )
def test_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): def test_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
...@@ -556,7 +555,7 @@ def run_gemm_rr( ...@@ -556,7 +555,7 @@ def run_gemm_rr(
num_stages=3, num_stages=3,
num_threads=128, num_threads=128,
): ):
metadata_dtype = "int32" if ("8" in in_dtype) else "int16" metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16
program = matmul_rr( program = matmul_rr(
M, M,
N, N,
...@@ -612,18 +611,18 @@ def run_gemm_rr( ...@@ -612,18 +611,18 @@ def run_gemm_rr(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads", "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[ [
(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2, 128), (512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2, 128), (512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2, 128), (512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2, 128), (512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2, 128), (512, 1024, 768, False, True, T.bfloat16, T.bfloat16, T.float32, 128, 256, 32, 2, 128),
(128, 8, 128, False, True, "float16", "float16", "float", 128, 8, 32, 2, 128), (128, 8, 128, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 2, 128),
(128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 64, 2, 128), (128, 8, 128, False, True, T.int8, T.int8, T.int32, 128, 8, 64, 2, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2, 128), (128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), (128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128), (128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128), (128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128), (128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128),
], ],
) )
def test_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads): def test_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
......
...@@ -6,8 +6,8 @@ from tilelang.jit.adapter.utils import match_declare_kernel ...@@ -6,8 +6,8 @@ from tilelang.jit.adapter.utils import match_declare_kernel
def _simple_add_kernel(): def _simple_add_kernel():
@T.prim_func @T.prim_func
def main( def main(
x: T.Tensor((128,), "float32"), x: T.Tensor((128,), T.float32),
y: T.Tensor((128,), "float32"), y: T.Tensor((128,), T.float32),
): ):
# One-dimensional kernel; writes y from x without modifying x # One-dimensional kernel; writes y from x without modifying x
with T.Kernel(128, threads=32) as pid: with T.Kernel(128, threads=32) as pid:
......
...@@ -16,13 +16,13 @@ def _check(original, transformed): ...@@ -16,13 +16,13 @@ def _check(original, transformed):
def test_trival_pipeline(): def test_trival_pipeline():
@T.prim_func @T.prim_func
def before(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")): def before(A: T.Tensor((16, 1), T.float32), C: T.Tensor((16, 1), T.float32)):
for tx in T.thread_binding(0, 16, thread="threadIdx.x"): for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
for i in T.serial(0, 1, annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]}): for i in T.serial(0, 1, annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]}):
with T.block(): with T.block():
T.reads(A[tx, i]) T.reads(A[tx, i])
T.writes(C[tx, i]) T.writes(C[tx, i])
B = T.alloc_buffer((16, 1), dtype="float32", scope="shared") B = T.alloc_buffer((16, 1), dtype=T.float32, scope="shared")
with T.block(): with T.block():
T.reads(A[tx, i]) T.reads(A[tx, i])
T.writes(B[tx, 0]) T.writes(B[tx, 0])
......
...@@ -22,11 +22,11 @@ def _check(original, transformed): ...@@ -22,11 +22,11 @@ def _check(original, transformed):
def test_cluster_planning(): def test_cluster_planning():
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor((1024, 1024), "float16")): def before(A: T.Tensor((1024, 32), T.float16), B: T.Tensor((32, 1024), T.float16), C: T.Tensor((1024, 1024), T.float16)):
with T.Kernel(8, 8, threads=128) as (bx, by): with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float16") A_shared = T.alloc_shared((128, 32), T.float16)
B_shared = T.alloc_shared((32, 128), "float16") B_shared = T.alloc_shared((32, 128), T.float16)
C_local = T.alloc_fragment((128, 128), "float32") C_local = T.alloc_fragment((128, 128), T.float32)
T.clear(C_local) T.clear(C_local)
...@@ -39,12 +39,12 @@ def test_cluster_planning(): ...@@ -39,12 +39,12 @@ def test_cluster_planning():
T.copy(C_local, C[by * 128, bx * 128]) T.copy(C_local, C[by * 128, bx * 128])
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor((1024, 1024), "float16")): def after(A: T.Tensor((1024, 32), T.float16), B: T.Tensor((32, 1024), T.float16), C: T.Tensor((1024, 1024), T.float16)):
T.func_attr({"clusterIdx.y": T.int32(2)}) T.func_attr({"clusterIdx.y": T.int32(2)})
with T.Kernel(8, 8, threads=128) as (bx, by): with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float16") A_shared = T.alloc_shared((128, 32), T.float16)
B_shared = T.alloc_shared((32, 128), "float16") B_shared = T.alloc_shared((32, 128), T.float16)
C_local = T.alloc_fragment((128, 128), "float32") C_local = T.alloc_fragment((128, 128), T.float32)
T.clear(C_local) T.clear(C_local)
......
...@@ -19,8 +19,8 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal) ...@@ -19,8 +19,8 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
shape = [batch, heads, seq_len, dim] shape = [batch, heads, seq_len, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len] block_mask_shape = [batch, heads, downsample_len, downsample_len]
dtype = "bfloat16" dtype = T.bfloat16
accum_dtype = "float" accum_dtype = T.float32
block_mask_dtype = "bool" block_mask_dtype = "bool"
def kernel_func(block_M, block_N, num_stages, threads): def kernel_func(block_M, block_N, num_stages, threads):
......
...@@ -25,8 +25,8 @@ def test_lower_fence_proxy(): ...@@ -25,8 +25,8 @@ def test_lower_fence_proxy():
@T.prim_func @T.prim_func
def before(): def before():
with T.Kernel(8): with T.Kernel(8):
A_shared = T.decl_buffer((1, 8, 256), "float16", scope="shared.dyn") A_shared = T.decl_buffer((1, 8, 256), T.float16, scope="shared.dyn")
B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn") B_shared = T.decl_buffer((1, 4, 512), T.float16, scope="shared.dyn")
C_local = T.decl_buffer((32,), scope="local") C_local = T.decl_buffer((32,), scope="local")
for i in T.unroll(16): for i in T.unroll(16):
C_local[i * 2 : i * 2 + 2] = T.Broadcast(T.float32(0), 2) C_local[i * 2 : i * 2 + 2] = T.Broadcast(T.float32(0), 2)
...@@ -34,16 +34,16 @@ def test_lower_fence_proxy(): ...@@ -34,16 +34,16 @@ def test_lower_fence_proxy():
"handle", "handle",
tir.op.Op.get("tl.tl_gemm"), tir.op.Op.get("tl.tl_gemm"),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>", "tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3),
) )
@T.prim_func @T.prim_func
def after(): def after():
with T.Kernel(8): with T.Kernel(8):
A_shared = T.decl_buffer((1, 8, 256), "float16", scope="shared.dyn") A_shared = T.decl_buffer((1, 8, 256), T.float16, scope="shared.dyn")
B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn") B_shared = T.decl_buffer((1, 4, 512), T.float16, scope="shared.dyn")
C_local = T.decl_buffer((32,), scope="local") C_local = T.decl_buffer((32,), scope="local")
for i in T.unroll(16): for i in T.unroll(16):
C_local[i * 2 : i * 2 + 2] = T.Broadcast(T.float32(0), 2) C_local[i * 2 : i * 2 + 2] = T.Broadcast(T.float32(0), 2)
...@@ -52,9 +52,9 @@ def test_lower_fence_proxy(): ...@@ -52,9 +52,9 @@ def test_lower_fence_proxy():
"handle", "handle",
tir.op.Op.get("tl.tl_gemm"), tir.op.Op.get("tl.tl_gemm"),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>", "tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3),
) )
_check(before, after) _check(before, after)
...@@ -64,8 +64,8 @@ def test_async_to_generic_no_double_fence(): ...@@ -64,8 +64,8 @@ def test_async_to_generic_no_double_fence():
@T.prim_func @T.prim_func
def before(): def before():
with T.Kernel(8): with T.Kernel(8):
A_shared = T.decl_buffer((1024,), "uint8", scope="shared.dyn") A_shared = T.decl_buffer((1024,), T.uint8, scope="shared.dyn")
B_shared = T.decl_buffer((1024,), "uint8", scope="shared.dyn") B_shared = T.decl_buffer((1024,), T.uint8, scope="shared.dyn")
T.ptx_cp_async("uint8", A_shared.data, 0, B_shared.data, 0, 16) T.ptx_cp_async("uint8", A_shared.data, 0, B_shared.data, 0, 16)
T.fence_proxy_async() T.fence_proxy_async()
T.call_extern("handle", "generic_op") T.call_extern("handle", "generic_op")
...@@ -129,7 +129,7 @@ def test_tma_store_sync_injection(): ...@@ -129,7 +129,7 @@ def test_tma_store_sync_injection():
@T.prim_func @T.prim_func
def before(): def before():
with T.Kernel(8): with T.Kernel(8):
A_global = T.decl_buffer((128,), "float16", scope="global") A_global = T.decl_buffer((128,), T.float16, scope="global")
T.evaluate(T.call_intrin("handle", tir.op.Op.get("tl.tma_store"), A_global.data)) T.evaluate(T.call_intrin("handle", tir.op.Op.get("tl.tma_store"), A_global.data))
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main")) mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
...@@ -159,14 +159,14 @@ def test_wgmma_marked_async(): ...@@ -159,14 +159,14 @@ def test_wgmma_marked_async():
@T.prim_func @T.prim_func
def before(): def before():
with T.Kernel(1): with T.Kernel(1):
A_shared = T.decl_buffer((1,), "float16", scope="shared") A_shared = T.decl_buffer((1,), T.float16, scope="shared")
desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor.wgmma") desc_a = T.decl_buffer((1,), T.uint64, scope="local.descriptor.wgmma")
desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor.wgmma") desc_b = T.decl_buffer((1,), T.uint64, scope="local.descriptor.wgmma")
C_local = T.decl_buffer((32,), "float16", scope="local") C_local = T.decl_buffer((32,), T.float16, scope="local")
A_shared[0] = T.float16(0) A_shared[0] = T.float16(0)
T.warpgroup_arrive() T.warpgroup_arrive()
T.ptx_wgmma_ss( T.ptx_wgmma_ss(
"float16", T.float16,
"m64n64k16", "m64n64k16",
T.bool(True), T.bool(True),
T.bool(True), T.bool(True),
......
...@@ -9,7 +9,7 @@ def test_inject_set_max_nreg(): ...@@ -9,7 +9,7 @@ def test_inject_set_max_nreg():
"""Test the InjectSetMaxNReg pass""" """Test the InjectSetMaxNReg pass"""
@T.prim_func @T.prim_func
def before(A: T.Tensor((512, 512), "float16"), B: T.Tensor((512, 512), "float16")): def before(A: T.Tensor((512, 512), T.float16), B: T.Tensor((512, 512), T.float16)):
bx = T.launch_thread("blockIdx.x", 8) bx = T.launch_thread("blockIdx.x", 8)
by = T.launch_thread("blockIdx.y", 8) by = T.launch_thread("blockIdx.y", 8)
v = T.launch_thread("threadIdx.x", 128) v = T.launch_thread("threadIdx.x", 128)
...@@ -22,8 +22,8 @@ def test_inject_set_max_nreg(): ...@@ -22,8 +22,8 @@ def test_inject_set_max_nreg():
T.annotate_producer_reg_dealloc(24) # Producer: decrease to 24 T.annotate_producer_reg_dealloc(24) # Producer: decrease to 24
T.annotate_consumer_reg_alloc(240) # Consumer: increase to 240 T.annotate_consumer_reg_alloc(240) # Consumer: increase to 240
A_shared = T.alloc_buffer((3, 1, 8, 256), "float16", scope="shared.dyn") A_shared = T.alloc_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn")
B_shared = T.alloc_buffer((3, 1, 4, 512), "float16", scope="shared.dyn") B_shared = T.alloc_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn")
C_local = T.alloc_buffer((32,), scope="local") C_local = T.alloc_buffer((32,), scope="local")
T.create_list_of_mbarrier(128, 128, 128, 128, 128, 128) T.create_list_of_mbarrier(128, 128, 128, 128, 128, 128)
...@@ -37,7 +37,7 @@ def test_inject_set_max_nreg(): ...@@ -37,7 +37,7 @@ def test_inject_set_max_nreg():
T.tma_load( T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
T.get_mbarrier(k % 3), T.get_mbarrier(k % 3),
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32, k * 32,
by * 64, by * 64,
) )
...@@ -49,9 +49,9 @@ def test_inject_set_max_nreg(): ...@@ -49,9 +49,9 @@ def test_inject_set_max_nreg():
T.call_extern( T.call_extern(
"handle", "handle",
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3),
) )
T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)])) T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
...@@ -86,7 +86,7 @@ def test_inject_set_max_nreg_no_set_max_nreg(): ...@@ -86,7 +86,7 @@ def test_inject_set_max_nreg_no_set_max_nreg():
"""Test the InjectSetMaxNReg pass with no_set_max_nreg""" """Test the InjectSetMaxNReg pass with no_set_max_nreg"""
@T.prim_func @T.prim_func
def before_no_set_max_nreg(A: T.Tensor((512, 512), "float16")): def before_no_set_max_nreg(A: T.Tensor((512, 512), T.float16)):
bx = T.launch_thread("blockIdx.x", 8) bx = T.launch_thread("blockIdx.x", 8)
v = T.launch_thread("threadIdx.x", 128) v = T.launch_thread("threadIdx.x", 128)
......
...@@ -11,7 +11,7 @@ auto_target = tvm.target.Target(determine_target("auto")) ...@@ -11,7 +11,7 @@ auto_target = tvm.target.Target(determine_target("auto"))
@pytest.mark.parametrize( @pytest.mark.parametrize(
"block_M, block_N, block_K, threads, vec_load_b, dtype", "block_M, block_N, block_K, threads, vec_load_b, dtype",
[ [
(64, 64, 32, 128, 8, "float16"), (64, 64, 32, 128, 8, T.float16),
], ],
) )
def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
...@@ -102,4 +102,4 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -102,4 +102,4 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
if __name__ == "__main__": if __name__ == "__main__":
# tilelang.testing.main() # tilelang.testing.main()
test_loop_tail_split(64, 64, 32, 128, 8, "float16") test_loop_tail_split(64, 64, 32, 128, 8, T.float16)
...@@ -19,15 +19,15 @@ def test_buffer_load_negative_index_legalized(): ...@@ -19,15 +19,15 @@ def test_buffer_load_negative_index_legalized():
""" """
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024,), "float32")): def before(A: T.Tensor((1024,), T.float32)):
value = A[-1] value = A[-1]
B = T.alloc_buffer((1,), "float32") B = T.alloc_buffer((1,), T.float32)
B[0] = value B[0] = value
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024,), "float32")): def after(A: T.Tensor((1024,), T.float32)):
value = A[1023] # A[-1] becomes A[1023] value = A[1023] # A[-1] becomes A[1023]
B = T.alloc_buffer((1,), "float32") B = T.alloc_buffer((1,), T.float32)
B[0] = value B[0] = value
_check(before, after) _check(before, after)
...@@ -39,15 +39,15 @@ def test_buffer_load_mixed_negative_positive_indices(): ...@@ -39,15 +39,15 @@ def test_buffer_load_mixed_negative_positive_indices():
""" """
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024, 512), "float32")): def before(A: T.Tensor((1024, 512), T.float32)):
value = A[-1, 10] value = A[-1, 10]
B = T.alloc_buffer((1,), "float32") B = T.alloc_buffer((1,), T.float32)
B[0] = value B[0] = value
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024, 512), "float32")): def after(A: T.Tensor((1024, 512), T.float32)):
value = A[1023, 10] # A[-1, 10] becomes A[1023, 10] value = A[1023, 10] # A[-1, 10] becomes A[1023, 10]
B = T.alloc_buffer((1,), "float32") B = T.alloc_buffer((1,), T.float32)
B[0] = value B[0] = value
_check(before, after) _check(before, after)
...@@ -59,15 +59,15 @@ def test_buffer_load_multiple_negative_indices(): ...@@ -59,15 +59,15 @@ def test_buffer_load_multiple_negative_indices():
""" """
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024, 512, 256), "float32")): def before(A: T.Tensor((1024, 512, 256), T.float32)):
value = A[-1, -2, -3] value = A[-1, -2, -3]
B = T.alloc_buffer((1,), "float32") B = T.alloc_buffer((1,), T.float32)
B[0] = value B[0] = value
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024, 512, 256), "float32")): def after(A: T.Tensor((1024, 512, 256), T.float32)):
value = A[1023, 510, 253] # -1+1024=1023, -2+512=510, -3+256=253 value = A[1023, 510, 253] # -1+1024=1023, -2+512=510, -3+256=253
B = T.alloc_buffer((1,), "float32") B = T.alloc_buffer((1,), T.float32)
B[0] = value B[0] = value
_check(before, after) _check(before, after)
...@@ -79,15 +79,15 @@ def test_buffer_load_negative_index_in_expression(): ...@@ -79,15 +79,15 @@ def test_buffer_load_negative_index_in_expression():
""" """
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024,), "float32")): def before(A: T.Tensor((1024,), T.float32)):
B = T.alloc_buffer((1024,), "float32") B = T.alloc_buffer((1024,), T.float32)
for i in T.serial(1, 1024): for i in T.serial(1, 1024):
value = A[-i] value = A[-i]
B[-i] = value B[-i] = value
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024,), "float32")): def after(A: T.Tensor((1024,), T.float32)):
B = T.alloc_buffer((1024,), "float32") B = T.alloc_buffer((1024,), T.float32)
for i in T.serial(1, 1024): for i in T.serial(1, 1024):
value = A[1024 - i] value = A[1024 - i]
B[1024 - i] = value B[1024 - i] = value
...@@ -101,16 +101,16 @@ def test_buffer_load_non_negative_index_unchanged(): ...@@ -101,16 +101,16 @@ def test_buffer_load_non_negative_index_unchanged():
""" """
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024,), "float32")): def before(A: T.Tensor((1024,), T.float32)):
value = A[0] value = A[0]
B = T.alloc_buffer((1,), "float32") B = T.alloc_buffer((1,), T.float32)
B[0] = value B[0] = value
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024,), "float32")): def after(A: T.Tensor((1024,), T.float32)):
# No changes expected for non-negative indices # No changes expected for non-negative indices
value = A[0] value = A[0]
B = T.alloc_buffer((1,), "float32") B = T.alloc_buffer((1,), T.float32)
B[0] = value B[0] = value
_check(before, after) _check(before, after)
...@@ -123,18 +123,18 @@ def test_buffer_load_unknown_sign_index_warning(): ...@@ -123,18 +123,18 @@ def test_buffer_load_unknown_sign_index_warning():
""" """
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024,), "float32")): def before(A: T.Tensor((1024,), T.float32)):
i = T.Var("i", "int32") i = T.Var("i", T.int32)
value = A[i] value = A[i]
B = T.alloc_buffer((1,), "float32") B = T.alloc_buffer((1,), T.float32)
B[0] = value B[0] = value
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024,), "float32")): def after(A: T.Tensor((1024,), T.float32)):
i = T.Var("i", "int32") i = T.Var("i", T.int32)
# Unknown sign indices should remain unchanged # Unknown sign indices should remain unchanged
value = A[i] value = A[i]
B = T.alloc_buffer((1,), "float32") B = T.alloc_buffer((1,), T.float32)
B[0] = value B[0] = value
_check(before, after) _check(before, after)
...@@ -146,18 +146,18 @@ def test_buffer_load_vector_index_negative_broadcast(): ...@@ -146,18 +146,18 @@ def test_buffer_load_vector_index_negative_broadcast():
""" """
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024,), "float32")): def before(A: T.Tensor((1024,), T.float32)):
vec = T.Broadcast(-1, 4) vec = T.Broadcast(-1, 4)
value = A[vec] value = A[vec]
B = T.alloc_buffer((4,), "float32") B = T.alloc_buffer((4,), T.float32)
B[T.Ramp(0, 1, 4)] = value B[T.Ramp(0, 1, 4)] = value
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024,), "float32")): def after(A: T.Tensor((1024,), T.float32)):
# vec is unused and can be delimed by Simplify. # vec is unused and can be delimed by Simplify.
vec = T.Broadcast(-1, 4) # noqa: F841 vec = T.Broadcast(-1, 4) # noqa: F841
value = A[T.Broadcast(1023, 4)] value = A[T.Broadcast(1023, 4)]
B = T.alloc_buffer((4,), "float32") B = T.alloc_buffer((4,), T.float32)
B[T.Ramp(0, 1, 4)] = value B[T.Ramp(0, 1, 4)] = value
_check(before, after) _check(before, after)
...@@ -169,18 +169,18 @@ def test_buffer_load_vector_index_negative_ramp(): ...@@ -169,18 +169,18 @@ def test_buffer_load_vector_index_negative_ramp():
""" """
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024,), "float32")): def before(A: T.Tensor((1024,), T.float32)):
vec = T.Ramp(-4, 1, 4) # indices: [-4, -3, -2, -1] vec = T.Ramp(-4, 1, 4) # indices: [-4, -3, -2, -1]
value = A[vec] value = A[vec]
B = T.alloc_buffer((4,), "float32") B = T.alloc_buffer((4,), T.float32)
B[T.Ramp(0, 1, 4)] = value B[T.Ramp(0, 1, 4)] = value
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024,), "float32")): def after(A: T.Tensor((1024,), T.float32)):
# vec is unused and can be delimed by Simplify. # vec is unused and can be delimed by Simplify.
vec = T.Ramp(-4, 1, 4) # noqa: F841 vec = T.Ramp(-4, 1, 4) # noqa: F841
value = A[T.Ramp(1020, 1, 4)] value = A[T.Ramp(1020, 1, 4)]
B = T.alloc_buffer((4,), "float32") B = T.alloc_buffer((4,), T.float32)
B[T.Ramp(0, 1, 4)] = value B[T.Ramp(0, 1, 4)] = value
_check(before, after) _check(before, after)
...@@ -192,17 +192,17 @@ def test_buffer_load_nested_buffer_loads(): ...@@ -192,17 +192,17 @@ def test_buffer_load_nested_buffer_loads():
""" """
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024, 512), "float32")): def before(A: T.Tensor((1024, 512), T.float32)):
inner_val = A[-1, 10] inner_val = A[-1, 10]
outer_val = A[inner_val.astype("int32"), -2] outer_val = A[inner_val.astype(T.int32), -2]
B = T.alloc_buffer((1,), "float32") B = T.alloc_buffer((1,), T.float32)
B[0] = outer_val B[0] = outer_val
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024, 512), "float32")): def after(A: T.Tensor((1024, 512), T.float32)):
inner_val = A[1023, 10] inner_val = A[1023, 10]
outer_val = A[inner_val.astype("int32"), 510] outer_val = A[inner_val.astype(T.int32), 510]
B = T.alloc_buffer((1,), "float32") B = T.alloc_buffer((1,), T.float32)
B[0] = outer_val B[0] = outer_val
_check(before, after) _check(before, after)
...@@ -214,11 +214,11 @@ def test_buffer_store_negative_index(): ...@@ -214,11 +214,11 @@ def test_buffer_store_negative_index():
""" """
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024,), "float32")): def before(A: T.Tensor((1024,), T.float32)):
A[-1] = 42.0 A[-1] = 42.0
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024,), "float32")): def after(A: T.Tensor((1024,), T.float32)):
A[1023] = 42.0 A[1023] = 42.0
_check(before, after) _check(before, after)
...@@ -230,11 +230,11 @@ def test_buffer_store_mixed_negative_positive_indices(): ...@@ -230,11 +230,11 @@ def test_buffer_store_mixed_negative_positive_indices():
""" """
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024, 512), "float32")): def before(A: T.Tensor((1024, 512), T.float32)):
A[-1, 10] = 42.0 A[-1, 10] = 42.0
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024, 512), "float32")): def after(A: T.Tensor((1024, 512), T.float32)):
A[1023, 10] = 42.0 A[1023, 10] = 42.0
_check(before, after) _check(before, after)
...@@ -246,11 +246,11 @@ def test_buffer_store_multiple_negative_indices(): ...@@ -246,11 +246,11 @@ def test_buffer_store_multiple_negative_indices():
""" """
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024, 512, 256), "float32")): def before(A: T.Tensor((1024, 512, 256), T.float32)):
A[-1, -2, -3] = 42.0 A[-1, -2, -3] = 42.0
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024, 512, 256), "float32")): def after(A: T.Tensor((1024, 512, 256), T.float32)):
A[1023, 510, 253] = 42.0 # -1+1024=1023, -2+512=510, -3+256=253 A[1023, 510, 253] = 42.0 # -1+1024=1023, -2+512=510, -3+256=253
_check(before, after) _check(before, after)
...@@ -262,12 +262,12 @@ def test_buffer_store_negative_index_in_expression(): ...@@ -262,12 +262,12 @@ def test_buffer_store_negative_index_in_expression():
""" """
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024,), "float32")): def before(A: T.Tensor((1024,), T.float32)):
for i in T.serial(1, 1024): for i in T.serial(1, 1024):
A[-i] = i * 2.0 A[-i] = i * 2.0
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024,), "float32")): def after(A: T.Tensor((1024,), T.float32)):
for i in T.serial(1, 1024): for i in T.serial(1, 1024):
A[1024 - i] = i * 2.0 A[1024 - i] = i * 2.0
...@@ -280,13 +280,13 @@ def test_buffer_store_vector_index_negative_broadcast(): ...@@ -280,13 +280,13 @@ def test_buffer_store_vector_index_negative_broadcast():
""" """
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024,), "float32")): def before(A: T.Tensor((1024,), T.float32)):
vec = T.Broadcast(-1, 4) vec = T.Broadcast(-1, 4)
values = T.Broadcast(42.0, 4) values = T.Broadcast(42.0, 4)
A[vec] = values A[vec] = values
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024,), "float32")): def after(A: T.Tensor((1024,), T.float32)):
# vec is unused and can be delimed by Simplify. # vec is unused and can be delimed by Simplify.
vec = T.Broadcast(-1, 4) # noqa: F841 vec = T.Broadcast(-1, 4) # noqa: F841
values = T.Broadcast(42.0, 4) values = T.Broadcast(42.0, 4)
...@@ -301,13 +301,13 @@ def test_buffer_store_vector_index_negative_ramp(): ...@@ -301,13 +301,13 @@ def test_buffer_store_vector_index_negative_ramp():
""" """
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024,), "float32")): def before(A: T.Tensor((1024,), T.float32)):
vec = T.Ramp(-4, 1, 4) # indices: [-4, -3, -2, -1] vec = T.Ramp(-4, 1, 4) # indices: [-4, -3, -2, -1]
values = T.Ramp(0.0, 1.0, 4) # values: [0.0, 1.0, 2.0, 3.0] values = T.Ramp(0.0, 1.0, 4) # values: [0.0, 1.0, 2.0, 3.0]
A[vec] = values A[vec] = values
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024,), "float32")): def after(A: T.Tensor((1024,), T.float32)):
# vec is unused and can be delimed by Simplify. # vec is unused and can be delimed by Simplify.
vec = T.Ramp(-4, 1, 4) # noqa: F841 vec = T.Ramp(-4, 1, 4) # noqa: F841
values = T.Ramp(0.0, 1.0, 4) values = T.Ramp(0.0, 1.0, 4)
...@@ -322,14 +322,14 @@ def test_buffer_store_nested_in_condition(): ...@@ -322,14 +322,14 @@ def test_buffer_store_nested_in_condition():
""" """
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024,), "float32"), flag: T.int32): def before(A: T.Tensor((1024,), T.float32), flag: T.int32):
if flag > 0: if flag > 0:
A[-1] = 42.0 A[-1] = 42.0
else: else:
A[-2] = 24.0 A[-2] = 24.0
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024,), "float32"), flag: T.int32): def after(A: T.Tensor((1024,), T.float32), flag: T.int32):
if flag > 0: if flag > 0:
A[1023] = 42.0 A[1023] = 42.0
else: else:
......
...@@ -5,7 +5,7 @@ import tilelang.testing ...@@ -5,7 +5,7 @@ import tilelang.testing
def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2): def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2):
dtype = "float32" dtype = T.float32
@T.prim_func @T.prim_func
def main( def main(
...@@ -41,39 +41,8 @@ def assert_vectorize_access(M: int = 64, N: int = 64): ...@@ -41,39 +41,8 @@ def assert_vectorize_access(M: int = 64, N: int = 64):
tvm.ir.assert_structural_equal(transformed["main"].body, expected.body) tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)
# def issue_1013_buggy_kernel():
# # NOTE: This kernel is mainly to test some corner cases in boundary check
# num_tokens = T.dynamic('num_tokens')
# num_threads = 128
# @T.prim_func
# def main(x: T.Tensor((num_tokens,), dtype="int64")):
# with T.Kernel(1, threads=num_threads) as _:
# count = T.alloc_var('int')
# thread_idx = T.get_thread_binding()
# for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)):
# idx = thread_idx + i * num_threads
# count += x[idx] == 2
# # NOTE(chaofan): Ideally, the prover should be able to prove that the access is safe
# # and the padding value is not used. However, the current prover cannot handle this case.
# # So for now the expected kernel is a if-else statement to check the boundary.
# @T.prim_func
# def expected(x: T.Tensor((num_tokens,), dtype="int64")):
# with T.Kernel(1, threads=num_threads) as _:
# count = T.alloc_var('int')
# thread_idx = T.get_thread_binding()
# for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)):
# idx = thread_idx + i * num_threads
# count += T.Cast("int32",
# value=T.if_then_else(idx < num_tokens, x[idx], T.int64(0)) == T.int64(2))
# return main, expected
def vectorize_access_with_atmoic_add_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2): def vectorize_access_with_atmoic_add_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2):
dtype = "float32" dtype = T.float32
@T.prim_func @T.prim_func
def main( def main(
...@@ -115,7 +84,7 @@ def assert_vectorize_access_with_atmoic_add(M: int = 64, N: int = 64): ...@@ -115,7 +84,7 @@ def assert_vectorize_access_with_atmoic_add(M: int = 64, N: int = 64):
def oob_store_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2): def oob_store_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2):
dtype = "float32" dtype = T.float32
@T.prim_func @T.prim_func
def main( def main(
...@@ -152,13 +121,6 @@ def test_vectorize_access(): ...@@ -152,13 +121,6 @@ def test_vectorize_access():
assert_vectorize_access(64, 64) assert_vectorize_access(64, 64)
# def test_issue_1013():
# func, expected = issue_1013_buggy_kernel()
# mod = tvm.IRModule({func.attrs["global_symbol"]: func})
# transformed = tl.transform.LegalizeSafeMemoryAccess()(mod)
# tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)
def test_vectorize_access_with_atmoic_add(): def test_vectorize_access_with_atmoic_add():
assert_vectorize_access_with_atmoic_add(64, 64) assert_vectorize_access_with_atmoic_add(64, 64)
......
...@@ -5,12 +5,12 @@ import tilelang.testing ...@@ -5,12 +5,12 @@ import tilelang.testing
def vectorize_access_legalize(M: int = 64, N: int = 64): def vectorize_access_legalize(M: int = 64, N: int = 64):
dtype = "float32" dtype = T.float32
vec_len = 8 vec_len = 8
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N, vec_len), dtype="float32"), A: T.Tensor((M, N, vec_len), dtype=T.float32),
): ):
with T.Kernel(1, 1, threads=M) as (bx, by): with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype) A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype)
...@@ -21,7 +21,7 @@ def vectorize_access_legalize(M: int = 64, N: int = 64): ...@@ -21,7 +21,7 @@ def vectorize_access_legalize(M: int = 64, N: int = 64):
@T.prim_func @T.prim_func
def expected( def expected(
A: T.Tensor((M, N, vec_len), dtype="float32"), A: T.Tensor((M, N, vec_len), dtype=T.float32),
): ):
with T.Kernel(1, 1, threads=M) as (bx, by): with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype) A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype)
......
...@@ -13,7 +13,7 @@ def _check(original, transformed): ...@@ -13,7 +13,7 @@ def _check(original, transformed):
def test_let_binding(): def test_let_binding():
@T.prim_func @T.prim_func
def before(A: T.Tensor((128, 128), "float32"), B: T.Tensor((128, 128), "float32")): def before(A: T.Tensor((128, 128), T.float32), B: T.Tensor((128, 128), T.float32)):
for i in range(128): for i in range(128):
for j in range(128): for j in range(128):
with T.block("compute"): with T.block("compute"):
...@@ -22,7 +22,7 @@ def test_let_binding(): ...@@ -22,7 +22,7 @@ def test_let_binding():
B[i, j] = value B[i, j] = value
@T.prim_func @T.prim_func
def expected(A: T.Tensor((128, 128), "float32"), B: T.Tensor((128, 128), "float32")): def expected(A: T.Tensor((128, 128), T.float32), B: T.Tensor((128, 128), T.float32)):
for i in range(128): for i in range(128):
for j in range(128): for j in range(128):
with T.block("compute"): with T.block("compute"):
...@@ -33,14 +33,14 @@ def test_let_binding(): ...@@ -33,14 +33,14 @@ def test_let_binding():
def test_parallel_scope(): def test_parallel_scope():
@T.prim_func @T.prim_func
def before(A: T.Tensor((128,), "float32")): def before(A: T.Tensor((128,), T.float32)):
for i in T.Parallel(128): for i in T.Parallel(128):
with T.block("parallel"): with T.block("parallel"):
value = T.float32(1.0) value = T.float32(1.0)
A[i] = value A[i] = value
@T.prim_func @T.prim_func
def expected(A: T.Tensor((128,), "float32")): def expected(A: T.Tensor((128,), T.float32)):
for i in T.Parallel(128): for i in T.Parallel(128):
with T.block("parallel"): with T.block("parallel"):
A[i] = T.float32(1.0) A[i] = T.float32(1.0)
......
...@@ -11,7 +11,7 @@ auto_target = tvm.target.Target(determine_target("auto")) ...@@ -11,7 +11,7 @@ auto_target = tvm.target.Target(determine_target("auto"))
@pytest.mark.parametrize( @pytest.mark.parametrize(
"block_M, block_N, block_K, threads, vec_load_b, dtype", "block_M, block_N, block_K, threads, vec_load_b, dtype",
[ [
(64, 64, 32, 128, 8, "float16"), (64, 64, 32, 128, 8, T.float16),
], ],
) )
def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
......
...@@ -24,7 +24,7 @@ def _check(original, transformed): ...@@ -24,7 +24,7 @@ def _check(original, transformed):
M = 512 M = 512
N = 512 N = 512
K = 512 K = 512
dtype = "float16" dtype = T.float16
block_M = 64 block_M = 64
block_N = 64 block_N = 64
block_K = 32 block_K = 32
...@@ -39,8 +39,8 @@ def test_multi_version_buffer(): ...@@ -39,8 +39,8 @@ def test_multi_version_buffer():
with T.block(""): with T.block(""):
T.reads(A[by * 64, 0:481], B[0:481, bx * 64]) T.reads(A[by * 64, 0:481], B[0:481, bx * 64])
T.writes() T.writes()
A_shared = T.alloc_buffer((1, 8, 256), "float16", scope="shared.dyn") A_shared = T.alloc_buffer((1, 8, 256), T.float16, scope="shared.dyn")
B_shared = T.alloc_buffer((1, 4, 512), "float16", scope="shared.dyn") B_shared = T.alloc_buffer((1, 4, 512), T.float16, scope="shared.dyn")
C_local = T.alloc_buffer((32,), scope="local") C_local = T.alloc_buffer((32,), scope="local")
for i in T.unroll(16, annotations={"pragma_unroll_explicit": T.bool(False)}): for i in T.unroll(16, annotations={"pragma_unroll_explicit": T.bool(False)}):
for vec in T.vectorized(2): for vec in T.vectorized(2):
...@@ -50,7 +50,7 @@ def test_multi_version_buffer(): ...@@ -50,7 +50,7 @@ def test_multi_version_buffer():
T.tma_load( T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
0, 0,
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 2), T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 2),
k * 32, k * 32,
by * 64, by * 64,
) )
...@@ -58,16 +58,16 @@ def test_multi_version_buffer(): ...@@ -58,16 +58,16 @@ def test_multi_version_buffer():
T.tma_load( T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0),
0, 0,
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 2), T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, 0, 2048, 2),
bx * 64, bx * 64,
k * 32, k * 32,
) )
T.call_extern( T.call_extern(
"handle", "handle",
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1), T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3),
) )
@T.prim_func @T.prim_func
...@@ -78,8 +78,8 @@ def test_multi_version_buffer(): ...@@ -78,8 +78,8 @@ def test_multi_version_buffer():
with T.block(""): with T.block(""):
T.reads(A[by * 64, 0:481], B[0:481, bx * 64]) T.reads(A[by * 64, 0:481], B[0:481, bx * 64])
T.writes() T.writes()
A_shared = T.alloc_buffer((3, 1, 8, 256), "float16", scope="shared.dyn") A_shared = T.alloc_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn")
B_shared = T.alloc_buffer((3, 1, 4, 512), "float16", scope="shared.dyn") B_shared = T.alloc_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn")
C_local = T.alloc_buffer((32,), scope="local") C_local = T.alloc_buffer((32,), scope="local")
for i in T.unroll(16, annotations={"pragma_unroll_explicit": T.bool(False)}): for i in T.unroll(16, annotations={"pragma_unroll_explicit": T.bool(False)}):
for vec in T.vectorized(2): for vec in T.vectorized(2):
...@@ -89,7 +89,7 @@ def test_multi_version_buffer(): ...@@ -89,7 +89,7 @@ def test_multi_version_buffer():
T.tma_load( T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0), T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
0, 0,
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2), T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32, k * 32,
by * 64, by * 64,
) )
...@@ -97,16 +97,16 @@ def test_multi_version_buffer(): ...@@ -97,16 +97,16 @@ def test_multi_version_buffer():
T.tma_load( T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0), T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0),
0, 0,
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2), T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 2),
bx * 64, bx * 64,
k * 32, k * 32,
) )
T.call_extern( T.call_extern(
"handle", "handle",
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>", "tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1), T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1), T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3), T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3),
) )
_check(before, after) _check(before, after)
...@@ -114,10 +114,10 @@ def test_multi_version_buffer(): ...@@ -114,10 +114,10 @@ def test_multi_version_buffer():
def test_multi_version_buffer_with_let(): def test_multi_version_buffer_with_let():
@T.prim_func @T.prim_func
def before(scales: T.Tensor((4,), "float32")): def before(scales: T.Tensor((4,), T.float32)):
with T.block("root"): with T.block("root"):
shared = T.alloc_buffer((8,), "float32", scope="shared.dyn") shared = T.alloc_buffer((8,), T.float32, scope="shared.dyn")
accum = T.alloc_buffer((8,), "float32", scope="local") accum = T.alloc_buffer((8,), T.float32, scope="local")
for k in T.serial(4, annotations={"num_stages": T.int32(2)}): for k in T.serial(4, annotations={"num_stages": T.int32(2)}):
value = scales[k] value = scales[k]
for i in T.serial(8): for i in T.serial(8):
...@@ -126,10 +126,10 @@ def test_multi_version_buffer_with_let(): ...@@ -126,10 +126,10 @@ def test_multi_version_buffer_with_let():
accum[i] = accum[i] + shared[i] accum[i] = accum[i] + shared[i]
@T.prim_func @T.prim_func
def after(scales: T.Tensor((4,), "float32")): def after(scales: T.Tensor((4,), T.float32)):
with T.block("root"): with T.block("root"):
shared = T.alloc_buffer((2, 8), "float32", scope="shared.dyn") shared = T.alloc_buffer((2, 8), T.float32, scope="shared.dyn")
accum = T.alloc_buffer((8,), "float32", scope="local") accum = T.alloc_buffer((8,), T.float32, scope="local")
for k in T.serial(4, annotations={"num_stages": T.int32(2)}): for k in T.serial(4, annotations={"num_stages": T.int32(2)}):
value = scales[k] value = scales[k]
for i in T.serial(8): for i in T.serial(8):
......
...@@ -20,11 +20,11 @@ def _check(original, transformed): ...@@ -20,11 +20,11 @@ def _check(original, transformed):
def test_simple_pipeline(): def test_simple_pipeline():
@T.prim_func @T.prim_func
def before(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor((1024, 1024), "float32")): def before(A: T.Tensor((1024, 32), T.float32), B: T.Tensor((32, 1024), T.float32), C: T.Tensor((1024, 1024), T.float32)):
with T.Kernel(8, 8, threads=128) as (bx, by): with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float32") A_shared = T.alloc_shared((128, 32), T.float32)
B_shared = T.alloc_shared((32, 128), "float32") B_shared = T.alloc_shared((32, 128), T.float32)
C_local = T.alloc_fragment((128, 128), "float32") C_local = T.alloc_fragment((128, 128), T.float32)
T.clear(C_local) T.clear(C_local)
...@@ -37,11 +37,11 @@ def test_simple_pipeline(): ...@@ -37,11 +37,11 @@ def test_simple_pipeline():
T.copy(C_local, C[by * 128, bx * 128]) T.copy(C_local, C[by * 128, bx * 128])
@T.prim_func @T.prim_func
def after(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor((1024, 1024), "float32")): def after(A: T.Tensor((1024, 32), T.float32), B: T.Tensor((32, 1024), T.float32), C: T.Tensor((1024, 1024), T.float32)):
with T.Kernel(8, 8, threads=128) as (bx, by): with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float32") A_shared = T.alloc_shared((128, 32), T.float32)
B_shared = T.alloc_shared((32, 128), "float32") B_shared = T.alloc_shared((32, 128), T.float32)
C_local = T.alloc_fragment((128, 128), "float32") C_local = T.alloc_fragment((128, 128), T.float32)
T.clear(C_local) T.clear(C_local)
......
...@@ -22,9 +22,9 @@ def modify( ...@@ -22,9 +22,9 @@ def modify(
T.gemm(A, B, D) T.gemm(A, B, D)
else: else:
with T.block(): with T.block():
A_shared = T.alloc_shared((64, 64), dtype="float32") A_shared = T.alloc_shared((64, 64), dtype=T.float32)
C_shared = T.alloc_shared((64, 64), dtype="float32") C_shared = T.alloc_shared((64, 64), dtype=T.float32)
D_shared = T.alloc_shared((64, 64), dtype="float32") D_shared = T.alloc_shared((64, 64), dtype=T.float32)
T.copy(A, A_shared) T.copy(A, A_shared)
T.copy(C, C_shared) T.copy(C, C_shared)
T.gemm(A_shared, C_shared, D_shared) T.gemm(A_shared, C_shared, D_shared)
...@@ -40,7 +40,7 @@ def test_modify(with_B=False, with_bias=False): ...@@ -40,7 +40,7 @@ def test_modify(with_B=False, with_bias=False):
assert mod != mod2 assert mod != mod2
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func @T.prim_func
def main( def main(
a: T.handle, a: T.handle,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment