Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
from tilelang import tvm as tvm
from tilelang import language as T
import tilelang.testing
import tilelang
from tilelang.engine.callback import register_cuda_postproc_callback
......@@ -25,8 +25,6 @@ def matmul(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
......@@ -107,9 +105,9 @@ def test_gemm_f16f16f16_nn():
768,
False,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
128,
256,
32,
......@@ -137,8 +135,6 @@ def matmu_jit_kernel(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
......@@ -207,8 +203,6 @@ def run_gemm_jit_kernel(
B = B.T
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
......@@ -226,9 +220,9 @@ def test_gemm_jit_kernel():
768,
False,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
128,
256,
32,
......
from tilelang import tvm as tvm
from tilelang import language as T
import tilelang.testing
import tilelang
import torch
......@@ -27,8 +27,6 @@ def matmul_kernel_jit(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
......@@ -95,8 +93,6 @@ def run_gemm_kernel_jit(
B = B.T
def ref_program(A, B):
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
......@@ -114,9 +110,9 @@ def test_gemm_f16f16f16_nn_kernel_jit():
768,
False,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
128,
128,
32,
......
......@@ -104,9 +104,9 @@ def test_gemm_f16f16f16_nn():
768,
False,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
128,
256,
32,
......@@ -226,9 +226,9 @@ def test_gemm_jit_kernel():
768,
False,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
128,
256,
32,
......@@ -278,7 +278,7 @@ def run_cython_kernel_do_bench(
def test_cython_kernel_do_bench():
run_cython_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cython_kernel_do_bench(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
def run_cython_kernel_multi_stream(
......@@ -322,7 +322,7 @@ def run_cython_kernel_multi_stream(
def test_cython_kernel_multi_stream():
run_cython_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cython_kernel_multi_stream(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
def run_cython_dynamic_shape(
......@@ -371,11 +371,11 @@ def run_cython_dynamic_shape(
def test_cython_dynamic_shape():
run_cython_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cython_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cython_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
def run_cython_dynamic_shape_with_out_idx(
......@@ -424,7 +424,7 @@ def run_cython_dynamic_shape_with_out_idx(
def test_cython_dynamic_shape_with_out_idx():
run_cython_dynamic_shape_with_out_idx(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_cython_dynamic_shape_with_out_idx(T.dynamic("m"), 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
def matmul_int_variable(
......@@ -495,7 +495,7 @@ def run_matmul_int_variable(M, N, K, block_M, block_N, block_K, trans_A, trans_B
def test_matmul_int_variable():
run_matmul_int_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", "float32", 0, 128)
run_matmul_int_variable(1024, 1024, 1024, 128, 128, 32, False, False, T.float16, T.float16, T.float32, 0, 128)
def matmul_float_variable(
......@@ -566,7 +566,7 @@ def run_matmul_float_variable(M, N, K, block_M, block_N, block_K, trans_A, trans
def test_matmul_float_variable():
run_matmul_float_variable(1024, 1024, 1024, 128, 128, 32, False, False, "float16", "float16", "float32", 0, 128)
run_matmul_float_variable(1024, 1024, 1024, 128, 128, 32, False, False, T.float16, T.float16, T.float32, 0, 128)
if __name__ == "__main__":
......
......@@ -7,7 +7,7 @@ from tilelang.utils import map_torch_type
@tl.jit
def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float", with_bias=False):
def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32, with_bias=False):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......@@ -38,7 +38,7 @@ def tensor_null_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_
return main
def run_test(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def run_test(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
a = torch.randn(M, K, device="cuda", dtype=map_torch_type(dtype))
b = torch.randn(N, K, device="cuda", dtype=map_torch_type(dtype))
c = torch.zeros(M, N, device="cuda", dtype=map_torch_type(accum_dtype))
......
......@@ -104,9 +104,9 @@ def test_gemm_f16f16f16_nn():
768,
False,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
128,
256,
32,
......@@ -224,9 +224,9 @@ def test_gemm_jit_kernel():
768,
False,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
128,
256,
32,
......@@ -269,7 +269,7 @@ def run_nvrtc_kernel_do_bench(
def test_nvrtc_kernel_do_bench():
run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_nvrtc_kernel_do_bench(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
def run_nvrtc_kernel_multi_stream(
......@@ -311,7 +311,7 @@ def run_nvrtc_kernel_multi_stream(
def test_nvrtc_kernel_multi_stream():
run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_nvrtc_kernel_multi_stream(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
def run_nvrtc_dynamic_shape(
......@@ -360,11 +360,11 @@ def run_nvrtc_dynamic_shape(
def test_nvrtc_dynamic_shape():
run_nvrtc_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_nvrtc_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_nvrtc_dynamic_shape(T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
def check_hopper():
......@@ -375,7 +375,7 @@ def check_hopper():
return compute_capability == (9, 0)
def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"):
def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
......@@ -463,7 +463,7 @@ def test_nvrtc_l2_persistent_map():
M,
N,
block_size=256,
dtype="float32",
dtype=T.float32,
):
@T.prim_func
def kernel(
......
import tilelang.testing
import tilelang
import torch
from tilelang import language as T
@tilelang.jit(
......@@ -16,9 +17,9 @@ def matmul_kernel_jit(
block_K,
trans_A=False,
trans_B=True,
in_dtype="float16",
out_dtype="float32",
accum_dtype="float32",
in_dtype=T.float16,
out_dtype=T.float32,
accum_dtype=T.float32,
num_stages=2,
threads=128,
):
......
......@@ -162,9 +162,9 @@ def test_gemm_jit_kernel():
768,
False,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
128,
256,
32,
......@@ -207,7 +207,7 @@ def run_tvm_ffi_kernel_do_bench(
def test_tvm_ffi_kernel_do_bench():
run_tvm_ffi_kernel_do_bench(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_tvm_ffi_kernel_do_bench(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
def run_tvm_ffi_kernel_multi_stream(
......@@ -249,7 +249,7 @@ def run_tvm_ffi_kernel_multi_stream(
def test_tvm_ffi_kernel_multi_stream():
run_tvm_ffi_kernel_multi_stream(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_tvm_ffi_kernel_multi_stream(512, 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
def run_tvm_ffi_dynamic_shape(
......@@ -298,12 +298,12 @@ def run_tvm_ffi_dynamic_shape(
def test_tvm_ffi_dynamic_shape():
run_tvm_ffi_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_tvm_ffi_dynamic_shape(T.dynamic("m"), 1024, 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
run_tvm_ffi_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_tvm_ffi_dynamic_shape(T.dynamic("m"), T.dynamic("n"), 768, False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2)
run_tvm_ffi_dynamic_shape(
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, "float16", "float16", "float16", 128, 256, 32, 2
T.dynamic("m"), T.dynamic("n"), T.dynamic("k"), False, False, T.float16, T.float16, T.float16, 128, 256, 32, 2
)
......@@ -315,7 +315,7 @@ def check_hopper():
return compute_capability == (9, 0)
def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype="float16", accum_dtype="float"):
def convolution_im2col(N, C, H, W, F, K, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32):
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
......@@ -403,7 +403,7 @@ def test_tvm_ffi_l2_persistent_map():
M,
N,
block_size=256,
dtype="float32",
dtype=T.float32,
):
@T.prim_func
def kernel(
......
......@@ -39,27 +39,27 @@ def tl_matmul(
accum_dtype,
):
assert in_dtype in [
"float16",
"bfloat16",
"float8_e4m3",
"float8_e5m2",
"int8",
T.float16,
T.bfloat16,
T.float8_e4m3fn,
T.float8_e5m2,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in [
"float8_e4m3",
"float8_e5m2",
"float8_e4m3fn",
"float8_e5m2fnuz",
T.float8_e4m3fn,
T.float8_e5m2,
T.float8_e4m3fn,
T.float8_e5m2fnuz,
]
if out_dtype == "int32" or is_float8:
if out_dtype == T.int32 or is_float8:
micro_size_k = 32
# This is a debug config
......@@ -67,7 +67,7 @@ def tl_matmul(
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32 if in_dtype == "float16" else 64
chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn"
# Pipeline Stage
......@@ -221,7 +221,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 0)
def test_assert_tl_matmul_bfloat16():
assert_tl_matmul_correctness(256, 256, 256, "bfloat16", "float32", "float32")
assert_tl_matmul_correctness(256, 256, 256, T.bfloat16, T.float32, T.float32)
if __name__ == "__main__":
......
from tilelang import tvm as tvm
import tilelang.testing
from tilelang import language as T
import torch
......@@ -12,8 +12,6 @@ def elementwise_add(
out_dtype,
threads,
):
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor((M, N), in_dtype),
......@@ -67,8 +65,8 @@ def test_elementwise_add_f32():
run_elementwise_add(
512,
1024,
"float32",
"float32",
T.float32,
T.float32,
128,
256,
)
......@@ -78,8 +76,8 @@ def test_elementwise_add_f16():
run_elementwise_add(
512,
1024,
"float16",
"float16",
T.float16,
T.float16,
128,
256,
)
......@@ -89,8 +87,8 @@ def test_elementwise_add_i32():
run_elementwise_add(
512,
1024,
"int32",
"int32",
T.int32,
T.int32,
128,
256,
)
......@@ -100,8 +98,8 @@ def test_elementwise_add_f32f16():
run_elementwise_add(
512,
1024,
"float32",
"float16",
T.float32,
T.float16,
128,
256,
)
......
......@@ -54,8 +54,8 @@ def assert_matmul_correctness(M, N, K, block_M, block_N, block_K, in_dtype, out_
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9)
def test_assert_matmul():
assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, "float8_e4m3", "float32", "float32")
assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, "float8_e5m2", "float32", "float32")
assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, T.float8_e4m3fn, T.float32, T.float32)
assert_matmul_correctness(1024, 1024, 1024, 128, 128, 64, T.float8_e5m2, T.float32, T.float32)
if __name__ == "__main__":
......
......@@ -39,26 +39,26 @@ def tl_matmul(
accum_dtype,
):
assert in_dtype in [
"float16",
"float8_e4m3",
"float8_e5m2",
"int8",
T.float16,
T.float8_e4m3fn,
T.float8_e5m2,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in [
"float8_e4m3",
"float8_e5m2",
"float8_e4m3fn",
"float8_e5m2fnuz",
T.float8_e4m3fn,
T.float8_e5m2,
T.float8_e4m3fn,
T.float8_e5m2fnuz,
]
if out_dtype == "int32" or is_float8:
if out_dtype == T.int32 or is_float8:
micro_size_k = 32
# This is a debug config
......@@ -66,7 +66,7 @@ def tl_matmul(
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32 if in_dtype == "float16" else 64
chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn"
# Pipeline Stage
......@@ -221,8 +221,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9)
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32)
assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32)
if __name__ == "__main__":
......
......@@ -46,7 +46,7 @@ def gemv_simt(
C_shape = (M, N)
dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32"
use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32
@T.prim_func
def main(
......@@ -164,8 +164,8 @@ def evaluate_gemv_simt(
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9)
def test_gemv_simt():
evaluate_gemv_simt(1, 1024, 1024, "float8_e4m3", "float32", "float32", with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, "float8_e5m2", "float32", "float32", with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, T.float8_e4m3fn, T.float32, T.float32, with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, T.float8_e5m2, T.float32, T.float32, with_bias=False)
if __name__ == "__main__":
......
from tilelang import tvm as tvm
import tilelang.testing
import tilelang.language as T
def matmul(
......@@ -22,8 +23,6 @@ def matmul(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
......@@ -92,7 +91,7 @@ def run_gemm(
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
if in_dtype == T.float32:
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
......@@ -111,9 +110,9 @@ def test_gemm_f16f16f16_nn():
768,
False,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
128,
128,
32,
......@@ -128,9 +127,9 @@ def test_gemm_f16f16f32_nn():
768,
False,
False,
"float16",
"float16",
"float32",
T.float16,
T.float16,
T.float32,
128,
128,
32,
......@@ -144,9 +143,9 @@ def test_gemm_bf16bf16f32_nn():
768,
False,
False,
"bfloat16",
"bfloat16",
"float32",
T.bfloat16,
T.bfloat16,
T.float32,
128,
128,
32,
......@@ -160,9 +159,9 @@ def test_gemm_f32f32f32_nn():
768,
False,
False,
"float32",
"float32",
"float32",
T.float32,
T.float32,
T.float32,
64,
128,
32,
......@@ -176,9 +175,9 @@ def test_gemm_f16f16f16_tn():
768,
True,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
128,
128,
32,
......@@ -193,9 +192,9 @@ def test_gemm_f16f16f16_nt():
768,
False,
True,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
128,
128,
32,
......@@ -204,15 +203,15 @@ def test_gemm_f16f16f16_nt():
def test_gemm_i8i8i32_nt():
run_gemm(512, 1024, 768, False, True, "int8", "int8", "int32", 128, 128, 64)
run_gemm(512, 1024, 768, False, True, T.int8, T.int8, T.int32, 128, 128, 64)
def test_gemm_i8i8i32_tn():
run_gemm(512, 1024, 768, True, False, "int8", "int8", "int32", 128, 128, 64)
run_gemm(512, 1024, 768, True, False, T.int8, T.int8, T.int32, 128, 128, 64)
def test_gemm_f64f64f64_nt():
run_gemm(512, 512, 512, False, True, "float64", "float64", "float64", 64, 32, 16)
run_gemm(512, 512, 512, False, True, T.float64, T.float64, T.float64, 64, 32, 16)
def test_gemm_f32f32f32_nt():
......@@ -222,9 +221,9 @@ def test_gemm_f32f32f32_nt():
768,
False,
True,
"float32",
"float32",
"float32",
T.float32,
T.float32,
T.float32,
64,
128,
32,
......@@ -238,9 +237,9 @@ def test_gemm_f32f32f32_tn():
768,
True,
False,
"float32",
"float32",
"float32",
T.float32,
T.float32,
T.float32,
64,
128,
32,
......@@ -254,9 +253,9 @@ def test_pad_aligned_f16f16f16_nn():
768 - 24,
False,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
128,
256,
32,
......@@ -271,9 +270,9 @@ def test_pad_f16f16f16_nn():
768 - 5,
False,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
128,
256,
32,
......@@ -288,9 +287,9 @@ def test_pad_f16f16f32_nn():
768 + 15,
False,
False,
"float16",
"float16",
"float32",
T.float16,
T.float16,
T.float32,
128,
64,
32,
......@@ -407,9 +406,9 @@ def test_gemm_f16f16f16_sr():
768,
False,
True,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
128,
128,
32,
......@@ -526,9 +525,9 @@ def test_gemm_f16f16f16_rs():
768,
True,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
128,
128,
32,
......
......@@ -39,27 +39,27 @@ def tl_matmul(
accum_dtype,
):
assert in_dtype in [
"float16",
"bfloat16",
"float8_e4m3",
"float8_e5m2",
"int8",
T.float16,
T.bfloat16,
T.float8_e4m3fn,
T.float8_e5m2,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
is_float8 = in_dtype in [
"float8_e4m3",
"float8_e5m2",
"float8_e4m3fn",
"float8_e5m2fnuz",
T.float8_e4m3fn,
T.float8_e5m2,
T.float8_e4m3fn,
T.float8_e5m2fnuz,
]
if out_dtype == "int32" or is_float8:
if out_dtype == T.int32 or is_float8:
micro_size_k = 32
# This is a debug config
......@@ -67,7 +67,7 @@ def tl_matmul(
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32 if in_dtype == "float16" else 64
chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn"
# Pipeline Stage
......@@ -219,22 +219,22 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 0)
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", "int32")
assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16)
assert_tl_matmul_correctness(128, 256, 256, T.float16, T.float32, T.float32)
assert_tl_matmul_correctness(128, 256, 256, T.int8, T.int32, T.int32)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 0)
def test_assert_tl_matmul_bfloat16():
assert_tl_matmul_correctness(256, 256, 256, "bfloat16", "float32", "float32")
assert_tl_matmul_correctness(256, 256, 256, T.bfloat16, T.float32, T.float32)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9)
def test_assert_tl_matmul_fp8():
assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32")
assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32)
assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32)
if __name__ == "__main__":
......
......@@ -35,13 +35,13 @@ def tl_matmul_simt(
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
# This is a debug config
......@@ -72,7 +72,7 @@ def tl_matmul_simt(
micro_size_k = 128 // DataType(in_dtype).bits
dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32"
use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32
@T.prim_func
def main(
......@@ -139,7 +139,7 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
# src_code is the generated cuda source
assert src_code is not None
if in_dtype == "int8":
if in_dtype == T.int8:
A = torch.randint(-128, 127, (M, K), device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, (N, K), device="cuda", dtype=torch.int8)
else:
......@@ -161,9 +161,9 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", "float32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", "int32")
assert_tl_matmul_correctness(128, 128, 128, T.float16, T.float16, T.float16)
assert_tl_matmul_correctness(128, 256, 256, T.float16, T.float32, T.float32)
assert_tl_matmul_correctness(128, 256, 256, T.int8, T.int32, T.int32)
if __name__ == "__main__":
......
......@@ -4,7 +4,7 @@ import tilelang.language as T
import torch
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......
......@@ -46,7 +46,7 @@ def gemv_simt(
C_shape = (M, N)
dp4a_size = 4
use_dp4a = in_dtype == "int8" and accum_dtype == "int32"
use_dp4a = in_dtype == T.int8 and accum_dtype == T.int32
@T.prim_func
def main(
......@@ -164,15 +164,15 @@ def evaluate_gemv_simt(
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 0)
def test_gemv_simt():
evaluate_gemv_simt(1, 1024, 1024, "float16", "float16", "float16", with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, "int8", "int32", "int32", with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, T.float16, T.float16, T.float16, with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, T.int8, T.int32, T.int32, with_bias=False)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(8, 9)
def test_gemv_simt_fp8():
evaluate_gemv_simt(1, 1024, 1024, "float8_e4m3", "float32", "float32", with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, "float8_e5m2", "float32", "float32", with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, T.float8_e4m3fn, T.float32, T.float32, with_bias=False)
evaluate_gemv_simt(1, 1024, 1024, T.float8_e5m2, T.float32, T.float32, with_bias=False)
if __name__ == "__main__":
......
......@@ -26,20 +26,20 @@ def tl_matmul(
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
K = K // 2
micro_size_x = micro_size_y = micro_size_k = 16
if accum_dtype == "int32":
if accum_dtype == T.int32:
micro_size_k = 32
# This is a debug config
......@@ -47,7 +47,7 @@ def tl_matmul(
block_col_warps = 2
warp_row_tiles = 64
warp_col_tiles = 64
chunk = 32 if in_dtype == "float16" else 64
chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn"
# Pipeline Stage
......@@ -197,8 +197,8 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype):
def test_assert_tl_matmul_correctness():
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32")
assert_tl_matmul_correctness(128, 128, 64, "int8", "int32", "int32")
assert_tl_matmul_correctness(128, 128, 128, T.int8, T.int32, T.int32)
assert_tl_matmul_correctness(128, 128, 64, T.int8, T.int32, T.int32)
@simplify_prim_func
......@@ -212,18 +212,18 @@ def tl_matmul_weight_only_transform(
):
K = K // 2
assert in_dtype in [
"float16",
"int8",
T.float16,
T.int8,
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
T.float16,
T.float32,
T.int32,
], "Currently only float16, float32 and int32 are supported"
micro_size_x = micro_size_y = micro_size_k = 16
if out_dtype == "int32":
if out_dtype == T.int32:
micro_size_k = 32
transform_b = 3
......@@ -233,7 +233,7 @@ def tl_matmul_weight_only_transform(
block_col_warps = 2
warp_row_tiles = 64
warp_col_tiles = 64
chunk = 32 if in_dtype == "float16" else 64
chunk = 32 if in_dtype == T.float16 else 64
shared_scope = "shared.dyn"
# Pipeline Stage
......@@ -375,8 +375,8 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt
ladder_permutate_config = bitblas.ops.LadderPermutateConfig(
M=N,
N=(K // 2),
datatype="int8",
storage_dtype="int8",
datatype=T.int8,
storage_dtype=T.int8,
transform_kind=transform_b,
transpose_matrix=True,
)
......@@ -400,9 +400,9 @@ def assert_tl_matmul_weight_only_transform_correctness(M, N, K, in_dtype, out_dt
@tilelang.testing.requires_package("bitblas")
@tilelang.testing.requires_llvm
def test_assert_tl_matmul_weight_only_transform():
assert_tl_matmul_weight_only_transform_correctness(128, 128, 128, "int8", "int32", "int32")
assert_tl_matmul_weight_only_transform_correctness(128, 128, 128, T.int8, T.int32, T.int32)
if __name__ == "__main__":
# tilelang.testing.main()
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", "int32")
assert_tl_matmul_correctness(128, 128, 128, T.int8, T.int32, T.int32)
......@@ -4,7 +4,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......@@ -43,7 +43,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
return main
def run_matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def run_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
program = matmul(M, N, K, block_M, block_N, block_K, dtype, accum_dtype)
kernel = tilelang.compile(program, out_idx=[2], target="cuda")
kernel.run_once()
......
......@@ -31,8 +31,8 @@ def blocksparse_matmul_global(
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float",
dtype=T.float16,
accum_dtype=T.float32,
):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
......@@ -75,8 +75,8 @@ def blocksparse_matmul_shared(
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float",
dtype=T.float16,
accum_dtype=T.float32,
):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
......@@ -124,8 +124,8 @@ def blocksparse_matmul_local(
num_stages,
thread_num,
enable_rasteration,
dtype="float16",
accum_dtype="float",
dtype=T.float16,
accum_dtype=T.float32,
):
block_mask_shape = (M // block_M, N // block_N, K // block_K, condition_dim)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment