Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
......@@ -3,7 +3,7 @@ import tilelang.language as T
@tilelang.jit(out_idx=[-1])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def gemm(
A: T.Tensor((M, K), dtype),
......
......@@ -9,14 +9,14 @@ import tilelang.testing
@tilelang.jit
def dynamic_smem_kernel():
# Symbolic length to drive dynamic shared memory allocation
length = T.symbolic("len", dtype="int32") # noqa: F821
length = T.symbolic("len", dtype=T.int32) # noqa: F821
@T.prim_func
def main(global_tensor: T.Tensor[(length,), "int32"]): # noqa: F821
def main(global_tensor: T.Tensor[(length,), T.int32]): # noqa: F821
# Launch a simple kernel that copies from global memory into shared memory
# using a dynamically-sized allocation. No writes back to global_tensor.
with T.Kernel(1, threads=32) as _:
buffer_shared = T.alloc_shared((length,), dtype="int32") # noqa: F821
buffer_shared = T.alloc_shared((length,), dtype=T.int32) # noqa: F821
T.copy(buffer_shared, global_tensor)
return main
......
import tilelang.language as T
from tilelang import tvm as tvm
import tilelang.testing
import pytest
......@@ -23,8 +24,6 @@ def matmul(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
......@@ -112,20 +111,20 @@ def run_gemm_ss(
@pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[
(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 2, 128),
(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 2, 128),
(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128),
(512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128),
(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128),
(512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128),
(512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 128, 32, 2, 128),
(128, 8, 32, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 0, 128),
(128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
],
)
def test_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
......@@ -153,8 +152,6 @@ def matmul_rs(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_frag_shape = A_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
......@@ -247,20 +244,20 @@ def run_gemm_rs(
@pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[
(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128),
(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(128, 8, 32, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 0, 128),
(128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
],
)
def test_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
......@@ -288,8 +285,6 @@ def matmul_sr(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
B_frag_shape = B_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
......@@ -381,20 +376,20 @@ def run_gemm_sr(
@pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[
(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128),
(128, 128, 32, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 32, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 32, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 32, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128),
(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(128, 8, 32, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 0, 128),
(128, 128, 32, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 32, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 32, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 32, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
],
)
def test_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
......@@ -519,22 +514,22 @@ def run_gemm_rr(
@pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[
(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2, 128),
(128, 8, 128, False, True, "float16", "float16", "float16", 128, 8, 32, 2, 128),
(128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 32, 2, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128),
(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, T.float16, T.float16, T.float16, 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, T.bfloat16, T.bfloat16, T.float, 128, 256, 32, 2, 128),
(128, 8, 128, False, True, T.float16, T.float16, T.float16, 128, 8, 32, 2, 128),
(128, 8, 128, False, True, T.int8, T.int8, T.int32, 128, 8, 32, 2, 128),
(128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, False, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, False, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, True, False, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
(128, 128, 128, True, True, T.float, T.float, T.float32, 128, 128, 32, 2, 128),
],
)
def test_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
......
......@@ -2,6 +2,7 @@ import pytest
import torch
import tilelang
import tilelang.testing
import tilelang.language as T
from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse
from tilelang.layout import make_cutlass_metadata_layout
......@@ -44,14 +45,12 @@ def matmul_sp_sm90(
trans_A,
trans_B,
):
E_factor = 4 if in_dtype == "float32" else 8
E_factor = 4 if in_dtype == T.float32 else 8
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (K, N) if not trans_B else (N, K)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)
import tilelang.language as T
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
......@@ -104,15 +103,13 @@ def matmul_sp_sm80(
trans_B,
):
is_8_bit = "8" in in_dtype
metadata_dtype = "int32" if is_8_bit else "int16"
metadata_dtype = T.int32 if is_8_bit else T.int16
E_factor = SparseTensorCoreIntrinEmitter.E_FACTOR_MAP[in_dtype][metadata_dtype]
A_sparse_shape = (M, K // 2) if not trans_A else (K // 2, M)
B_shape = (K, N) if not trans_B else (N, K)
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
B_shared_shape = (block_K, block_N) if not trans_B else (block_N, block_K)
import tilelang.language as T
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
......@@ -312,19 +309,18 @@ def run_gemm_sp_sm80(
@pytest.mark.parametrize(
"M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B",
[
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 2, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 0, 256, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 0, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 2, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 0, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 2, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, True),
(512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False, True),
(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True),
(512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 32, 2, 128, False, False),
(512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 32, 0, 256, False, False),
(512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, False),
(512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 2, 128, False, False),
(512, 1024, 768, T.float16, T.float32, T.float32, 128, 128, 128, 0, 128, False, False),
(512, 1024, 768, T.float16, T.float32, T.float32, 128, 128, 128, 2, 128, False, False),
(512, 1024, 768, T.float16, T.float32, T.float32, 64, 128, 256, 0, 128, False, False),
(512, 1024, 768, T.float16, T.float32, T.float32, 64, 128, 256, 2, 128, False, False),
(512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, True),
(512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, False),
(512, 1024, 768, T.float8_e4m3fn, T.float16, T.float16, 64, 64, 64, 2, 128, False, True),
(512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 2, 128, False, True),
],
)
def test_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B):
......@@ -337,21 +333,20 @@ def test_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_
@pytest.mark.parametrize(
"M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B",
[
(512, 1024, 768, "float16", "float32", "float32", 32, 32, 32, 0, 32, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False, True),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, True),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 1, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 3, 128, False, False),
(512, 1024, 768, "int8", "int32", "int32", 32, 32, 64, 0, 32, False, True),
(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 0, 32, False, True),
(512, 1024, 768, "int8", "int32", "int32", 128, 128, 128, 0, 128, False, True),
(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 1, 128, False, True),
(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True),
(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 3, 128, False, True),
(512, 1024, 768, T.float16, T.float32, T.float32, 32, 32, 32, 0, 32, False, False),
(512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 32, False, False),
(512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, False),
(512, 1024, 768, T.float16, T.float32, T.float32, 32, 32, 64, 0, 32, False, True),
(512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 32, False, True),
(512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 0, 128, False, True),
(512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 1, 128, False, False),
(512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 2, 128, False, False),
(512, 1024, 768, T.float16, T.float32, T.float32, 64, 64, 64, 3, 128, False, False),
(512, 1024, 768, T.int8, T.int32, T.int32, 32, 32, 64, 0, 32, False, True),
(512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 0, 32, False, True),
(512, 1024, 768, T.int8, T.int32, T.int32, 128, 128, 128, 0, 128, False, True),
(512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 1, 128, False, True),
(512, 1024, 768, T.int8, T.int32, T.int32, 64, 64, 64, 2, 128, False, True),
],
)
def test_gemm_sp_sm80(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B):
......
......@@ -7,6 +7,7 @@ from tilelang.intrinsics.mma_sp_macro_generator import SparseTensorCoreIntrinEmi
import tilelang.testing
import torch
import tilelang.language as T
def matmul(
......@@ -31,8 +32,6 @@ def matmul(
A_shared_shape = (block_M, block_K // 2) if not trans_A else (block_K // 2, block_M)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A_sparse: T.Tensor(A_sparse_shape, in_dtype),
......@@ -83,7 +82,7 @@ def run_gemm_ss(
num_stages=3,
num_threads=128,
):
metadata_dtype = "int32" if ("8" in in_dtype) else "int16"
metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16
program = matmul(
M,
N,
......@@ -157,17 +156,17 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
@pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[
(512, 1024, 768, False, True, "float16", "float16", "float", 128, 128, 32, 2, 128),
(512, 1024, 768, False, False, "float16", "float16", "float", 128, 128, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float", 128, 128, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float", 128, 128, 32, 2, 128),
(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128),
(128, 128, 128, False, True, "int8", "int32", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128),
(512, 1024, 768, False, True, T.float16, T.float16, T.float, 128, 128, 32, 2, 128),
(512, 1024, 768, False, False, T.float16, T.float16, T.float, 128, 128, 32, 2, 128),
(512, 1024, 768, True, False, T.float16, T.float16, T.float, 128, 128, 32, 2, 128),
(512, 1024, 768, True, True, T.float16, T.float16, T.float, 128, 128, 32, 2, 128),
(128, 8, 64, False, True, T.float16, T.float16, T.float, 128, 8, 32, 0, 128),
(128, 128, 128, False, True, T.int8, T.int32, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, False, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128),
(128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128),
],
)
def test_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
......@@ -252,7 +251,7 @@ def run_gemm_rs(
num_stages=3,
num_threads=128,
):
metadata_dtype = "int32" if ("8" in in_dtype) else "int16"
metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16
program = matmul_rs(
M,
N,
......@@ -308,16 +307,16 @@ def run_gemm_rs(
@pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[
(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2, 128),
(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128),
(512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(128, 8, 64, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 0, 128),
(128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128),
],
)
def test_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
......@@ -402,7 +401,7 @@ def run_gemm_sr(
num_stages=3,
num_threads=128,
):
metadata_dtype = "int32" if ("8" in in_dtype) else "int16"
metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16
program = matmul_sr(
M,
N,
......@@ -458,16 +457,16 @@ def run_gemm_sr(
@pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[
(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2, 128),
(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 128, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 128, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128),
(512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(128, 8, 64, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 0, 128),
(128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 128, 2, 128),
(128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 128, 2, 128),
(128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128),
],
)
def test_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
......@@ -556,7 +555,7 @@ def run_gemm_rr(
num_stages=3,
num_threads=128,
):
metadata_dtype = "int32" if ("8" in in_dtype) else "int16"
metadata_dtype = T.int32 if ("8" in in_dtype) else T.int16
program = matmul_rr(
M,
N,
......@@ -612,18 +611,18 @@ def run_gemm_rr(
@pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[
(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2, 128),
(128, 8, 128, False, True, "float16", "float16", "float", 128, 8, 32, 2, 128),
(128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 64, 2, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128),
(512, 1024, 768, False, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, T.float16, T.float16, T.float32, 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, T.bfloat16, T.bfloat16, T.float32, 128, 256, 32, 2, 128),
(128, 8, 128, False, True, T.float16, T.float16, T.float32, 128, 8, 32, 2, 128),
(128, 8, 128, False, True, T.int8, T.int8, T.int32, 128, 8, 64, 2, 128),
(128, 128, 128, False, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, False, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, False, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, True, T.int8, T.int8, T.int32, 128, 128, 64, 2, 128),
(128, 128, 128, True, True, T.float8_e5m2, T.float8_e5m2, T.float32, 128, 128, 64, 2, 128),
],
)
def test_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
......
......@@ -6,8 +6,8 @@ from tilelang.jit.adapter.utils import match_declare_kernel
def _simple_add_kernel():
@T.prim_func
def main(
x: T.Tensor((128,), "float32"),
y: T.Tensor((128,), "float32"),
x: T.Tensor((128,), T.float32),
y: T.Tensor((128,), T.float32),
):
# One-dimensional kernel; writes y from x without modifying x
with T.Kernel(128, threads=32) as pid:
......
......@@ -16,13 +16,13 @@ def _check(original, transformed):
def test_trival_pipeline():
@T.prim_func
def before(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")):
def before(A: T.Tensor((16, 1), T.float32), C: T.Tensor((16, 1), T.float32)):
for tx in T.thread_binding(0, 16, thread="threadIdx.x"):
for i in T.serial(0, 1, annotations={"software_pipeline_stage": [0, 1], "software_pipeline_order": [0, 1]}):
with T.block():
T.reads(A[tx, i])
T.writes(C[tx, i])
B = T.alloc_buffer((16, 1), dtype="float32", scope="shared")
B = T.alloc_buffer((16, 1), dtype=T.float32, scope="shared")
with T.block():
T.reads(A[tx, i])
T.writes(B[tx, 0])
......
......@@ -22,11 +22,11 @@ def _check(original, transformed):
def test_cluster_planning():
@T.prim_func
def before(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor((1024, 1024), "float16")):
def before(A: T.Tensor((1024, 32), T.float16), B: T.Tensor((32, 1024), T.float16), C: T.Tensor((1024, 1024), T.float16)):
with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float16")
B_shared = T.alloc_shared((32, 128), "float16")
C_local = T.alloc_fragment((128, 128), "float32")
A_shared = T.alloc_shared((128, 32), T.float16)
B_shared = T.alloc_shared((32, 128), T.float16)
C_local = T.alloc_fragment((128, 128), T.float32)
T.clear(C_local)
......@@ -39,12 +39,12 @@ def test_cluster_planning():
T.copy(C_local, C[by * 128, bx * 128])
@T.prim_func
def after(A: T.Tensor((1024, 32), "float16"), B: T.Tensor((32, 1024), "float16"), C: T.Tensor((1024, 1024), "float16")):
def after(A: T.Tensor((1024, 32), T.float16), B: T.Tensor((32, 1024), T.float16), C: T.Tensor((1024, 1024), T.float16)):
T.func_attr({"clusterIdx.y": T.int32(2)})
with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float16")
B_shared = T.alloc_shared((32, 128), "float16")
C_local = T.alloc_fragment((128, 128), "float32")
A_shared = T.alloc_shared((128, 32), T.float16)
B_shared = T.alloc_shared((32, 128), T.float16)
C_local = T.alloc_fragment((128, 128), T.float32)
T.clear(C_local)
......
......@@ -19,8 +19,8 @@ def blocksparse_flashattn(batch, heads, seq_len, dim, downsample_len, is_causal)
shape = [batch, heads, seq_len, dim]
block_mask_shape = [batch, heads, downsample_len, downsample_len]
dtype = "bfloat16"
accum_dtype = "float"
dtype = T.bfloat16
accum_dtype = T.float32
block_mask_dtype = "bool"
def kernel_func(block_M, block_N, num_stages, threads):
......
......@@ -25,8 +25,8 @@ def test_lower_fence_proxy():
@T.prim_func
def before():
with T.Kernel(8):
A_shared = T.decl_buffer((1, 8, 256), "float16", scope="shared.dyn")
B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn")
A_shared = T.decl_buffer((1, 8, 256), T.float16, scope="shared.dyn")
B_shared = T.decl_buffer((1, 4, 512), T.float16, scope="shared.dyn")
C_local = T.decl_buffer((32,), scope="local")
for i in T.unroll(16):
C_local[i * 2 : i * 2 + 2] = T.Broadcast(T.float32(0), 2)
......@@ -34,16 +34,16 @@ def test_lower_fence_proxy():
"handle",
tir.op.Op.get("tl.tl_gemm"),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3),
)
@T.prim_func
def after():
with T.Kernel(8):
A_shared = T.decl_buffer((1, 8, 256), "float16", scope="shared.dyn")
B_shared = T.decl_buffer((1, 4, 512), "float16", scope="shared.dyn")
A_shared = T.decl_buffer((1, 8, 256), T.float16, scope="shared.dyn")
B_shared = T.decl_buffer((1, 4, 512), T.float16, scope="shared.dyn")
C_local = T.decl_buffer((32,), scope="local")
for i in T.unroll(16):
C_local[i * 2 : i * 2 + 2] = T.Broadcast(T.float32(0), 2)
......@@ -52,9 +52,9 @@ def test_lower_fence_proxy():
"handle",
tir.op.Op.get("tl.tl_gemm"),
"tl::gemm_ss<128, 128, 32, 4, 1, 0, 0, 0, 32, 128, 0, 0, true>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3),
)
_check(before, after)
......@@ -64,8 +64,8 @@ def test_async_to_generic_no_double_fence():
@T.prim_func
def before():
with T.Kernel(8):
A_shared = T.decl_buffer((1024,), "uint8", scope="shared.dyn")
B_shared = T.decl_buffer((1024,), "uint8", scope="shared.dyn")
A_shared = T.decl_buffer((1024,), T.uint8, scope="shared.dyn")
B_shared = T.decl_buffer((1024,), T.uint8, scope="shared.dyn")
T.ptx_cp_async("uint8", A_shared.data, 0, B_shared.data, 0, 16)
T.fence_proxy_async()
T.call_extern("handle", "generic_op")
......@@ -129,7 +129,7 @@ def test_tma_store_sync_injection():
@T.prim_func
def before():
with T.Kernel(8):
A_global = T.decl_buffer((128,), "float16", scope="global")
A_global = T.decl_buffer((128,), T.float16, scope="global")
T.evaluate(T.call_intrin("handle", tir.op.Op.get("tl.tma_store"), A_global.data))
mod = tvm.IRModule.from_expr(before.with_attr("global_symbol", "main"))
......@@ -159,14 +159,14 @@ def test_wgmma_marked_async():
@T.prim_func
def before():
with T.Kernel(1):
A_shared = T.decl_buffer((1,), "float16", scope="shared")
desc_a = T.decl_buffer((1,), "uint64", scope="local.descriptor.wgmma")
desc_b = T.decl_buffer((1,), "uint64", scope="local.descriptor.wgmma")
C_local = T.decl_buffer((32,), "float16", scope="local")
A_shared = T.decl_buffer((1,), T.float16, scope="shared")
desc_a = T.decl_buffer((1,), T.uint64, scope="local.descriptor.wgmma")
desc_b = T.decl_buffer((1,), T.uint64, scope="local.descriptor.wgmma")
C_local = T.decl_buffer((32,), T.float16, scope="local")
A_shared[0] = T.float16(0)
T.warpgroup_arrive()
T.ptx_wgmma_ss(
"float16",
T.float16,
"m64n64k16",
T.bool(True),
T.bool(True),
......
......@@ -9,7 +9,7 @@ def test_inject_set_max_nreg():
"""Test the InjectSetMaxNReg pass"""
@T.prim_func
def before(A: T.Tensor((512, 512), "float16"), B: T.Tensor((512, 512), "float16")):
def before(A: T.Tensor((512, 512), T.float16), B: T.Tensor((512, 512), T.float16)):
bx = T.launch_thread("blockIdx.x", 8)
by = T.launch_thread("blockIdx.y", 8)
v = T.launch_thread("threadIdx.x", 128)
......@@ -22,8 +22,8 @@ def test_inject_set_max_nreg():
T.annotate_producer_reg_dealloc(24) # Producer: decrease to 24
T.annotate_consumer_reg_alloc(240) # Consumer: increase to 240
A_shared = T.alloc_buffer((3, 1, 8, 256), "float16", scope="shared.dyn")
B_shared = T.alloc_buffer((3, 1, 4, 512), "float16", scope="shared.dyn")
A_shared = T.alloc_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn")
B_shared = T.alloc_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn")
C_local = T.alloc_buffer((32,), scope="local")
T.create_list_of_mbarrier(128, 128, 128, 128, 128, 128)
......@@ -37,7 +37,7 @@ def test_inject_set_max_nreg():
T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
T.get_mbarrier(k % 3),
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32,
by * 64,
)
......@@ -49,9 +49,9 @@ def test_inject_set_max_nreg():
T.call_extern(
"handle",
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3),
)
T.evaluate(tir.Call("handle", "tir.ptx_arrive_barrier", [T.get_mbarrier(k % 3 + 3)]))
......@@ -86,7 +86,7 @@ def test_inject_set_max_nreg_no_set_max_nreg():
"""Test the InjectSetMaxNReg pass with no_set_max_nreg"""
@T.prim_func
def before_no_set_max_nreg(A: T.Tensor((512, 512), "float16")):
def before_no_set_max_nreg(A: T.Tensor((512, 512), T.float16)):
bx = T.launch_thread("blockIdx.x", 8)
v = T.launch_thread("threadIdx.x", 128)
......
......@@ -11,7 +11,7 @@ auto_target = tvm.target.Target(determine_target("auto"))
@pytest.mark.parametrize(
"block_M, block_N, block_K, threads, vec_load_b, dtype",
[
(64, 64, 32, 128, 8, "float16"),
(64, 64, 32, 128, 8, T.float16),
],
)
def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
......@@ -102,4 +102,4 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
if __name__ == "__main__":
# tilelang.testing.main()
test_loop_tail_split(64, 64, 32, 128, 8, "float16")
test_loop_tail_split(64, 64, 32, 128, 8, T.float16)
......@@ -19,15 +19,15 @@ def test_buffer_load_negative_index_legalized():
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
def before(A: T.Tensor((1024,), T.float32)):
value = A[-1]
B = T.alloc_buffer((1,), "float32")
B = T.alloc_buffer((1,), T.float32)
B[0] = value
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
def after(A: T.Tensor((1024,), T.float32)):
value = A[1023] # A[-1] becomes A[1023]
B = T.alloc_buffer((1,), "float32")
B = T.alloc_buffer((1,), T.float32)
B[0] = value
_check(before, after)
......@@ -39,15 +39,15 @@ def test_buffer_load_mixed_negative_positive_indices():
"""
@T.prim_func
def before(A: T.Tensor((1024, 512), "float32")):
def before(A: T.Tensor((1024, 512), T.float32)):
value = A[-1, 10]
B = T.alloc_buffer((1,), "float32")
B = T.alloc_buffer((1,), T.float32)
B[0] = value
@T.prim_func
def after(A: T.Tensor((1024, 512), "float32")):
def after(A: T.Tensor((1024, 512), T.float32)):
value = A[1023, 10] # A[-1, 10] becomes A[1023, 10]
B = T.alloc_buffer((1,), "float32")
B = T.alloc_buffer((1,), T.float32)
B[0] = value
_check(before, after)
......@@ -59,15 +59,15 @@ def test_buffer_load_multiple_negative_indices():
"""
@T.prim_func
def before(A: T.Tensor((1024, 512, 256), "float32")):
def before(A: T.Tensor((1024, 512, 256), T.float32)):
value = A[-1, -2, -3]
B = T.alloc_buffer((1,), "float32")
B = T.alloc_buffer((1,), T.float32)
B[0] = value
@T.prim_func
def after(A: T.Tensor((1024, 512, 256), "float32")):
def after(A: T.Tensor((1024, 512, 256), T.float32)):
value = A[1023, 510, 253] # -1+1024=1023, -2+512=510, -3+256=253
B = T.alloc_buffer((1,), "float32")
B = T.alloc_buffer((1,), T.float32)
B[0] = value
_check(before, after)
......@@ -79,15 +79,15 @@ def test_buffer_load_negative_index_in_expression():
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
B = T.alloc_buffer((1024,), "float32")
def before(A: T.Tensor((1024,), T.float32)):
B = T.alloc_buffer((1024,), T.float32)
for i in T.serial(1, 1024):
value = A[-i]
B[-i] = value
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
B = T.alloc_buffer((1024,), "float32")
def after(A: T.Tensor((1024,), T.float32)):
B = T.alloc_buffer((1024,), T.float32)
for i in T.serial(1, 1024):
value = A[1024 - i]
B[1024 - i] = value
......@@ -101,16 +101,16 @@ def test_buffer_load_non_negative_index_unchanged():
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
def before(A: T.Tensor((1024,), T.float32)):
value = A[0]
B = T.alloc_buffer((1,), "float32")
B = T.alloc_buffer((1,), T.float32)
B[0] = value
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
def after(A: T.Tensor((1024,), T.float32)):
# No changes expected for non-negative indices
value = A[0]
B = T.alloc_buffer((1,), "float32")
B = T.alloc_buffer((1,), T.float32)
B[0] = value
_check(before, after)
......@@ -123,18 +123,18 @@ def test_buffer_load_unknown_sign_index_warning():
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
i = T.Var("i", "int32")
def before(A: T.Tensor((1024,), T.float32)):
i = T.Var("i", T.int32)
value = A[i]
B = T.alloc_buffer((1,), "float32")
B = T.alloc_buffer((1,), T.float32)
B[0] = value
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
i = T.Var("i", "int32")
def after(A: T.Tensor((1024,), T.float32)):
i = T.Var("i", T.int32)
# Unknown sign indices should remain unchanged
value = A[i]
B = T.alloc_buffer((1,), "float32")
B = T.alloc_buffer((1,), T.float32)
B[0] = value
_check(before, after)
......@@ -146,18 +146,18 @@ def test_buffer_load_vector_index_negative_broadcast():
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
def before(A: T.Tensor((1024,), T.float32)):
vec = T.Broadcast(-1, 4)
value = A[vec]
B = T.alloc_buffer((4,), "float32")
B = T.alloc_buffer((4,), T.float32)
B[T.Ramp(0, 1, 4)] = value
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
def after(A: T.Tensor((1024,), T.float32)):
# vec is unused and can be delimed by Simplify.
vec = T.Broadcast(-1, 4) # noqa: F841
value = A[T.Broadcast(1023, 4)]
B = T.alloc_buffer((4,), "float32")
B = T.alloc_buffer((4,), T.float32)
B[T.Ramp(0, 1, 4)] = value
_check(before, after)
......@@ -169,18 +169,18 @@ def test_buffer_load_vector_index_negative_ramp():
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
def before(A: T.Tensor((1024,), T.float32)):
vec = T.Ramp(-4, 1, 4) # indices: [-4, -3, -2, -1]
value = A[vec]
B = T.alloc_buffer((4,), "float32")
B = T.alloc_buffer((4,), T.float32)
B[T.Ramp(0, 1, 4)] = value
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
def after(A: T.Tensor((1024,), T.float32)):
# vec is unused and can be delimed by Simplify.
vec = T.Ramp(-4, 1, 4) # noqa: F841
value = A[T.Ramp(1020, 1, 4)]
B = T.alloc_buffer((4,), "float32")
B = T.alloc_buffer((4,), T.float32)
B[T.Ramp(0, 1, 4)] = value
_check(before, after)
......@@ -192,17 +192,17 @@ def test_buffer_load_nested_buffer_loads():
"""
@T.prim_func
def before(A: T.Tensor((1024, 512), "float32")):
def before(A: T.Tensor((1024, 512), T.float32)):
inner_val = A[-1, 10]
outer_val = A[inner_val.astype("int32"), -2]
B = T.alloc_buffer((1,), "float32")
outer_val = A[inner_val.astype(T.int32), -2]
B = T.alloc_buffer((1,), T.float32)
B[0] = outer_val
@T.prim_func
def after(A: T.Tensor((1024, 512), "float32")):
def after(A: T.Tensor((1024, 512), T.float32)):
inner_val = A[1023, 10]
outer_val = A[inner_val.astype("int32"), 510]
B = T.alloc_buffer((1,), "float32")
outer_val = A[inner_val.astype(T.int32), 510]
B = T.alloc_buffer((1,), T.float32)
B[0] = outer_val
_check(before, after)
......@@ -214,11 +214,11 @@ def test_buffer_store_negative_index():
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
def before(A: T.Tensor((1024,), T.float32)):
A[-1] = 42.0
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
def after(A: T.Tensor((1024,), T.float32)):
A[1023] = 42.0
_check(before, after)
......@@ -230,11 +230,11 @@ def test_buffer_store_mixed_negative_positive_indices():
"""
@T.prim_func
def before(A: T.Tensor((1024, 512), "float32")):
def before(A: T.Tensor((1024, 512), T.float32)):
A[-1, 10] = 42.0
@T.prim_func
def after(A: T.Tensor((1024, 512), "float32")):
def after(A: T.Tensor((1024, 512), T.float32)):
A[1023, 10] = 42.0
_check(before, after)
......@@ -246,11 +246,11 @@ def test_buffer_store_multiple_negative_indices():
"""
@T.prim_func
def before(A: T.Tensor((1024, 512, 256), "float32")):
def before(A: T.Tensor((1024, 512, 256), T.float32)):
A[-1, -2, -3] = 42.0
@T.prim_func
def after(A: T.Tensor((1024, 512, 256), "float32")):
def after(A: T.Tensor((1024, 512, 256), T.float32)):
A[1023, 510, 253] = 42.0 # -1+1024=1023, -2+512=510, -3+256=253
_check(before, after)
......@@ -262,12 +262,12 @@ def test_buffer_store_negative_index_in_expression():
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
def before(A: T.Tensor((1024,), T.float32)):
for i in T.serial(1, 1024):
A[-i] = i * 2.0
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
def after(A: T.Tensor((1024,), T.float32)):
for i in T.serial(1, 1024):
A[1024 - i] = i * 2.0
......@@ -280,13 +280,13 @@ def test_buffer_store_vector_index_negative_broadcast():
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
def before(A: T.Tensor((1024,), T.float32)):
vec = T.Broadcast(-1, 4)
values = T.Broadcast(42.0, 4)
A[vec] = values
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
def after(A: T.Tensor((1024,), T.float32)):
# vec is unused and can be delimed by Simplify.
vec = T.Broadcast(-1, 4) # noqa: F841
values = T.Broadcast(42.0, 4)
......@@ -301,13 +301,13 @@ def test_buffer_store_vector_index_negative_ramp():
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32")):
def before(A: T.Tensor((1024,), T.float32)):
vec = T.Ramp(-4, 1, 4) # indices: [-4, -3, -2, -1]
values = T.Ramp(0.0, 1.0, 4) # values: [0.0, 1.0, 2.0, 3.0]
A[vec] = values
@T.prim_func
def after(A: T.Tensor((1024,), "float32")):
def after(A: T.Tensor((1024,), T.float32)):
# vec is unused and can be delimed by Simplify.
vec = T.Ramp(-4, 1, 4) # noqa: F841
values = T.Ramp(0.0, 1.0, 4)
......@@ -322,14 +322,14 @@ def test_buffer_store_nested_in_condition():
"""
@T.prim_func
def before(A: T.Tensor((1024,), "float32"), flag: T.int32):
def before(A: T.Tensor((1024,), T.float32), flag: T.int32):
if flag > 0:
A[-1] = 42.0
else:
A[-2] = 24.0
@T.prim_func
def after(A: T.Tensor((1024,), "float32"), flag: T.int32):
def after(A: T.Tensor((1024,), T.float32), flag: T.int32):
if flag > 0:
A[1023] = 42.0
else:
......
......@@ -5,7 +5,7 @@ import tilelang.testing
def vectorize_access_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2):
dtype = "float32"
dtype = T.float32
@T.prim_func
def main(
......@@ -41,39 +41,8 @@ def assert_vectorize_access(M: int = 64, N: int = 64):
tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)
# def issue_1013_buggy_kernel():
# # NOTE: This kernel is mainly to test some corner cases in boundary check
# num_tokens = T.dynamic('num_tokens')
# num_threads = 128
# @T.prim_func
# def main(x: T.Tensor((num_tokens,), dtype="int64")):
# with T.Kernel(1, threads=num_threads) as _:
# count = T.alloc_var('int')
# thread_idx = T.get_thread_binding()
# for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)):
# idx = thread_idx + i * num_threads
# count += x[idx] == 2
# # NOTE(chaofan): Ideally, the prover should be able to prove that the access is safe
# # and the padding value is not used. However, the current prover cannot handle this case.
# # So for now the expected kernel is a if-else statement to check the boundary.
# @T.prim_func
# def expected(x: T.Tensor((num_tokens,), dtype="int64")):
# with T.Kernel(1, threads=num_threads) as _:
# count = T.alloc_var('int')
# thread_idx = T.get_thread_binding()
# for i in T.serial(0, T.ceildiv(num_tokens - thread_idx, num_threads)):
# idx = thread_idx + i * num_threads
# count += T.Cast("int32",
# value=T.if_then_else(idx < num_tokens, x[idx], T.int64(0)) == T.int64(2))
# return main, expected
def vectorize_access_with_atmoic_add_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2):
dtype = "float32"
dtype = T.float32
@T.prim_func
def main(
......@@ -115,7 +84,7 @@ def assert_vectorize_access_with_atmoic_add(M: int = 64, N: int = 64):
def oob_store_legalize(M: int = 64, N: int = 64, M_offset: int = 2, N_offset: int = 2):
dtype = "float32"
dtype = T.float32
@T.prim_func
def main(
......@@ -152,13 +121,6 @@ def test_vectorize_access():
assert_vectorize_access(64, 64)
# def test_issue_1013():
# func, expected = issue_1013_buggy_kernel()
# mod = tvm.IRModule({func.attrs["global_symbol"]: func})
# transformed = tl.transform.LegalizeSafeMemoryAccess()(mod)
# tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)
def test_vectorize_access_with_atmoic_add():
assert_vectorize_access_with_atmoic_add(64, 64)
......
......@@ -5,12 +5,12 @@ import tilelang.testing
def vectorize_access_legalize(M: int = 64, N: int = 64):
dtype = "float32"
dtype = T.float32
vec_len = 8
@T.prim_func
def main(
A: T.Tensor((M, N, vec_len), dtype="float32"),
A: T.Tensor((M, N, vec_len), dtype=T.float32),
):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype)
......@@ -21,7 +21,7 @@ def vectorize_access_legalize(M: int = 64, N: int = 64):
@T.prim_func
def expected(
A: T.Tensor((M, N, vec_len), dtype="float32"),
A: T.Tensor((M, N, vec_len), dtype=T.float32),
):
with T.Kernel(1, 1, threads=M) as (bx, by):
A_shared = T.alloc_shared((M, N, vec_len), dtype=dtype)
......
......@@ -13,7 +13,7 @@ def _check(original, transformed):
def test_let_binding():
@T.prim_func
def before(A: T.Tensor((128, 128), "float32"), B: T.Tensor((128, 128), "float32")):
def before(A: T.Tensor((128, 128), T.float32), B: T.Tensor((128, 128), T.float32)):
for i in range(128):
for j in range(128):
with T.block("compute"):
......@@ -22,7 +22,7 @@ def test_let_binding():
B[i, j] = value
@T.prim_func
def expected(A: T.Tensor((128, 128), "float32"), B: T.Tensor((128, 128), "float32")):
def expected(A: T.Tensor((128, 128), T.float32), B: T.Tensor((128, 128), T.float32)):
for i in range(128):
for j in range(128):
with T.block("compute"):
......@@ -33,14 +33,14 @@ def test_let_binding():
def test_parallel_scope():
@T.prim_func
def before(A: T.Tensor((128,), "float32")):
def before(A: T.Tensor((128,), T.float32)):
for i in T.Parallel(128):
with T.block("parallel"):
value = T.float32(1.0)
A[i] = value
@T.prim_func
def expected(A: T.Tensor((128,), "float32")):
def expected(A: T.Tensor((128,), T.float32)):
for i in T.Parallel(128):
with T.block("parallel"):
A[i] = T.float32(1.0)
......
......@@ -11,7 +11,7 @@ auto_target = tvm.target.Target(determine_target("auto"))
@pytest.mark.parametrize(
"block_M, block_N, block_K, threads, vec_load_b, dtype",
[
(64, 64, 32, 128, 8, "float16"),
(64, 64, 32, 128, 8, T.float16),
],
)
def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
......
......@@ -24,7 +24,7 @@ def _check(original, transformed):
M = 512
N = 512
K = 512
dtype = "float16"
dtype = T.float16
block_M = 64
block_N = 64
block_K = 32
......@@ -39,8 +39,8 @@ def test_multi_version_buffer():
with T.block(""):
T.reads(A[by * 64, 0:481], B[0:481, bx * 64])
T.writes()
A_shared = T.alloc_buffer((1, 8, 256), "float16", scope="shared.dyn")
B_shared = T.alloc_buffer((1, 4, 512), "float16", scope="shared.dyn")
A_shared = T.alloc_buffer((1, 8, 256), T.float16, scope="shared.dyn")
B_shared = T.alloc_buffer((1, 4, 512), T.float16, scope="shared.dyn")
C_local = T.alloc_buffer((32,), scope="local")
for i in T.unroll(16, annotations={"pragma_unroll_explicit": T.bool(False)}):
for vec in T.vectorized(2):
......@@ -50,7 +50,7 @@ def test_multi_version_buffer():
T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
0,
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 2),
T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 2),
k * 32,
by * 64,
)
......@@ -58,16 +58,16 @@ def test_multi_version_buffer():
T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0),
0,
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 2),
T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, 0, 2048, 2),
bx * 64,
k * 32,
)
T.call_extern(
"handle",
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, 0, 2048, 1),
T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3),
)
@T.prim_func
......@@ -78,8 +78,8 @@ def test_multi_version_buffer():
with T.block(""):
T.reads(A[by * 64, 0:481], B[0:481, bx * 64])
T.writes()
A_shared = T.alloc_buffer((3, 1, 8, 256), "float16", scope="shared.dyn")
B_shared = T.alloc_buffer((3, 1, 4, 512), "float16", scope="shared.dyn")
A_shared = T.alloc_buffer((3, 1, 8, 256), T.float16, scope="shared.dyn")
B_shared = T.alloc_buffer((3, 1, 4, 512), T.float16, scope="shared.dyn")
C_local = T.alloc_buffer((32,), scope="local")
for i in T.unroll(16, annotations={"pragma_unroll_explicit": T.bool(False)}):
for vec in T.vectorized(2):
......@@ -89,7 +89,7 @@ def test_multi_version_buffer():
T.tma_load(
T.create_tma_descriptor(6, 2, A.data, 512, 512, 2, 1024, 32, 64, 1, 1, 0, 2, 2, 0),
0,
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 2),
T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 2),
k * 32,
by * 64,
)
......@@ -97,16 +97,16 @@ def test_multi_version_buffer():
T.tma_load(
T.create_tma_descriptor(6, 2, B.data, 512, 512, 2, 1024, 64, 32, 1, 1, 0, 3, 2, 0),
0,
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 2),
T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 2),
bx * 64,
k * 32,
)
T.call_extern(
"handle",
"tl::gemm_ss<64, 64, 32, 4, 1, 0, 0>",
T.tvm_access_ptr(T.type_annotation("float16"), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float16"), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation("float32"), C_local.data, 0, 32, 3),
T.tvm_access_ptr(T.type_annotation(T.float16), A_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation(T.float16), B_shared.data, k % 3 * 2048, 2048, 1),
T.tvm_access_ptr(T.type_annotation(T.float32), C_local.data, 0, 32, 3),
)
_check(before, after)
......@@ -114,10 +114,10 @@ def test_multi_version_buffer():
def test_multi_version_buffer_with_let():
@T.prim_func
def before(scales: T.Tensor((4,), "float32")):
def before(scales: T.Tensor((4,), T.float32)):
with T.block("root"):
shared = T.alloc_buffer((8,), "float32", scope="shared.dyn")
accum = T.alloc_buffer((8,), "float32", scope="local")
shared = T.alloc_buffer((8,), T.float32, scope="shared.dyn")
accum = T.alloc_buffer((8,), T.float32, scope="local")
for k in T.serial(4, annotations={"num_stages": T.int32(2)}):
value = scales[k]
for i in T.serial(8):
......@@ -126,10 +126,10 @@ def test_multi_version_buffer_with_let():
accum[i] = accum[i] + shared[i]
@T.prim_func
def after(scales: T.Tensor((4,), "float32")):
def after(scales: T.Tensor((4,), T.float32)):
with T.block("root"):
shared = T.alloc_buffer((2, 8), "float32", scope="shared.dyn")
accum = T.alloc_buffer((8,), "float32", scope="local")
shared = T.alloc_buffer((2, 8), T.float32, scope="shared.dyn")
accum = T.alloc_buffer((8,), T.float32, scope="local")
for k in T.serial(4, annotations={"num_stages": T.int32(2)}):
value = scales[k]
for i in T.serial(8):
......
......@@ -20,11 +20,11 @@ def _check(original, transformed):
def test_simple_pipeline():
@T.prim_func
def before(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor((1024, 1024), "float32")):
def before(A: T.Tensor((1024, 32), T.float32), B: T.Tensor((32, 1024), T.float32), C: T.Tensor((1024, 1024), T.float32)):
with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float32")
B_shared = T.alloc_shared((32, 128), "float32")
C_local = T.alloc_fragment((128, 128), "float32")
A_shared = T.alloc_shared((128, 32), T.float32)
B_shared = T.alloc_shared((32, 128), T.float32)
C_local = T.alloc_fragment((128, 128), T.float32)
T.clear(C_local)
......@@ -37,11 +37,11 @@ def test_simple_pipeline():
T.copy(C_local, C[by * 128, bx * 128])
@T.prim_func
def after(A: T.Tensor((1024, 32), "float32"), B: T.Tensor((32, 1024), "float32"), C: T.Tensor((1024, 1024), "float32")):
def after(A: T.Tensor((1024, 32), T.float32), B: T.Tensor((32, 1024), T.float32), C: T.Tensor((1024, 1024), T.float32)):
with T.Kernel(8, 8, threads=128) as (bx, by):
A_shared = T.alloc_shared((128, 32), "float32")
B_shared = T.alloc_shared((32, 128), "float32")
C_local = T.alloc_fragment((128, 128), "float32")
A_shared = T.alloc_shared((128, 32), T.float32)
B_shared = T.alloc_shared((32, 128), T.float32)
C_local = T.alloc_fragment((128, 128), T.float32)
T.clear(C_local)
......
......@@ -22,9 +22,9 @@ def modify(
T.gemm(A, B, D)
else:
with T.block():
A_shared = T.alloc_shared((64, 64), dtype="float32")
C_shared = T.alloc_shared((64, 64), dtype="float32")
D_shared = T.alloc_shared((64, 64), dtype="float32")
A_shared = T.alloc_shared((64, 64), dtype=T.float32)
C_shared = T.alloc_shared((64, 64), dtype=T.float32)
D_shared = T.alloc_shared((64, 64), dtype=T.float32)
T.copy(A, A_shared)
T.copy(C, C_shared)
T.gemm(A_shared, C_shared, D_shared)
......@@ -40,7 +40,7 @@ def test_modify(with_B=False, with_bias=False):
assert mod != mod2
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
a: T.handle,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment