Unverified Commit c750fb8a authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Update examples and tests for improved type handling functionality (#1448)

* [Enhancement] Update examples and tests for improved type handling and functionality

- Enhanced various example scripts to support new data types and improve compatibility with PyTorch.
- Updated tests across multiple modules to ensure correct functionality with the latest changes in type handling.
- Refactored code in examples to streamline operations and improve clarity, particularly in tensor operations and memory management.
- Added comprehensive tests for new features and fixed existing issues related to type conversions and buffer handling.

* [Refactor] Update accumulation data type to float32 across examples

- Changed accumulation data type from "float" to T.float32 in multiple example scripts to ensure consistency and improve numerical stability.
- This update affects various modules including flash attention, GEMM analysis, convolution, and deepseek MLA examples, enhancing type handling across the board.

* [Refactor] Standardize data type usage across benchmark scripts

- Updated data type definitions in benchmark scripts to use T.float16 and T.float32 consistently, enhancing clarity and type handling.
- Adjusted dtype assignments in matmul functions and configuration setups to align with the new standard.
- Improved overall code consistency and maintainability by ensuring uniform data type usage across various modules.

* [Refactor] Standardize data type usage in templates and scripts

- Updated data type definitions in various templates and scripts to use string representations (e.g., "float16", "int32") instead of T.float16 and T.int32 for improved consistency and clarity.
- Enhanced overall code maintainability by ensuring uniform data type usage across multiple modules, including convolution, elementwise operations, and matrix multiplication templates.
- This change aims to streamline type handling and improve compatibility with existing workflows.

* [Refactor] Standardize data type usage in examples and benchmarks

- Updated data type definitions in various example and benchmark scripts to use T.float16 and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in kernel functions and configuration setups to align with the new standard.
- Improved overall code consistency by ensuring uniform data type usage across multiple modules, including attention mechanisms, matrix multiplication, and GEMM examples.

* [Refactor] Import dtypes from language.v2 module

- Added import statement for dtypes from the language.v2 module to enhance type handling and maintain consistency across the codebase.
- This change aims to streamline data type management and improve overall code clarity.

* fix

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use string representations (e.g., "float16", "int8") instead of T.float16 and T.int8 for improved consistency and clarity.
- Adjusted dtype assignments in functions and configuration setups to align with the new standard, enhancing overall code maintainability.
- This change affects multiple modules, including benchmark and attention mechanisms, ensuring uniform data type usage throughout the codebase.

* [Refactor] Update data type handling for consistency and clarity

- Changed string representations of data types in the Hint class to use T.float32 and T.int32 for improved consistency.
- Added new data types "int4" and "int16" to the dtypes module, enhancing type support across the codebase.
- Updated function signatures and assertions in the lop3 and mxfp modules to utilize the new data types, ensuring uniformity in type handling.
- This refactor aims to streamline data type management and improve overall code clarity and maintainability.

* [Enhancement] Improve data type handling and error messaging

- Introduced a mapping for canonical data types to their display strings, enhancing clarity in type representation.
- Updated the dtype creation logic to utilize the new mapping, ensuring more intuitive handling of string inputs.
- Refined error messages in the lop3 module to provide clearer feedback on invalid source formats, improving debugging and user experience.

* [Fix] Correct boolean flag in GEMM SP test case

- Updated the boolean flag in the test_gemm_sp_sm90 function to ensure proper functionality in the test case.
- This change enhances the accuracy of the test and aligns it with expected behavior for the GEMM SP implementation.

* [Refactor] Standardize data type usage across scripts

- Updated data type definitions in various scripts to use T.float16 and T.bfloat16 consistently, enhancing clarity and maintainability.
- Adjusted dtype assignments in function signatures and argument parsing to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change affects multiple modules, including benchmarks and examples, improving overall code consistency and readability.

* [Refactor] Standardize data type usage in various modules

- Updated data type assignments in multiple scripts to utilize T.float32, T.int8, and T.int32 consistently, enhancing clarity and maintainability.
- Adjusted function signatures and parameter types across benchmarks, examples, and tests to align with the new standard, ensuring uniform data type usage throughout the codebase.
- This change improves overall code consistency and readability, impacting modules related to matrix multiplication, GEMM, and tensor operations.

* [Refactor] Update argument parsing for data types in benchmarks

- Changed argument parsing for data types in benchmark_matmul_intrinsic.py and benchmark_matmul_sp.py to use string representations ("float16", "int8", "float") instead of T.float16 and T.float.
- This update enhances consistency in data type handling across benchmark scripts, improving clarity and maintainability.

* [Refactor] Update data type handling in benchmark and example scripts

- Changed data type arguments in benchmark and example scripts to use string representations ("float16") instead of T.float16 for improved consistency.
- Updated function signatures and argument parsing to align with the new standard, enhancing clarity and maintainability across the codebase.
- This change affects multiple modules related to attention mechanisms and tensor operations, ensuring uniform data type usage throughout the examples.

* [Refactor] Fix data type conversion in multiple scripts

- Corrected the usage of the data type conversion method from dtype..as_torch() to dtype.as_torch() across various benchmark and example scripts.
- This change enhances consistency in data type handling and improves code readability, impacting modules related to attention mechanisms and tensor operations.

* [Refactor] Update float8 data type usage across multiple scripts

- Changed instances of T.float8_e4m3 to T.float8_e4m3fn in various benchmark, example, and test scripts to ensure consistency in data type handling.
- This update enhances clarity and maintainability across the codebase, particularly in modules related to matrix multiplication and tensor operations.

* [Refactor] Enhance float8 data type handling in CUDA code generation

- Updated the handling of float8 data types in the CUDA code generation to include additional float8 variants, improving type conversion logic.
- Adjusted conditions to ensure proper type checks for float8 conversions, enhancing clarity and maintainability in the codebase.
- Modified layout inference to streamline float8 type checks, ensuring consistency across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Streamline float8 data type handling in CUDA and related modules

- Enhanced float8 data type handling in CUDA code generation by refining type conversion logic and ensuring consistent type checks.
- Updated layout inference for float8 types to improve clarity and maintainability across the implementation.
- This change impacts modules related to matrix operations and CUDA code generation, improving overall type handling and conversion accuracy.

* [Refactor] Remove unnecessary cache disabling in float8 example script

- Eliminated the call to tilelang.disable_cache() in example_group_per_split_token_cast_to_fp8.py to streamline the code.
- This change enhances clarity and maintainability of the example script without affecting its functionality.

* [Refactor] Update data type usage in debug print tests

- Changed the argument for dtype in the test_debug_print_buffer function from a string representation to the corresponding T.bool type.
- This update enhances consistency in data type handling within the test suite, improving clarity and maintainability.

* lint fix

* Update function parameter types from `str` to `T.dtype` for improved type safety in attention sink and related examples

* Refactor `gemv_alloc_reducer` function signature for improved readability by formatting parameters across multiple lines.
parent 0c25c4f3
......@@ -5,7 +5,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@tilelang.jit(out_idx=[2])
def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......
......@@ -10,7 +10,7 @@ import tilelang.language as T
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},
)
def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
warp_group_num = 2
threads = 128 * warp_group_num
......
......@@ -5,7 +5,7 @@ import tilelang.language as T
# add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit
@tilelang.jit(out_idx=[2])
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor[(M, K), dtype],
......
......@@ -2,6 +2,8 @@
import pytest
from tilelang import tvm as tvm
import tilelang.testing
from tilelang import language as T
import torch
def matmul(
......@@ -24,8 +26,6 @@ def matmul(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
......@@ -74,13 +74,11 @@ def _compile_and_check(
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
if in_dtype == T.float32:
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
......@@ -148,8 +146,6 @@ def matmul_rs(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_frag_shape = A_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
......@@ -235,8 +231,6 @@ def matmul_sr(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
B_frag_shape = B_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
......@@ -323,8 +317,6 @@ def matmul_rr(
A_frag_shape = A_shared_shape
B_frag_shape = B_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
......@@ -399,9 +391,9 @@ FALSE_TRUE_CASES = (
[
pytest.param(
k,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
id=f"K{k}-float16-float16-float16",
)
for k in K_VALUES
......@@ -409,9 +401,9 @@ FALSE_TRUE_CASES = (
+ [
pytest.param(
k,
"int8",
"int32",
"int32",
T.int8,
T.int32,
T.int32,
id="K32-int8-int32-int32",
)
for k in K_VALUES_8Bit
......@@ -419,9 +411,9 @@ FALSE_TRUE_CASES = (
+ [
pytest.param(
k,
"float8_e5m2",
"float32",
"float32",
T.float8_e5m2,
T.float32,
T.float32,
id="K32-float8_e5m2-float32-float32",
)
for k in K_VALUES_8Bit
......@@ -429,9 +421,9 @@ FALSE_TRUE_CASES = (
+ [
pytest.param(
k,
"float8_e4m3",
"float32",
"float32",
T.float8_e4m3fn,
T.float32,
T.float32,
id="K32-float8_e4m3-float32-float32",
)
for k in K_VALUES_8Bit
......@@ -452,15 +444,15 @@ def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
def run_gemm_rs_false_false(m, n, k):
run_gemm_rs(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k)
run_gemm_rs(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k)
def run_gemm_rs_true_false(m, n, k):
run_gemm_rs(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k)
run_gemm_rs(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k)
def run_gemm_rs_true_true(m, n, k):
run_gemm_rs(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k)
run_gemm_rs(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k)
def run_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
......@@ -468,15 +460,15 @@ def run_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
def run_gemm_sr_false_false(m, n, k):
run_gemm_sr(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k)
run_gemm_sr(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k)
def run_gemm_sr_true_false(m, n, k):
run_gemm_sr(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k)
run_gemm_sr(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k)
def run_gemm_sr_true_true(m, n, k):
run_gemm_sr(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k)
run_gemm_sr(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k)
def run_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
......@@ -484,15 +476,15 @@ def run_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
def run_gemm_rr_false_false(m, n, k):
run_gemm_rr(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k)
run_gemm_rr(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k)
def run_gemm_rr_true_false(m, n, k):
run_gemm_rr(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k)
run_gemm_rr(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k)
def run_gemm_rr_true_true(m, n, k):
run_gemm_rr(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k)
run_gemm_rr(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k)
TRANS_CASES = [
......@@ -548,9 +540,9 @@ def test_gemm_false_false(m, n, k):
k * 3,
False,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
m,
n,
k,
......@@ -567,9 +559,9 @@ def test_gemm_true_false(m, n, k):
k * 3,
True,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
m,
n,
k,
......@@ -586,9 +578,9 @@ def test_gemm_true_true(m, n, k):
k * 3,
True,
True,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
m,
n,
k,
......@@ -607,7 +599,7 @@ def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rs_false_false(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_rs_false_false(m, n, k)
......@@ -615,7 +607,7 @@ def test_gemm_rs_false_false(m, n, k):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rs_true_false(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_rs_true_false(m, n, k)
......@@ -623,7 +615,7 @@ def test_gemm_rs_true_false(m, n, k):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rs_true_true(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_rs_true_true(m, n, k)
......@@ -639,7 +631,7 @@ def test_gemm_sr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_sr_false_false(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_sr_false_false(m, n, k)
......@@ -647,7 +639,7 @@ def test_gemm_sr_false_false(m, n, k):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_sr_true_false(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_sr_true_false(m, n, k)
......@@ -655,7 +647,7 @@ def test_gemm_sr_true_false(m, n, k):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_sr_true_true(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_sr_true_true(m, n, k)
......@@ -671,7 +663,7 @@ def test_gemm_rr_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rr_false_false(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_rr_false_false(m, n, k)
......@@ -679,7 +671,7 @@ def test_gemm_rr_false_false(m, n, k):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rr_true_false(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_rr_true_false(m, n, k)
......@@ -687,7 +679,7 @@ def test_gemm_rr_true_false(m, n, k):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rr_true_true(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_rr_true_true(m, n, k)
......@@ -699,7 +691,7 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
# run_gemm(m, n, k * 3, False, True, T.float16, T.float16, T.float16, m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
......@@ -707,7 +699,7 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} False False =============================")
# run_gemm(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
# run_gemm(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
......@@ -715,7 +707,7 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} True False =============================")
# run_gemm(m, n, k * 3, True, False, "float16", "float16", "float16", m, n, k, 2, 128)
# run_gemm(m, n, k * 3, True, False, T.float16, T.float16, T.float16, m, n, k, 2, 128)
# print(f"Test {m}, {n} {k} Pass")
# print(f"Test {n} Pass")
......@@ -724,7 +716,7 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} True True =============================")
# run_gemm(m, n, k * 3, True, True, "float16", "float16", "float16", m, n, k, 2, 128)
# run_gemm(m, n, k * 3, True, True, T.float16, T.float16, T.float16, m, n, k, 2, 128)
# print(f"Test {m}, {n} {k} Pass")
# print(f"Test {n} Pass")
......@@ -733,15 +725,15 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm_rs(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
# run_gemm_rs(m, n, k * 3, False, True, T.float16, T.float16, T.float16, m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# run_gemm_rs(64, n, k, False, False, "float16", "float16", "float16", 64, n, k, 0, 256)
# run_gemm_rs(64, n, k, False, False, T.float16, T.float16, T.float16, 64, n, k, 0, 256)
# print(f"Test {64} {n} {k} Pass")
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# run_gemm(64, n, k, False, False, "float16", "float16", "float16", 64, n, k, 0, 256)
# run_gemm(64, n, k, False, False, T.float16, T.float16, T.float16, 64, n, k, 0, 256)
# print(f"Test {64} {n} {k} Pass")
......@@ -2,6 +2,7 @@
import pytest
from tilelang import tvm as tvm
import tilelang.testing
from tilelang import language as T
def matmul(
......@@ -24,8 +25,6 @@ def matmul(
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
......@@ -81,7 +80,7 @@ def _compile_and_check(
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
if in_dtype == T.float32:
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
......@@ -147,8 +146,6 @@ def matmul_rs(
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
A_frag_shape = A_shared_shape
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
......@@ -217,18 +214,18 @@ K_VALUES = [16, 32, 64]
FALSE_TRUE_CASES = [
pytest.param(
k,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
id=f"K{k}-float16-float16-float16",
)
for k in K_VALUES
] + [
pytest.param(
k,
"float16",
"float16",
"float32",
T.float16,
T.float16,
T.float32,
id=f"K{k}-float16-float16-float32",
)
for k in K_VALUES
......@@ -248,7 +245,7 @@ def run_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
def run_gemm_rs_false_false(m, n, k):
run_gemm_rs(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
run_gemm_rs(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, 2, 128)
TRANS_CASES = [
......@@ -306,9 +303,9 @@ def test_gemm_false_false(m, n, k):
k * 3,
False,
False,
"float16",
"float16",
"float16",
T.float16,
T.float16,
T.float16,
m,
n,
k,
......@@ -329,7 +326,7 @@ def test_gemm_rs_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k", K_VALUES, ids=lambda v: f"K{v}")
def test_gemm_rs_false_false(m, n, k):
_ensure_torch_dtypes("float16")
_ensure_torch_dtypes(T.float16)
run_gemm_rs_false_false(m, n, k)
......@@ -341,7 +338,7 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64]:
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float16", "float16", m, n, k, 2, 128)
# run_gemm(m, n, k * 3, False, True, T.float16, T.float16, T.float16, m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
......@@ -349,5 +346,5 @@ if __name__ == "__main__":
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64]:
# print(f"======================= Test {m} {n} {k} False False =============================")
# run_gemm(m, n, k * 3, False, False, "float16", "float16", "float16", m, n, k, 2, 128)
# run_gemm(m, n, k * 3, False, False, T.float16, T.float16, T.float16, m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
......@@ -80,7 +80,7 @@ def _compile_and_check(
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
if in_dtype == T.float32:
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
......@@ -134,18 +134,18 @@ K_VALUES_8Bit = [32, 64, 128]
FALSE_TRUE_CASES = [
pytest.param(
k,
"float16",
"float32",
"float32",
T.float16,
T.float32,
T.float32,
id=f"K{k}-float16-float-float",
)
for k in K_VALUES
] + [
pytest.param(
k,
"float8_e5m2",
"float32",
"float32",
T.float8_e5m2,
T.float32,
T.float32,
id="K32-float8_e5m2-float32-float32",
)
for k in K_VALUES_8Bit
......@@ -195,7 +195,7 @@ if __name__ == "__main__":
# if m in [32, 64] and (n not in [64, 128, 256]):
# continue
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float", "float", m, n, k, 2, 128)
# run_gemm(m, n, k * 3, False, True, T.float16, T.float, T.float, m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
......@@ -205,7 +205,7 @@ if __name__ == "__main__":
# if m in [32, 64] and (n not in [64, 128, 256]):
# continue
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float", "float", m, n, k, 2, 256)
# run_gemm(m, n, k * 3, False, True, T.float16, T.float, T.float, m, n, k, 2, 256)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
......@@ -215,4 +215,4 @@ if __name__ == "__main__":
# if m in [32, 64] and (n not in [64, 128, 256]):
# continue
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float8_e5m2", "float", "float", m, n, k, 2, 128)
# run_gemm(m, n, k * 3, False, True, T.float8_e5m2, T.float, T.float, m, n, k, 2, 128)
......@@ -13,7 +13,7 @@ use_v2 = args.use_v2
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
......
......@@ -13,7 +13,7 @@ use_v2 = args.use_v2
# target currently can be "cuda" or "hip" or "cpu".
# if not specified, it will be inferred from the input tensors during compile time
@tilelang.jit
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def matmul_relu_kernel(
A: T.Tensor((M, K), dtype),
......
......@@ -38,8 +38,8 @@ def flashattn(batch, heads, seq_q, seq_kv, dim, is_causal, block_M=64, block_N=6
scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e)
q_shape = [batch, heads, seq_q, dim]
kv_shape = [batch, heads, seq_kv, dim]
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
past_len = seq_kv - seq_q
assert past_len >= 0, "seq_kv must be greater than or equal to seq_q"
......
......@@ -3,7 +3,7 @@ import tilelang.language as T
import torch
def make_matmul_prim(M, N, K, block_M=128, block_N=128, block_K=32, dtype="float16", accum_dtype="float"):
def make_matmul_prim(M, N, K, block_M=128, block_N=128, block_K=32, dtype=T.float16, accum_dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
......
......@@ -186,8 +186,8 @@ def make_tilelang_unary_kernel(M: int, N: int, op_id: int, use_fastmath: bool =
@T.prim_func
def tilelang_unary_kernel(
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
A: T.Tensor((M, N), T.float32),
B: T.Tensor((M, N), T.float32),
):
with T.Kernel(T.ceildiv(N, TILELANG_BLOCK_N), T.ceildiv(M, TILELANG_BLOCK_M), threads=TILELANG_THREADS) as (bx, by):
for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
......@@ -224,9 +224,9 @@ def make_tilelang_binary_kernel(M: int, N: int):
@T.prim_func
def tilelang_binary_kernel(
A: T.Tensor((M, N), "float32"),
B: T.Tensor((M, N), "float32"),
C: T.Tensor((M, N), "float32"),
A: T.Tensor((M, N), T.float32),
B: T.Tensor((M, N), T.float32),
C: T.Tensor((M, N), T.float32),
):
with T.Kernel(T.ceildiv(N, TILELANG_BLOCK_N), T.ceildiv(M, TILELANG_BLOCK_M), threads=TILELANG_THREADS) as (bx, by):
for i, j in T.Parallel(TILELANG_BLOCK_M, TILELANG_BLOCK_N):
......
......@@ -30,8 +30,8 @@ def run(M, N, K):
thread_num=None,
enable_rasteration=None,
):
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def main(
......
......@@ -1124,8 +1124,10 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
}
// Handle conversion from float32 to float8 (E4M3/E5M2)
if (from_ty.is_float() &&
(target_ty.is_float8_e4m3() || target_ty.is_float8_e5m2())) {
if (from_ty.is_float() && (target_ty.is_float8())) {
bool target_type_is_e4m3 = target_ty.is_float8_e4m3() ||
target_ty.is_float8_e4m3fn() ||
target_ty.is_float8_e4m3fnuz();
// FP32 -> FP8: Use __nv_cvt_float2_to_fp8x2 for vectorized conversion
// (float2 -> fp8x2)
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
......@@ -1134,8 +1136,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
stream << "*reinterpret_cast<__nv_fp8x2_storage_t*>(&(" << sret
<< ")) = __nv_cvt_float2_to_fp8x2(*reinterpret_cast<float2*>(&("
<< src << ")), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
<< (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n";
os << sret;
return;
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
......@@ -1144,14 +1145,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = "
<< "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src
<< ")), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
<< (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+1), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
<< (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n";
os << sret;
return;
} else if (from_ty.lanes() == 8 && target_ty.lanes() == 8) {
......@@ -1160,33 +1159,31 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[0] = "
<< "__nv_cvt_float2_to_fp8x2(*(float2*)(&(" << src
<< ")), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
<< (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[1] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+1), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
<< (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[2] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+2), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
<< (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n";
PrintIndent();
stream << "((__nv_fp8x2_storage_t*)(&" << sret << "))[3] = "
<< "__nv_cvt_float2_to_fp8x2(*((float2*)(&(" << src
<< "))+3), __NV_SATFINITE, "
<< (target_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
<< (target_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2") << ");\n";
os << sret;
return;
}
}
if ((from_ty.is_float8_e4m3() || from_ty.is_float8_e5m2()) &&
target_ty.is_float()) {
if (from_ty.is_float8() && target_ty.is_float()) {
bool from_type_is_e4m3 = from_ty.is_float8_e4m3() ||
from_ty.is_float8_e4m3fn() ||
from_ty.is_float8_e4m3fnuz();
// FP8 -> FP32: Use __tl_cvt_fp8x2_to_float2 for vectorized conversion
// (fp8x2 -> float2)
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
......@@ -1196,8 +1193,7 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
<< ")) = "
"__tl_cvt_fp8x2_to_float2(*reinterpret_cast<__nv_fp8x2_storage_"
"t*>(&("
<< src << ")), "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< src << ")), " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
......@@ -1206,14 +1202,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
PrintIndent();
stream << "*(float2*)(&" << sret << ") = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[0], "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< "))[0], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "*((float2*)(&" << sret << ")+1) = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[1], "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< "))[1], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
......@@ -1222,26 +1216,22 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
PrintIndent();
stream << "*(float2*)(&" << sret << ") = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[0], "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< "))[0], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "*((float2*)(&" << sret << ")+1) = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[1], "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< "))[1], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "*((float2*)(&" << sret << ")+2) = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[2], "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< "))[2], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
PrintIndent();
stream << "*((float2*)(&" << sret << ")+3) = "
<< "__tl_cvt_fp8x2_to_float2(((__nv_fp8x2_storage_t*)(&" << src
<< "))[3], "
<< (from_ty.is_float8_e4m3() ? "__NV_E4M3" : "__NV_E5M2")
<< "))[3], " << (from_type_is_e4m3 ? "__NV_E4M3" : "__NV_E5M2")
<< ");\n";
os << sret;
return;
......
......@@ -1179,10 +1179,10 @@ private:
// Check if this is a non-reducer store with Cast operation
DataType src_type = cast->value.dtype();
DataType dst_type = cast->dtype;
bool src_ok = src_type.is_float() || src_type.is_bfloat() ||
src_type.is_float8_e4m3() || src_type.is_float8_e5m2();
bool dst_ok = dst_type.is_float() || dst_type.is_bfloat() ||
dst_type.is_float8_e4m3() || dst_type.is_float8_e5m2();
bool src_ok =
src_type.is_float() || src_type.is_bfloat() || src_type.is_float8();
bool dst_ok =
dst_type.is_float() || dst_type.is_bfloat() || dst_type.is_float8();
if (src_ok && dst_ok && TargetIsCuda(Target::Current())) {
has_cast_operations = true;
}
......
......@@ -26,7 +26,7 @@ def tl_matmul(
):
micro_size_x = micro_size_y = micro_size_k = 16
if in_dtype in {"float8_e4m3fnuz", "int8"}:
if in_dtype in {T.float8_e4m3fnuz, T.int8}:
micro_size_k = 32
block_row_warps = 2
......@@ -160,7 +160,7 @@ def tl_matmul(
return main
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="float32", a_transposed=False, b_transposed=True, k_pack=1):
def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype=T.float32, a_transposed=False, b_transposed=True, k_pack=1):
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack)
print(matmul)
kernel = tilelang.compile(matmul)
......@@ -169,10 +169,10 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa
assert src_code is not None
A_shape = (K, M) if a_transposed else (M, K)
B_shape = (N, K) if b_transposed else (K, N)
if in_dtype == "int8":
if in_dtype == T.int8:
A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8)
elif in_dtype == "float8_e4m3fnuz":
elif in_dtype == T.float8_e4m3fnuz:
A = torch.rand(A_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
else:
......@@ -211,15 +211,15 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa
@pytest.mark.parametrize(
"M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack",
[
(128, 128, 128, "float16", "float16", "float32", False, True, 1),
(128, 256, 256, "float16", "float32", "float32", False, True, 1),
(128, 256, 256, "float16", "float32", "float32", False, True, 2),
(128, 128, 128, "int8", "int32", "int32", False, True, 1),
(128, 256, 256, "int8", "int32", "int32", False, True, 1),
(128, 256, 256, "int8", "int32", "int32", False, True, 2),
(128, 256, 256, "int8", "int32", "int32", False, False, 1),
(128, 256, 256, "int8", "int32", "int32", False, False, 2),
(128, 128, 128, "float8_e4m3fnuz", "float16", "float32", False, True, 1),
(128, 128, 128, T.float16, T.float16, T.float32, False, True, 1),
(128, 256, 256, T.float16, T.float32, T.float32, False, True, 1),
(128, 256, 256, T.float16, T.float32, T.float32, False, True, 2),
(128, 128, 128, T.int8, T.int32, T.int32, False, True, 1),
(128, 256, 256, T.int8, T.int32, T.int32, False, True, 1),
(128, 256, 256, T.int8, T.int32, T.int32, False, True, 2),
(128, 256, 256, T.int8, T.int32, T.int32, False, False, 1),
(128, 256, 256, T.int8, T.int32, T.int32, False, False, 2),
(128, 128, 128, T.float8_e4m3fnuz, T.float16, T.float32, False, True, 1),
],
)
@tilelang.testing.requires_rocm
......@@ -235,10 +235,10 @@ def test_assert_tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transpose
b_transposed=b_transposed,
k_pack=k_pack,
)
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32")
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2)
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False)
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False, k_pack=2)
assert_tl_matmul_correctness(128, 256, 256, T.float8_e4m3fnuz, T.float32)
assert_tl_matmul_correctness(128, 256, 256, T.float8_e4m3fnuz, T.float32, k_pack=2)
assert_tl_matmul_correctness(128, 256, 256, T.float8_e4m3fnuz, T.float32, b_transposed=False)
assert_tl_matmul_correctness(128, 256, 256, T.float8_e4m3fnuz, T.float32, b_transposed=False, k_pack=2)
if __name__ == "__main__":
......
......@@ -26,7 +26,7 @@ def tl_matmul(
):
micro_size_x = micro_size_y = micro_size_k = 16
if in_dtype in {"float8_e4m3fnuz", "int8"}:
if in_dtype in {T.float8_e4m3fnuz, T.int8}:
micro_size_k = 32
block_row_warps = 2
......@@ -196,7 +196,7 @@ def assert_tl_matmul_correctness(
K,
in_dtype,
out_dtype,
accum_dtype="float32",
accum_dtype=T.float32,
a_transposed=False,
b_transposed=True,
k_pack=1,
......@@ -211,10 +211,10 @@ def assert_tl_matmul_correctness(
assert src_code is not None
A_shape = (K, M) if a_transposed else (M, K)
B_shape = (N, K) if b_transposed else (K, N)
if in_dtype == "int8":
if in_dtype == T.int8:
A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8)
B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8)
elif in_dtype == "float8_e4m3fnuz":
elif in_dtype == T.float8_e4m3fnuz:
A = torch.rand(A_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
B = torch.rand(B_shape, device="cuda", dtype=torch.float16).to(getattr(torch, in_dtype))
else:
......@@ -261,14 +261,14 @@ def assert_tl_matmul_correctness(
@pytest.mark.parametrize(
"M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack, b_preshuffle, b_g2l_load",
[
(256, 256, 512, "int8", "int32", "int32", False, True, 1, True, False),
(256, 256, 512, "int8", "int32", "int32", False, False, 1, True, False),
(256, 256, 512, "int8", "int32", "int32", False, True, 2, True, False),
(256, 256, 512, "int8", "int32", "int32", False, False, 2, True, False),
(256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, True, 1, True, False),
(256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, False, 1, True, False),
(256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, True, 2, True, False),
(256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, False, 2, True, False),
(256, 256, 512, T.int8, T.int32, T.int32, False, True, 1, True, False),
(256, 256, 512, T.int8, T.int32, T.int32, False, False, 1, True, False),
(256, 256, 512, T.int8, T.int32, T.int32, False, True, 2, True, False),
(256, 256, 512, T.int8, T.int32, T.int32, False, False, 2, True, False),
(256, 256, 512, T.float8_e4m3fnuz, T.float32, T.float32, False, True, 1, True, False),
(256, 256, 512, T.float8_e4m3fnuz, T.float32, T.float32, False, False, 1, True, False),
(256, 256, 512, T.float8_e4m3fnuz, T.float32, T.float32, False, True, 2, True, False),
(256, 256, 512, T.float8_e4m3fnuz, T.float32, T.float32, False, False, 2, True, False),
],
)
@tilelang.testing.requires_rocm
......
......@@ -108,7 +108,7 @@ def run_gemm(
)
@tilelang.testing.requires_rocm
def test_gemm_f16f32f32_nt(trans_A, trans_B, k_pack):
run_gemm(1024, 1024, 1024, trans_A, trans_B, "float16", "float32", "float32", 128, 128, 32, k_pack=k_pack)
run_gemm(1024, 1024, 1024, trans_A, trans_B, T.float16, T.float32, T.float32, 128, 128, 32, k_pack=k_pack)
@pytest.mark.parametrize(
......@@ -123,7 +123,7 @@ def test_gemm_f16f32f32_nt(trans_A, trans_B, k_pack):
)
@tilelang.testing.requires_rocm
def test_gemm_bf16f32f32_nt(trans_A, trans_B, k_pack):
run_gemm(1024, 1024, 1024, trans_A, trans_B, "bfloat16", "float32", "float32", 128, 128, 32, k_pack=k_pack)
run_gemm(1024, 1024, 1024, trans_A, trans_B, T.bfloat16, T.float32, T.float32, 128, 128, 32, k_pack=k_pack)
@pytest.mark.parametrize(
......@@ -138,7 +138,7 @@ def test_gemm_bf16f32f32_nt(trans_A, trans_B, k_pack):
)
@tilelang.testing.requires_rocm
def test_gemm_bf16bf16f32(trans_A, trans_B, k_pack):
run_gemm(1024, 1024, 1024, trans_A, trans_B, "bfloat16", "bfloat16", "float32", 128, 128, 32, k_pack=k_pack)
run_gemm(1024, 1024, 1024, trans_A, trans_B, T.bfloat16, T.bfloat16, T.float32, 128, 128, 32, k_pack=k_pack)
def matmul_rs(
......@@ -241,24 +241,24 @@ def run_gemm_rs(
# @tilelang.testing.requires_rocm
# def test_gemm_rs_f16f32f32_nt():
# run_gemm_rs(1024, 1024, 1024, False, False, "float16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, True, "float16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, False, "float16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, False, T.float16, T.float32, T.float32, 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, True, T.float16, T.float32, T.float32, 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, True, T.float16, T.float32, T.float32, 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, False, T.float16, T.float32, T.float32, 128, 128, 32)
# @tilelang.testing.requires_rocm
# def test_gemm_rs_bf16f32f32_nt():
# run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, False, T.bfloat16, T.float32, T.float32, 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, True, T.bfloat16, T.float32, T.float32, 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, True, T.bfloat16, T.float32, T.float32, 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, False, T.bfloat16, T.float32, T.float32, 128, 128, 32)
# @tilelang.testing.requires_rocm
# def test_gemm_rs_bf16bf16f32_nt():
# run_gemm_rs(1024, 1024, 1024, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, False, T.bfloat16, T.bfloat16, T.float32, 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, False, True, T.bfloat16, T.bfloat16, T.float32, 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, True, T.bfloat16, T.bfloat16, T.float32, 128, 128, 32)
# run_gemm_rs(1024, 1024, 1024, True, False, T.bfloat16, T.bfloat16, T.float32, 128, 128, 32)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -5,7 +5,7 @@ import pytest
@tilelang.jit
def simple_invalid_loop(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128):
def simple_invalid_loop(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128):
A = T.dynamic("A")
@T.prim_func
......@@ -26,7 +26,7 @@ def simple_invalid_loop(dtype: str = "bfloat16", accum_dtype: str = "float32", n
@tilelang.jit
def nested_invalid_loop(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128):
def nested_invalid_loop(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128):
A = T.dynamic("A")
@T.prim_func
......@@ -48,7 +48,7 @@ def nested_invalid_loop(dtype: str = "bfloat16", accum_dtype: str = "float32", n
@tilelang.jit
def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128):
def invalid_loop_with_complex_dataflow(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128):
A = T.dynamic("A")
@T.prim_func
......@@ -69,7 +69,7 @@ def invalid_loop_with_complex_dataflow(dtype: str = "bfloat16", accum_dtype: str
@tilelang.jit
def valid_loop_not_use_loop_var(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128):
def valid_loop_not_use_loop_var(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128):
A = T.dynamic("A")
@T.prim_func
......@@ -91,7 +91,7 @@ def valid_loop_not_use_loop_var(dtype: str = "bfloat16", accum_dtype: str = "flo
@tilelang.jit
def valid_loop_not_frag(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128):
def valid_loop_not_frag(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128):
A = T.dynamic("A")
@T.prim_func
......@@ -112,7 +112,7 @@ def valid_loop_not_frag(dtype: str = "bfloat16", accum_dtype: str = "float32", n
@tilelang.jit
def valid_loop_serial(dtype: str = "bfloat16", accum_dtype: str = "float32", num_threads: int = 128):
def valid_loop_serial(dtype: T.dtype = T.bfloat16, accum_dtype: T.dtype = T.float32, num_threads: int = 128):
A = T.dynamic("A")
@T.prim_func
......
......@@ -29,7 +29,7 @@ Rule:
@tilelang.jit(out_idx=[1])
def nested_continuous_parallels(length=256, block=16, dtype="float32"):
def nested_continuous_parallels(length=256, block=16, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -44,7 +44,7 @@ def nested_continuous_parallels(length=256, block=16, dtype="float32"):
@tilelang.jit(out_idx=[1])
def nested_triple_continuous_parallels(length=256, block1=8, block2=2, dtype="float32"):
def nested_triple_continuous_parallels(length=256, block1=8, block2=2, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -60,7 +60,7 @@ def nested_triple_continuous_parallels(length=256, block1=8, block2=2, dtype="fl
@tilelang.jit(out_idx=[1])
def nested_noncontinuous_parallels(length=256, block=16, dtype="float32"):
def nested_noncontinuous_parallels(length=256, block=16, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -149,9 +149,9 @@ def run_gemm_nested_pipelines(
block_K = 32
trans_A = False
trans_B = False
in_dtype = "float16"
out_dtype = "float16"
dtypeAccum = "float32"
in_dtype = T.float16
out_dtype = T.float16
dtypeAccum = T.float32
num_threads = 128
program = matmul_nested_pipelines(
M,
......@@ -188,7 +188,7 @@ def run_gemm_nested_pipelines(
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
if in_dtype == T.float32:
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
......@@ -215,7 +215,7 @@ is OK.
@tilelang.jit(out_idx=[1])
def nested_continuous_serials(length=256, block=16, dtype="float32"):
def nested_continuous_serials(length=256, block=16, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -230,7 +230,7 @@ def nested_continuous_serials(length=256, block=16, dtype="float32"):
@tilelang.jit(out_idx=[1])
def nested_noncontinuous_serials(length=256, block=16, dtype="float32"):
def nested_noncontinuous_serials(length=256, block=16, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -272,7 +272,7 @@ Rule:
@tilelang.jit(out_idx=[1])
def nested_continuous_sp(length=256, block=16, dtype="float32"):
def nested_continuous_sp(length=256, block=16, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -287,7 +287,7 @@ def nested_continuous_sp(length=256, block=16, dtype="float32"):
@tilelang.jit(out_idx=[1])
def nested_continuous_ps(length=256, block=16, dtype="float32"):
def nested_continuous_ps(length=256, block=16, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -302,7 +302,7 @@ def nested_continuous_ps(length=256, block=16, dtype="float32"):
@tilelang.jit(out_idx=[1])
def nested_continuous_psp(length=256, block1=8, block2=2, dtype="float32"):
def nested_continuous_psp(length=256, block1=8, block2=2, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -318,7 +318,7 @@ def nested_continuous_psp(length=256, block1=8, block2=2, dtype="float32"):
@tilelang.jit(out_idx=[1])
def nested_continuous_sps(length=256, block1=8, block2=2, dtype="float32"):
def nested_continuous_sps(length=256, block1=8, block2=2, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -469,9 +469,9 @@ def run_gemm_mixed_pp(
block_M = 128
block_N = 128
block_K = 32
in_dtype = "float16"
out_dtype = "float16"
dtypeAccum = "float32"
in_dtype = T.float16
out_dtype = T.float16
dtypeAccum = T.float32
num_threads = 128
program = matmul_nested_pipa(
......@@ -502,7 +502,7 @@ def run_gemm_mixed_pp(
def ref_program(A, B):
import torch
if in_dtype == "float32":
if in_dtype == T.float32:
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
......@@ -603,9 +603,9 @@ def run_gemm_tiled_op_with_parallel(
block_M = 128
block_N = 128
block_K = 32
in_dtype = "float16"
out_dtype = "float16"
dtypeAccum = "float32"
in_dtype = T.float16
out_dtype = T.float16
dtypeAccum = T.float32
num_threads = 128
program = matmul_nested_pipa(
......@@ -636,7 +636,7 @@ def run_gemm_tiled_op_with_parallel(
def ref_program(A, B):
import torch
if in_dtype == "float32":
if in_dtype == T.float32:
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
......@@ -673,7 +673,7 @@ def run_gemm_tiled_op_with_parallel(
@tilelang.jit(out_idx=[1])
def tir_op_with_parallel(length=256, block=16, dtype="float32"):
def tir_op_with_parallel(length=256, block=16, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......@@ -688,7 +688,7 @@ def tir_op_with_parallel(length=256, block=16, dtype="float32"):
@tilelang.jit(out_idx=[1])
def customize_op_with_parallel(length=256, block=16, dtype="float32"):
def customize_op_with_parallel(length=256, block=16, dtype=T.float32):
@T.prim_func
def main(
A: T.Tensor((length,), dtype),
......
......@@ -57,9 +57,9 @@ def get_configs(M, N, K, with_roller=False):
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
in_dtype=T.float16,
out_dtype=T.float16,
accum_dtype=T.float16,
).with_arch(arch)
func = carve_template.equivalent_function()
......@@ -187,8 +187,8 @@ def matmul(M, N, K, with_roller):
"""
# Use half-precision for input data to reduce memory bandwidth,
# accumulate in float for better numerical accuracy
dtype = "float16"
accum_dtype = "float"
dtype = T.float16
accum_dtype = T.float32
@T.prim_func
def main(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment