Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
from tilelang import tvm as tvm from tilelang import language as T
import tilelang.testing import tilelang.testing
import tilelang import tilelang
from tilelang.engine.callback import register_cuda_postproc_callback from tilelang.engine.callback import register_cuda_postproc_callback
...@@ -25,8 +25,6 @@ def matmul( ...@@ -25,8 +25,6 @@ def matmul(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
...@@ -107,9 +105,9 @@ def test_gemm_f16f16f16_nn(): ...@@ -107,9 +105,9 @@ def test_gemm_f16f16f16_nn():
768, 768,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float16", T.float16,
128, 128,
256, 256,
32, 32,
...@@ -137,8 +135,6 @@ def matmu_jit_kernel( ...@@ -137,8 +135,6 @@ def matmu_jit_kernel(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
...@@ -207,8 +203,6 @@ def run_gemm_jit_kernel( ...@@ -207,8 +203,6 @@ def run_gemm_jit_kernel(
B = B.T B = B.T
def ref_program(A, B): def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype)) C = C.to(torch.__getattribute__(out_dtype))
return C return C
...@@ -226,9 +220,9 @@ def test_gemm_jit_kernel(): ...@@ -226,9 +220,9 @@ def test_gemm_jit_kernel():
768, 768,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float16", T.float16,
128, 128,
256, 256,
32, 32,
......
from tilelang import tvm as tvm from tilelang import language as T
import tilelang.testing import tilelang.testing
import tilelang import tilelang
import torch import torch
...@@ -27,8 +27,6 @@ def matmul_kernel_jit( ...@@ -27,8 +27,6 @@ def matmul_kernel_jit(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
...@@ -95,8 +93,6 @@ def run_gemm_kernel_jit( ...@@ -95,8 +93,6 @@ def run_gemm_kernel_jit(
B = B.T B = B.T
def ref_program(A, B): def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float)) C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype)) C = C.to(torch.__getattribute__(out_dtype))
return C return C
...@@ -114,9 +110,9 @@ def test_gemm_f16f16f16_nn_kernel_jit(): ...@@ -114,9 +110,9 @@ def test_gemm_f16f16f16_nn_kernel_jit():
768, 768,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float16", T.float16,
128, 128,
128, 128,
32, 32,
......
...@@ -104,9 +104,9 @@ def test_gemm_f16f16f16_nn(): ...@@ -104,9 +104,9 @@ def test_gemm_f16f16f16_nn():
768, 768,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float16", T.float16,
128, 128,
256, 256,
32, 32,
...@@ -226,9 +226,9 @@ def test_gemm_jit_kernel(): ...@@ -226,9 +226,9 @@ def test_gemm_jit_kernel():
768, 768,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float16", T.float16,
128, 128,
256, 256,
32, 32,
...@@ -278,7 +278,7 @@ def run_cython_kernel_do_bench( ...@@ -278,7 +278,7 @@ def run_cython_kernel_do_bench(
def test_cython_kernel_do_bench(): def test_cython_kernel_do_bench():
run_cython_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_cython_kernel_do_bench(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
def run_cython_kernel_multi_stream( def run_cython_kernel_multi_stream(
...@@ -322,7 +322,7 @@ def run_cython_kernel_multi_stream( ...@@ -322,7 +322,7 @@ def run_cython_kernel_multi_stream(
def test_cython_kernel_multi_stream(): def test_cython_kernel_multi_stream():
run_cython_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_cython_kernel_multi_stream(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
def run_cython_dynamic_shape( def run_cython_dynamic_shape(
...@@ -371,11 +371,11 @@ def run_cython_dynamic_shape( ...@@ -371,11 +371,11 @@ def run_cython_dynamic_shape(
def test_cython_dynamic_shape(): def test_cython_dynamic_shape():
run_cython_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_cython_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
def run_cython_dynamic_shape_with_out_idx( def run_cython_dynamic_shape_with_out_idx(
...@@ -424,7 +424,7 @@ def run_cython_dynamic_shape_with_out_idx( ...@@ -424,7 +424,7 @@ def run_cython_dynamic_shape_with_out_idx(
def test_cython_dynamic_shape_with_out_idx(): def test_cython_dynamic_shape_with_out_idx():
run_cython_dynamic_shape_with_out_idx(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_cython_dynamic_shape_with_out_idx(T.dynamic("m"), 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
def matmul_int_variable( def matmul_int_variable(
...@@ -495,7 +495,7 @@ def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B ...@@ -495,7 +495,7 @@ def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B
def test_matmul_int_variable(): def test_matmul_int_variable():
run_matmul_int_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", "float32", 0, 128) run_matmul_int_variable(1024, 1024, 1024, 128, 128, 32, False, False, T.float16, T.float16, T.float32, 0, 128)
def matmul_float_variable( def matmul_float_variable(
...@@ -566,7 +566,7 @@ def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans ...@@ -566,7 +566,7 @@ def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans
def test_matmul_float_variable(): def test_matmul_float_variable():
run_matmul_float_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", "float32", 0, 128) run_matmul_float_variable(1024, 1024, 1024, 128, 128, 32, False, False, T.float16, T.float16, T.float32, 0, 128)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -7,7 +7,7 @@ from tilelang.utils import map_torch_type ...@@ -7,7 +7,7 @@ from tilelang.utils import map_torch_type
@tl.jit @tl.jit
def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float", with_bias=False): def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32, with_bias=False):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
...@@ -38,7 +38,7 @@ def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_ ...@@ -38,7 +38,7 @@ def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_
return main return main
def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def run_test(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype)) a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype))
b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype)) b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype))
c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype)) c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype))
......
...@@ -104,9 +104,9 @@ def test_gemm_f16f16f16_nn(): ...@@ -104,9 +104,9 @@ def test_gemm_f16f16f16_nn():
768, 768,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float16", T.float16,
128, 128,
256, 256,
32, 32,
...@@ -224,9 +224,9 @@ def test_gemm_jit_kernel(): ...@@ -224,9 +224,9 @@ def test_gemm_jit_kernel():
768, 768,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float16", T.float16,
128, 128,
256, 256,
32, 32,
...@@ -269,7 +269,7 @@ def run_nvrtc_kernel_do_bench( ...@@ -269,7 +269,7 @@ def run_nvrtc_kernel_do_bench(
def test_nvrtc_kernel_do_bench(): def test_nvrtc_kernel_do_bench():
run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
def run_nvrtc_kernel_multi_stream( def run_nvrtc_kernel_multi_stream(
...@@ -311,7 +311,7 @@ def run_nvrtc_kernel_multi_stream( ...@@ -311,7 +311,7 @@ def run_nvrtc_kernel_multi_stream(
def test_nvrtc_kernel_multi_stream(): def test_nvrtc_kernel_multi_stream():
run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
def run_nvrtc_dynamic_shape( def run_nvrtc_dynamic_shape(
...@@ -360,11 +360,11 @@ def run_nvrtc_dynamic_shape( ...@@ -360,11 +360,11 @@ def run_nvrtc_dynamic_shape(
def test_nvrtc_dynamic_shape(): def test_nvrtc_dynamic_shape():
run_nvrtc_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_nvrtc_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
def check_hopper(): def check_hopper():
...@@ -375,7 +375,7 @@ def check_hopper(): ...@@ -375,7 +375,7 @@ def check_hopper():
return compute_capability == (9, 0) return compute_capability == (9, 0)
def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"): def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32):
KH, KW = K, K KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
...@@ -463,7 +463,7 @@ def test_nvrtc_l2_persistent_map(): ...@@ -463,7 +463,7 @@ def test_nvrtc_l2_persistent_map():
M, M,
N, N,
block_size=256, block_size=256,
dtype="float32", dtype=T.float32,
): ):
@T.prim_func @T.prim_func
def kernel( def kernel(
......
import tilelang.testing import tilelang.testing
import tilelang import tilelang
import torch import torch
from tilelang import language as T
@tilelang.jit( @tilelang.jit(
...@@ -16,9 +17,9 @@ def matmul_kernel_jit( ...@@ -16,9 +17,9 @@ def matmul_kernel_jit(
block_K, block_K,
trans_A=False, trans_A=False,
trans_B=True, trans_B=True,
in_dtype="float16", in_dtype=T.float16,
out_dtype="float32", out_dtype=T.float32,
accum_dtype="float32", accum_dtype=T.float32,
num_stages=2, num_stages=2,
threads=128, threads=128,
): ):
......
...@@ -162,9 +162,9 @@ def test_gemm_jit_kernel(): ...@@ -162,9 +162,9 @@ def test_gemm_jit_kernel():
768, 768,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float16", T.float16,
128, 128,
256, 256,
32, 32,
...@@ -207,7 +207,7 @@ def run_tvm_ffi_kernel_do_bench( ...@@ -207,7 +207,7 @@ def run_tvm_ffi_kernel_do_bench(
def test_tvm_ffi_kernel_do_bench(): def test_tvm_ffi_kernel_do_bench():
run_tvm_ffi_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_tvm_ffi_kernel_do_bench(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
def run_tvm_ffi_kernel_multi_stream( def run_tvm_ffi_kernel_multi_stream(
...@@ -249,7 +249,7 @@ def run_tvm_ffi_kernel_multi_stream( ...@@ -249,7 +249,7 @@ def run_tvm_ffi_kernel_multi_stream(
def test_tvm_ffi_kernel_multi_stream(): def test_tvm_ffi_kernel_multi_stream():
run_tvm_ffi_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_tvm_ffi_kernel_multi_stream(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
def run_tvm_ffi_dynamic_shape( def run_tvm_ffi_dynamic_shape(
...@@ -298,12 +298,12 @@ def run_tvm_ffi_dynamic_shape( ...@@ -298,12 +298,12 @@ def run_tvm_ffi_dynamic_shape(
def test_tvm_ffi_dynamic_shape(): def test_tvm_ffi_dynamic_shape():
run_tvm_ffi_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_tvm_ffi_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
run_tvm_ffi_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) run_tvm_ffi_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
run_tvm_ffi_dynamic_shape( run_tvm_ffi_dynamic_shape(
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2 T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2
) )
...@@ -315,7 +315,7 @@ def check_hopper(): ...@@ -315,7 +315,7 @@ def check_hopper():
return compute_capability == (9, 0) return compute_capability == (9, 0)
def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"): def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32):
KH, KW = K, K KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1 OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1 OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
...@@ -403,7 +403,7 @@ def test_tvm_ffi_l2_persistent_map(): ...@@ -403,7 +403,7 @@ def test_tvm_ffi_l2_persistent_map():
M, M,
N, N,
block_size=256, block_size=256,
dtype="float32", dtype=T.float32,
): ):
@T.prim_func @T.prim_func
def kernel( def kernel(
......
...@@ -39,27 +39,27 @@ def tl_matmul( ...@@ -39,27 +39,27 @@ def tl_matmul(
accum_dtype, accum_dtype,
): ):
assert in_dtype in [ assert in_dtype in [
"float16", T.float16,
"bfloat16", T.bfloat16,
"float8_e4m3", T.float8_e4m3fn,
"float8_e5m2", T.float8_e5m2,
"int8", T.int8,
], "Currently only float16 and int8 are supported" ], "Currently only float16 and int8 are supported"
assert out_dtype in [ assert out_dtype in [
"float16", T.float16,
"float32", T.float32,
"int32", T.int32,
], "Currently only float16, float32 and int32 are supported" ], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in [ is_float8 = in_dtype in [
"float8_e4m3", T.float8_e4m3fn,
"float8_e5m2", T.float8_e5m2,
"float8_e4m3fn", T.float8_e4m3fn,
"float8_e5m2fnuz", T.float8_e5m2fnuz,
] ]
if out_dtype == "int32" or is_float8: if out_dtype == T.int32 or is_float8:
micro_size_k = 32 micro_size_k = 32
# This is a debug config # This is a debug config
...@@ -67,7 +67,7 @@ def tl_matmul( ...@@ -67,7 +67,7 @@ def tl_matmul(
block_col_warps = 2 block_col_warps = 2
warp_row_tiles = 32 warp_row_tiles = 32
warp_col_tiles = 32 warp_col_tiles = 32
chunk = 32 if in_dtype == "float16" else 64 chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn" shared_scope = "shared.dyn"
# Pipeline Stage # Pipeline Stage
...@@ -221,7 +221,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -221,7 +221,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 0) @tilelang.testing.requires_cuda_compute_version(8, 0)
def test_assert_tl_matmul_bfloat16(): def test_assert_tl_matmul_bfloat16():
assert_tl_matmul_correctness(256, 256, 256, "bfloat16", "float32", "float32") assert_tl_matmul_correctness(256, 256, 256, T.bfloat16, T.float32, T.float32)
if __name__ == "__main__": if __name__ == "__main__":
......
from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
from tilelang import language as T
import torch import torch
...@@ -12,8 +12,6 @@ def elementwise_add( ...@@ -12,8 +12,6 @@ def elementwise_add(
out_dtype, out_dtype,
threads, threads,
): ):
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), in_dtype), A: T.Tensor((M, N), in_dtype),
...@@ -67,8 +65,8 @@ def test_elementwise_add_f32(): ...@@ -67,8 +65,8 @@ def test_elementwise_add_f32():
run_elementwise_add( run_elementwise_add(
512, 512,
1024, 1024,
"float32", T.float32,
"float32", T.float32,
128, 128,
256, 256,
) )
...@@ -78,8 +76,8 @@ def test_elementwise_add_f16(): ...@@ -78,8 +76,8 @@ def test_elementwise_add_f16():
run_elementwise_add( run_elementwise_add(
512, 512,
1024, 1024,
"float16", T.float16,
"float16", T.float16,
128, 128,
256, 256,
) )
...@@ -89,8 +87,8 @@ def test_elementwise_add_i32(): ...@@ -89,8 +87,8 @@ def test_elementwise_add_i32():
run_elementwise_add( run_elementwise_add(
512, 512,
1024, 1024,
"int32", T.int32,
"int32", T.int32,
128, 128,
256, 256,
) )
...@@ -100,8 +98,8 @@ def test_elementwise_add_f32f16(): ...@@ -100,8 +98,8 @@ def test_elementwise_add_f32f16():
run_elementwise_add( run_elementwise_add(
512, 512,
1024, 1024,
"float32", T.float32,
"float16", T.float16,
128, 128,
256, 256,
) )
......
...@@ -54,8 +54,8 @@ def assert_matmul_correctness(M, N, K, block_M, block_N, block_K, in_dtype, out_ ...@@ -54,8 +54,8 @@ def assert_matmul_correctness(M, N, K, block_M, block_N, block_K, in_dtype, out_
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9) @tilelang.testing.requires_cuda_compute_version(9)
def test_assert_matmul(): def test_assert_matmul():
assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, "float8_e4m3", "float32", "float32") assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, T.float8_e4m3fn, T.float32, T.float32)
assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, "float8_e5m2", "float32", "float32") assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, T.float8_e5m2, T.float32, T.float32)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -39,26 +39,26 @@ def tl_matmul( ...@@ -39,26 +39,26 @@ def tl_matmul(
accum_dtype, accum_dtype,
): ):
assert in_dtype in [ assert in_dtype in [
"float16", T.float16,
"float8_e4m3", T.float8_e4m3fn,
"float8_e5m2", T.float8_e5m2,
"int8", T.int8,
], "Currently only float16 and int8 are supported" ], "Currently only float16 and int8 are supported"
assert out_dtype in [ assert out_dtype in [
"float16", T.float16,
"float32", T.float32,
"int32", T.int32,
], "Currently only float16, float32 and int32 are supported" ], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in [ is_float8 = in_dtype in [
"float8_e4m3", T.float8_e4m3fn,
"float8_e5m2", T.float8_e5m2,
"float8_e4m3fn", T.float8_e4m3fn,
"float8_e5m2fnuz", T.float8_e5m2fnuz,
] ]
if out_dtype == "int32" or is_float8: if out_dtype == T.int32 or is_float8:
micro_size_k = 32 micro_size_k = 32
# This is a debug config # This is a debug config
...@@ -66,7 +66,7 @@ def tl_matmul( ...@@ -66,7 +66,7 @@ def tl_matmul(
block_col_warps = 2 block_col_warps = 2
warp_row_tiles = 32 warp_row_tiles = 32
warp_col_tiles = 32 warp_col_tiles = 32
chunk = 32 if in_dtype == "float16" else 64 chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn" shared_scope = "shared.dyn"
# Pipeline Stage # Pipeline Stage
...@@ -221,8 +221,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -221,8 +221,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9) @tilelang.testing.requires_cuda_compute_version(8, 9)
def test_assert_tl_matmul(): def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3", "float32", "float32") assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32)
assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32") assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -46,7 +46,7 @@ def gemv_simt( ...@@ -46,7 +46,7 @@ def gemv_simt(
C_shape = (M, N) C_shape = (M, N)
dp4a_size = 4 dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32" use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32
@T.prim_func @T.prim_func
def main( def main(
...@@ -164,8 +164,8 @@ def evaluate_gemv_simt( ...@@ -164,8 +164,8 @@ def evaluate_gemv_simt(
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9) @tilelang.testing.requires_cuda_compute_version(8, 9)
def test_gemv_simt(): def test_gemv_simt():
evaluate_gemv_simt(1, 1024, 1024, "float8_e4m3", "float32", "float32", with_bias=False) evaluate_gemv_simt(1, 1024, 1024, T.float8_e4m3fn, T.float32, T.float32, with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, "float8_e5m2", "float32", "float32", with_bias=False) evaluate_gemv_simt(1, 1024, 1024, T.float8_e5m2, T.float32, T.float32, with_bias=False)
if __name__ == "__main__": if __name__ == "__main__":
......
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.testing import tilelang.testing
import tilelang.language as T
def matmul( def matmul(
...@@ -22,8 +23,6 @@ def matmul( ...@@ -22,8 +23,6 @@ def matmul(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K) A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N) B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor(A_shape, in_dtype), A: T.Tensor(A_shape, in_dtype),
...@@ -92,7 +91,7 @@ def run_gemm( ...@@ -92,7 +91,7 @@ def run_gemm(
A = A.T A = A.T
if trans_B: if trans_B:
B = B.T B = B.T
if in_dtype == "float32": if in_dtype == T.float32:
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate # Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas # float32 automatically, -0x1000 meas
A = (A.view(torch.int32) - 0x1000).view(torch.float32) A = (A.view(torch.int32) - 0x1000).view(torch.float32)
...@@ -111,9 +110,9 @@ def test_gemm_f16f16f16_nn(): ...@@ -111,9 +110,9 @@ def test_gemm_f16f16f16_nn():
768, 768,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float16", T.float16,
128, 128,
128, 128,
32, 32,
...@@ -128,9 +127,9 @@ def test_gemm_f16f16f32_nn(): ...@@ -128,9 +127,9 @@ def test_gemm_f16f16f32_nn():
768, 768,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float32", T.float32,
128, 128,
128, 128,
32, 32,
...@@ -144,9 +143,9 @@ def test_gemm_bf16bf16f32_nn(): ...@@ -144,9 +143,9 @@ def test_gemm_bf16bf16f32_nn():
768, 768,
False, False,
False, False,
"bfloat16", T.bfloat16,
"bfloat16", T.bfloat16,
"float32", T.float32,
128, 128,
128, 128,
32, 32,
...@@ -160,9 +159,9 @@ def test_gemm_f32f32f32_nn(): ...@@ -160,9 +159,9 @@ def test_gemm_f32f32f32_nn():
768, 768,
False, False,
False, False,
"float32", T.float32,
"float32", T.float32,
"float32", T.float32,
64, 64,
128, 128,
32, 32,
...@@ -176,9 +175,9 @@ def test_gemm_f16f16f16_tn(): ...@@ -176,9 +175,9 @@ def test_gemm_f16f16f16_tn():
768, 768,
True, True,
False, False,
"float16", T.float16,
"float16", T.float16,
"float16", T.float16,
128, 128,
128, 128,
32, 32,
...@@ -193,9 +192,9 @@ def test_gemm_f16f16f16_nt(): ...@@ -193,9 +192,9 @@ def test_gemm_f16f16f16_nt():
768, 768,
False, False,
True, True,
"float16", T.float16,
"float16", T.float16,
"float16", T.float16,
128, 128,
128, 128,
32, 32,
...@@ -204,15 +203,15 @@ def test_gemm_f16f16f16_nt(): ...@@ -204,15 +203,15 @@ def test_gemm_f16f16f16_nt():
def test_gemm_i8i8i32_nt(): def test_gemm_i8i8i32_nt():
run_gemm(512, 1024, 768, False, True, "int8", "int8", "int32", 128, 128, 64) run_gemm(512, 1024, 768, False, True, T.int8, T.int8, T.int32, 128, 128, 64)
def test_gemm_i8i8i32_tn(): def test_gemm_i8i8i32_tn():
run_gemm(512, 1024, 768, True, False, "int8", "int8", "int32", 128, 128, 64) run_gemm(512, 1024, 768, True, False, T.int8, T.int8, T.int32, 128, 128, 64)
def test_gemm_f64f64f64_nt(): def test_gemm_f64f64f64_nt():
run_gemm(512, 512, 512, False, True, "float64", "float64", "float64", 64, 32, 16) run_gemm(512, 512, 512, False, True, T.float64, T.float64, T.float64, 64, 32, 16)
def test_gemm_f32f32f32_nt(): def test_gemm_f32f32f32_nt():
...@@ -222,9 +221,9 @@ def test_gemm_f32f32f32_nt(): ...@@ -222,9 +221,9 @@ def test_gemm_f32f32f32_nt():
768, 768,
False, False,
True, True,
"float32", T.float32,
"float32", T.float32,
"float32", T.float32,
64, 64,
128, 128,
32, 32,
...@@ -238,9 +237,9 @@ def test_gemm_f32f32f32_tn(): ...@@ -238,9 +237,9 @@ def test_gemm_f32f32f32_tn():
768, 768,
True, True,
False, False,
"float32", T.float32,
"float32", T.float32,
"float32", T.float32,
64, 64,
128, 128,
32, 32,
...@@ -254,9 +253,9 @@ def test_pad_aligned_f16f16f16_nn(): ...@@ -254,9 +253,9 @@ def test_pad_aligned_f16f16f16_nn():
768 - 24, 768 - 24,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float16", T.float16,
128, 128,
256, 256,
32, 32,
...@@ -271,9 +270,9 @@ def test_pad_f16f16f16_nn(): ...@@ -271,9 +270,9 @@ def test_pad_f16f16f16_nn():
768 - 5, 768 - 5,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float16", T.float16,
128, 128,
256, 256,
32, 32,
...@@ -288,9 +287,9 @@ def test_pad_f16f16f32_nn(): ...@@ -288,9 +287,9 @@ def test_pad_f16f16f32_nn():
768 + 15, 768 + 15,
False, False,
False, False,
"float16", T.float16,
"float16", T.float16,
"float32", T.float32,
128, 128,
64, 64,
32, 32,
...@@ -407,9 +406,9 @@ def test_gemm_f16f16f16_sr(): ...@@ -407,9 +406,9 @@ def test_gemm_f16f16f16_sr():
768, 768,
False, False,
True, True,
"float16", T.float16,
"float16", T.float16,
"float16", T.float16,
128, 128,
128, 128,
32, 32,
...@@ -526,9 +525,9 @@ def test_gemm_f16f16f16_rs(): ...@@ -526,9 +525,9 @@ def test_gemm_f16f16f16_rs():
768, 768,
True, True,
False, False,
"float16", T.float16,
"float16", T.float16,
"float16", T.float16,
128, 128,
128, 128,
32, 32,
......
...@@ -39,27 +39,27 @@ def tl_matmul( ...@@ -39,27 +39,27 @@ def tl_matmul(
accum_dtype, accum_dtype,
): ):
assert in_dtype in [ assert in_dtype in [
"float16", T.float16,
"bfloat16", T.bfloat16,
"float8_e4m3", T.float8_e4m3fn,
"float8_e5m2", T.float8_e5m2,
"int8", T.int8,
], "Currently only float16 and int8 are supported" ], "Currently only float16 and int8 are supported"
assert out_dtype in [ assert out_dtype in [
"float16", T.float16,
"float32", T.float32,
"int32", T.int32,
], "Currently only float16, float32 and int32 are supported" ], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in [ is_float8 = in_dtype in [
"float8_e4m3", T.float8_e4m3fn,
"float8_e5m2", T.float8_e5m2,
"float8_e4m3fn", T.float8_e4m3fn,
"float8_e5m2fnuz", T.float8_e5m2fnuz,
] ]
if out_dtype == "int32" or is_float8: if out_dtype == T.int32 or is_float8:
micro_size_k = 32 micro_size_k = 32
# This is a debug config # This is a debug config
...@@ -67,7 +67,7 @@ def tl_matmul( ...@@ -67,7 +67,7 @@ def tl_matmul(
block_col_warps = 2 block_col_warps = 2
warp_row_tiles = 32 warp_row_tiles = 32
warp_col_tiles = 32 warp_col_tiles = 32
chunk = 32 if in_dtype == "float16" else 64 chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn" shared_scope = "shared.dyn"
# Pipeline Stage # Pipeline Stage
...@@ -219,22 +219,22 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -219,22 +219,22 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 0) @tilelang.testing.requires_cuda_compute_version(8, 0)
def test_assert_tl_matmul(): def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16)
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") assert_tl_matmul_correctness(128, 256, 256, T.float16, T.float32, T.float32)
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", "int32") assert_tl_matmul_correctness(128, 256, 256, T.int8, T.int32, T.int32)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 0) @tilelang.testing.requires_cuda_compute_version(8, 0)
def test_assert_tl_matmul_bfloat16(): def test_assert_tl_matmul_bfloat16():
assert_tl_matmul_correctness(256, 256, 256, "bfloat16", "float32", "float32") assert_tl_matmul_correctness(256, 256, 256, T.bfloat16, T.float32, T.float32)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9) @tilelang.testing.requires_cuda_compute_version(8, 9)
def test_assert_tl_matmul_fp8(): def test_assert_tl_matmul_fp8():
assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3", "float32", "float32") assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32)
assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32") assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -35,13 +35,13 @@ def tl_matmul_simt( ...@@ -35,13 +35,13 @@ def tl_matmul_simt(
accum_dtype, accum_dtype,
): ):
assert in_dtype in [ assert in_dtype in [
"float16", T.float16,
"int8", T.int8,
], "Currently only float16 and int8 are supported" ], "Currently only float16 and int8 are supported"
assert out_dtype in [ assert out_dtype in [
"float16", T.float16,
"float32", T.float32,
"int32", T.int32,
], "Currently only float16, float32 and int32 are supported" ], "Currently only float16, float32 and int32 are supported"
# This is a debug config # This is a debug config
...@@ -72,7 +72,7 @@ def tl_matmul_simt( ...@@ -72,7 +72,7 @@ def tl_matmul_simt(
micro_size_k = 128 // DataType(in_dtype).bits micro_size_k = 128 // DataType(in_dtype).bits
dp4a_size = 4 dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32" use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32
@T.prim_func @T.prim_func
def main( def main(
...@@ -139,7 +139,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -139,7 +139,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
# src_code is the generated cuda source # src_code is the generated cuda source
assert src_code is not None assert src_code is not None
if in_dtype == "int8": if in_dtype == T.int8:
A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8) A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8) B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8)
else: else:
...@@ -161,9 +161,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -161,9 +161,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
def test_assert_tl_matmul(): def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16") assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16)
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32") assert_tl_matmul_correctness(128, 256, 256, T.float16, T.float32, T.float32)
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", "int32") assert_tl_matmul_correctness(128, 256, 256, T.int8, T.int32, T.int32)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -4,7 +4,7 @@ import tilelang.language as T ...@@ -4,7 +4,7 @@ import tilelang.language as T
import torch import torch
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
......
...@@ -46,7 +46,7 @@ def gemv_simt( ...@@ -46,7 +46,7 @@ def gemv_simt(
C_shape = (M, N) C_shape = (M, N)
dp4a_size = 4 dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32" use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32
@T.prim_func @T.prim_func
def main( def main(
...@@ -164,15 +164,15 @@ def evaluate_gemv_simt( ...@@ -164,15 +164,15 @@ def evaluate_gemv_simt(
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 0) @tilelang.testing.requires_cuda_compute_version(8, 0)
def test_gemv_simt(): def test_gemv_simt():
evaluate_gemv_simt(1, 1024, 1024, "float16", "float16", "float16", with_bias=False) evaluate_gemv_simt(1, 1024, 1024, T.float16, T.float16, T.float16, with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, "int8", "int32", "int32", with_bias=False) evaluate_gemv_simt(1, 1024, 1024, T.int8, T.int32, T.int32, with_bias=False)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9) @tilelang.testing.requires_cuda_compute_version(8, 9)
def test_gemv_simt_fp8(): def test_gemv_simt_fp8():
evaluate_gemv_simt(1, 1024, 1024, "float8_e4m3", "float32", "float32", with_bias=False) evaluate_gemv_simt(1, 1024, 1024, T.float8_e4m3fn, T.float32, T.float32, with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, "float8_e5m2", "float32", "float32", with_bias=False) evaluate_gemv_simt(1, 1024, 1024, T.float8_e5m2, T.float32, T.float32, with_bias=False)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -26,20 +26,20 @@ def tl_matmul( ...@@ -26,20 +26,20 @@ def tl_matmul(
accum_dtype, accum_dtype,
): ):
assert in_dtype in [ assert in_dtype in [
"float16", T.float16,
"int8", T.int8,
], "Currently only float16 and int8 are supported" ], "Currently only float16 and int8 are supported"
assert out_dtype in [ assert out_dtype in [
"float16", T.float16,
"float32", T.float32,
"int32", T.int32,
], "Currently only float16, float32 and int32 are supported" ], "Currently only float16, float32 and int32 are supported"
K = K // 2 K = K // 2
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
if accum_dtype == "int32": if accum_dtype == T.int32:
micro_size_k = 32 micro_size_k = 32
# This is a debug config # This is a debug config
...@@ -47,7 +47,7 @@ def tl_matmul( ...@@ -47,7 +47,7 @@ def tl_matmul(
block_col_warps = 2 block_col_warps = 2
warp_row_tiles = 64 warp_row_tiles = 64
warp_col_tiles = 64 warp_col_tiles = 64
chunk = 32 if in_dtype == "float16" else 64 chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn" shared_scope = "shared.dyn"
# Pipeline Stage # Pipeline Stage
...@@ -197,8 +197,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): ...@@ -197,8 +197,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
def test_assert_tl_matmul_correctness(): def test_assert_tl_matmul_correctness():
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") assert_tl_matmul_correctness(128, 128, 128, T.int8, T.int32, T.int32)
assert_tl_matmul_correctness(128, 128, 64, "int8", "int32", "int32") assert_tl_matmul_correctness(128, 128, 64, T.int8, T.int32, T.int32)
@simplify_prim_func @simplify_prim_func
...@@ -212,18 +212,18 @@ def tl_matmul_weight_only_transform( ...@@ -212,18 +212,18 @@ def tl_matmul_weight_only_transform(
): ):
K = K // 2 K = K // 2
assert in_dtype in [ assert in_dtype in [
"float16", T.float16,
"int8", T.int8,
], "Currently only float16 and int8 are supported" ], "Currently only float16 and int8 are supported"
assert out_dtype in [ assert out_dtype in [
"float16", T.float16,
"float32", T.float32,
"int32", T.int32,
], "Currently only float16, float32 and int32 are supported" ], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16 micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32": if out_dtype == T.int32:
micro_size_k = 32 micro_size_k = 32
transform_b = 3 transform_b = 3
...@@ -233,7 +233,7 @@ def tl_matmul_weight_only_transform( ...@@ -233,7 +233,7 @@ def tl_matmul_weight_only_transform(
block_col_warps = 2 block_col_warps = 2
warp_row_tiles = 64 warp_row_tiles = 64
warp_col_tiles = 64 warp_col_tiles = 64
chunk = 32 if in_dtype == "float16" else 64 chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn" shared_scope = "shared.dyn"
# Pipeline Stage # Pipeline Stage
...@@ -375,8 +375,8 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt ...@@ -375,8 +375,8 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt
ladder_permutate_config = bitblas.ops.LadderPermutateConfig( ladder_permutate_config = bitblas.ops.LadderPermutateConfig(
M=N, M=N,
N=(K // 2), N=(K // 2),
datatype="int8", datatype=T.int8,
storage_dtype="int8", storage_dtype=T.int8,
transform_kind=transform_b, transform_kind=transform_b,
transpose_matrix=True, transpose_matrix=True,
) )
...@@ -400,9 +400,9 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt ...@@ -400,9 +400,9 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt
@tilelang.testing.requires_package("bitblas") @tilelang.testing.requires_package("bitblas")
@tilelang.testing.requires_llvm @tilelang.testing.requires_llvm
def test_assert_tl_matmul_weight_only_transform(): def test_assert_tl_matmul_weight_only_transform():
assert_tl_matmul_weight_only_transform_correctness(128, 128, 128, "int8", "int32", "int32") assert_tl_matmul_weight_only_transform_correctness(128, 128, 128, T.int8, T.int32, T.int32)
if __name__ == "__main__": if __name__ == "__main__":
# tilelang.testing.main() # tilelang.testing.main()
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32") assert_tl_matmul_correctness(128, 128, 128, T.int8, T.int32, T.int32)
...@@ -4,7 +4,7 @@ import tilelang.language as T ...@@ -4,7 +4,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, K), dtype), A: T.Tensor((M, K), dtype),
...@@ -43,7 +43,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -43,7 +43,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return main return main
def run_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): def run_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype) program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
kernel = tilelang.compile(program, out_idx=[2], target="cuda") kernel = tilelang.compile(program, out_idx=[2], target="cuda")
kernel.run_once() kernel.run_once()
......
...@@ -31,8 +31,8 @@ def blocksparse_matmul_global( ...@@ -31,8 +31,8 @@ def blocksparse_matmul_global(
num_stages, num_stages,
thread_num, thread_num,
enable_rasteration, enable_rasteration,
dtype="float16", dtype=T.float16,
accum_dtype="float", accum_dtype=T.float32,
): ):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
...@@ -75,8 +75,8 @@ def blocksparse_matmul_shared( ...@@ -75,8 +75,8 @@ def blocksparse_matmul_shared(
num_stages, num_stages,
thread_num, thread_num,
enable_rasteration, enable_rasteration,
dtype="float16", dtype=T.float16,
accum_dtype="float", accum_dtype=T.float32,
): ):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
...@@ -124,8 +124,8 @@ def blocksparse_matmul_local( ...@@ -124,8 +124,8 @@ def blocksparse_matmul_local(
num_stages, num_stages,
thread_num, thread_num,
enable_rasteration, enable_rasteration,
dtype="float16", dtype=T.float16,
accum_dtype="float", accum_dtype=T.float32,
): ):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim) block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment