"example/vscode:/vscode.git/clone" did not exist on "87fd11526f32b09baa1c72735d1ee4c4fb412dc0"
Unverified Commit 0b6336b5 authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Refactor] Use `pytest.mark.parameterize` to speedup parallel testing (#1447)



* Refactor GEMM tests to use parameterized pytest fixtures

- Converted multiple test cases for GEMM operations in `test_tilelang_tilelibrary_gemm_sp.py` to use `pytest.mark.parametrize` for better maintainability and readability.
- Similar refactoring applied to `test_tilelang_tilelibrary_gemm_sp_v2.py`, consolidating test cases for `run_gemm_ss`, `run_gemm_rs`, `run_gemm_sr`, and `run_gemm_rr` into parameterized tests.
- This change reduces code duplication and enhances the clarity of test configurations.

* Update testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

---------
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
parent dda45126
import pytest
import torch import torch
import tilelang.testing import tilelang.testing
from tilelang import tvm as tvm from tilelang import tvm as tvm
...@@ -207,17 +208,33 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa ...@@ -207,17 +208,33 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@pytest.mark.parametrize(
"M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack",
[
(128, 128, 128, "float16", "float16", "float32", False, True, 1),
(128, 256, 256, "float16", "float32", "float32", False, True, 1),
(128, 256, 256, "float16", "float32", "float32", False, True, 2),
(128, 128, 128, "int8", "int32", "int32", False, True, 1),
(128, 256, 256, "int8", "int32", "int32", False, True, 1),
(128, 256, 256, "int8", "int32", "int32", False, True, 2),
(128, 256, 256, "int8", "int32", "int32", False, False, 1),
(128, 256, 256, "int8", "int32", "int32", False, False, 2),
(128, 128, 128, "float8_e4m3fnuz", "float16", "float32", False, True, 1),
],
)
@tilelang.testing.requires_rocm @tilelang.testing.requires_rocm
def test_assert_tl_matmul(): def test_assert_tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack):
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16") assert_tl_matmul_correctness(
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32") M,
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", k_pack=2) N,
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32") K,
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32") in_dtype,
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2) out_dtype,
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32") accum_dtype=accum_dtype,
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2) a_transposed=a_transposed,
assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3fnuz", "float16") b_transposed=b_transposed,
k_pack=k_pack,
)
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32") assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32")
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2) assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2)
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False) assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False)
......
import pytest
import torch import torch
import tilelang.testing import tilelang.testing
from tilelang import tvm as tvm from tilelang import tvm as tvm
...@@ -257,19 +258,46 @@ def assert_tl_matmul_correctness( ...@@ -257,19 +258,46 @@ def assert_tl_matmul_correctness(
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@pytest.mark.parametrize(
"M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack, b_preshuffle, b_g2l_load",
[
(256, 256, 512, "int8", "int32", "int32", False, True, 1, True, False),
(256, 256, 512, "int8", "int32", "int32", False, False, 1, True, False),
(256, 256, 512, "int8", "int32", "int32", False, True, 2, True, False),
(256, 256, 512, "int8", "int32", "int32", False, False, 2, True, False),
(256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, True, 1, True, False),
(256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, False, 1, True, False),
(256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, True, 2, True, False),
(256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, False, 2, True, False),
],
)
@tilelang.testing.requires_rocm @tilelang.testing.requires_rocm
def test_assert_tl_matmul(): def test_assert_tl_matmul(
assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True) M,
assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True) N,
assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True) K,
in_dtype,
assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True) out_dtype,
assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2, b_preshuffle=True) accum_dtype,
a_transposed,
assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_preshuffle=True) b_transposed,
assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_transposed=False, b_preshuffle=True) k_pack,
assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_preshuffle=True) b_preshuffle,
assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_transposed=False, b_preshuffle=True) b_g2l_load,
):
assert_tl_matmul_correctness(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype=accum_dtype,
a_transposed=a_transposed,
b_transposed=b_transposed,
k_pack=k_pack,
b_preshuffle=b_preshuffle,
b_g2l_load=b_g2l_load,
)
if __name__ == "__main__": if __name__ == "__main__":
......
import pytest
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang as tl import tilelang as tl
import tilelang.language as T import tilelang.language as T
...@@ -95,31 +96,49 @@ def run_gemm( ...@@ -95,31 +96,49 @@ def run_gemm(
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
@pytest.mark.parametrize(
"trans_A, trans_B, k_pack",
[
(False, False, 1),
(False, True, 1),
(True, True, 1),
(True, False, 1),
(False, True, 2),
],
)
@tilelang.testing.requires_rocm @tilelang.testing.requires_rocm
def test_gemm_f16f32f32_nt(): def test_gemm_f16f32f32_nt(trans_A, trans_B, k_pack):
run_gemm(1024, 1024, 1024, False, False, "float16", "float32", "float32", 128, 128, 32) run_gemm(1024, 1024, 1024, trans_A, trans_B, "float16", "float32", "float32", 128, 128, 32, k_pack=k_pack)
run_gemm(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, True, "float16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, False, "float16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32, k_pack=2)
@pytest.mark.parametrize(
"trans_A, trans_B, k_pack",
[
(False, False, 1),
(False, True, 1),
(True, True, 1),
(True, False, 1),
(False, True, 2),
],
)
@tilelang.testing.requires_rocm @tilelang.testing.requires_rocm
def test_gemm_bf16f32f32_nt(): def test_gemm_bf16f32f32_nt(trans_A, trans_B, k_pack):
run_gemm(1024, 1024, 1024, False, False, "bfloat16", "float32", "float32", 128, 128, 32) run_gemm(1024, 1024, 1024, trans_A, trans_B, "bfloat16", "float32", "float32", 128, 128, 32, k_pack=k_pack)
run_gemm(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32, k_pack=2)
@pytest.mark.parametrize(
"trans_A, trans_B, k_pack",
[
(False, False, 1),
(False, True, 1),
(True, True, 1),
(True, False, 1),
(False, True, 2),
],
)
@tilelang.testing.requires_rocm @tilelang.testing.requires_rocm
def test_gemm_bf16bf16f32(): def test_gemm_bf16bf16f32(trans_A, trans_B, k_pack):
run_gemm(1024, 1024, 1024, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32) run_gemm(1024, 1024, 1024, trans_A, trans_B, "bfloat16", "bfloat16", "float32", 128, 128, 32, k_pack=k_pack)
run_gemm(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32, k_pack=2)
def matmul_rs( def matmul_rs(
......
import pytest
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
import torch import torch
...@@ -242,13 +243,9 @@ def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32, ...@@ -242,13 +243,9 @@ def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32,
print(f"✓ {mathop_name} numerical test passed") print(f"✓ {mathop_name} numerical test passed")
@tilelang.testing.requires_cuda @pytest.mark.parametrize(
def test_mathops_generate_no_fastmath(): "name, func",
"""Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)""" [
# Based on test results, our tl.* intrinsics actually generate
# no fastmath versions
# This appears to be the intended behavior
single_arg_mathops = [
("exp", T.exp), ("exp", T.exp),
("exp2", T.exp2), ("exp2", T.exp2),
("exp10", T.exp10), ("exp10", T.exp10),
...@@ -270,24 +267,26 @@ def test_mathops_generate_no_fastmath(): ...@@ -270,24 +267,26 @@ def test_mathops_generate_no_fastmath():
("trunc", T.trunc), ("trunc", T.trunc),
("round", T.round), ("round", T.round),
("nearbyint", T.nearbyint), ("nearbyint", T.nearbyint),
] ],
)
for name, func in single_arg_mathops: @tilelang.testing.requires_cuda
run_single_arg_mathop_test(name, func, dtype="float32") def test_mathops_generate_no_fastmath(name, func):
print(f"✓ {name} test passed") """Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)"""
run_single_arg_mathop_test(name, func, dtype="float32")
print(f"✓ {name} test passed")
@tilelang.testing.requires_cuda @pytest.mark.parametrize(
def test_two_arg_mathops_fastmath(): "name, func",
"""Test all two-argument mathops""" [
# Two argument mathops
two_arg_mathops = [
("pow", T.pow), ("pow", T.pow),
("fmod", T.fmod), ("fmod", T.fmod),
] ],
)
for name, func in two_arg_mathops: @tilelang.testing.requires_cuda
run_two_arg_mathop_test(name, func, dtype="float32") def test_two_arg_mathops_fastmath(name, func):
"""Test all two-argument mathops"""
run_two_arg_mathop_test(name, func, dtype="float32")
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
...@@ -296,11 +295,9 @@ def test_abs_maps_to_fabs(): ...@@ -296,11 +295,9 @@ def test_abs_maps_to_fabs():
run_abs_test() run_abs_test()
@tilelang.testing.requires_cuda @pytest.mark.parametrize(
def test_fastmath_versions(): "name, func",
"""Test that __exp, __exp10, __log, __log2, __log10, __tan, __cos, __sin generate fastmath CUDA code""" [
# Test fastmath versions
fastmath_mathops = [
("__exp", T.__exp), ("__exp", T.__exp),
("__exp10", T.__exp10), ("__exp10", T.__exp10),
("__log", T.__log), ("__log", T.__log),
...@@ -309,11 +306,13 @@ def test_fastmath_versions(): ...@@ -309,11 +306,13 @@ def test_fastmath_versions():
("__tan", T.__tan), ("__tan", T.__tan),
("__cos", T.__cos), ("__cos", T.__cos),
("__sin", T.__sin), ("__sin", T.__sin),
] ],
)
for name, func in fastmath_mathops: @tilelang.testing.requires_cuda
run_fastmath_mathop_test(name, func, dtype="float32") def test_fastmath_versions(name, func):
print(f"✓ {name} test passed") """Test that __exp, __exp10, __log, __log2, __log10, __tan, __cos, __sin generate fastmath CUDA code"""
run_fastmath_mathop_test(name, func, dtype="float32")
print(f"✓ {name} test passed")
if __name__ == "__main__": if __name__ == "__main__":
......
import pytest
import torch import torch
import tilelang.testing import tilelang.testing
import tilelang.language as T import tilelang.language as T
...@@ -77,38 +78,29 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str, ...@@ -77,38 +78,29 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
assert check_str in code and check_str in code_parallel, f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!" assert check_str in code and check_str in code_parallel, f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"
def test_vectorized_cast(): @pytest.mark.parametrize(
# fp32 -> fp16 "src_dtype, dst_dtype, check_str, lanes",
run_vectorized_cast("float32", "float16", "__float22half2_rn", 2) [
run_vectorized_cast("float32", "float16", "__float22half2_rn", 4) ("float32", "float16", "__float22half2_rn", 2),
("float32", "float16", "__float22half2_rn", 4),
# fp16 -> fp32 ("float16", "float32", "__half22float2", 2),
run_vectorized_cast("float16", "float32", "__half22float2", 2) ("float16", "float32", "__half22float2", 4),
run_vectorized_cast("float16", "float32", "__half22float2", 4) ("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 2),
("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 4),
# fp32 -> fp8_e4m3 ("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 2),
run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 2) ("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 4),
run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 4) ("float32", "bfloat16", "__float22bfloat162_rn", 2),
("float32", "bfloat16", "__float22bfloat162_rn", 4),
# fp32 -> fp8_e5m2 ("bfloat16", "float32", "__bfloat1622float2", 2),
run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 2) ("bfloat16", "float32", "__bfloat1622float2", 4),
run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 4) ("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 2),
("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 4),
# fp32 -> bf16 ("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 2),
run_vectorized_cast("float32", "bfloat16", "__float22bfloat162_rn", 2) ("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 4),
run_vectorized_cast("float32", "bfloat16", "__float22bfloat162_rn", 4) ],
)
# bf16 -> fp32 def test_vectorized_cast(src_dtype, dst_dtype, check_str, lanes):
run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 2) run_vectorized_cast(src_dtype, dst_dtype, check_str, lanes)
run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 4)
# fp8_e4m3 -> fp32
run_vectorized_cast("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 2)
run_vectorized_cast("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 4)
# fp8_e5m2 -> fp32
run_vectorized_cast("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 2)
run_vectorized_cast("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 4)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -109,30 +109,27 @@ def run_gemm_ss( ...@@ -109,30 +109,27 @@ def run_gemm_ss(
@pytest.mark.skip(reason="Temporarily disabling until GEMM SS issues are resolved") @pytest.mark.skip(reason="Temporarily disabling until GEMM SS issues are resolved")
def test_gemm_ss(): @pytest.mark.parametrize(
# More test case can be found in kernel/test_tilelang_kernel_gemm.py "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
# GEMM tests for float16 [
run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 2) (512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 2, 128),
run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 2) (512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 2, 128),
run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 2) (512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 2, 128),
run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 2) (512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 2, 128),
# n8 test (128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128),
run_gemm_ss(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128) (128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
# int8 test (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
run_gemm_ss(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2) (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
run_gemm_ss(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2) (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128),
run_gemm_ss(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2) (128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128),
run_gemm_ss(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2) (128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128),
# float8 tests (128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128),
run_gemm_ss(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) ],
)
# tfloat32 test def test_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
run_gemm_ss(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) run_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads)
run_gemm_ss(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_ss(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_ss(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
def matmul_rs( def matmul_rs(
...@@ -247,30 +244,27 @@ def run_gemm_rs( ...@@ -247,30 +244,27 @@ def run_gemm_rs(
@pytest.mark.skip(reason="Temporarily disabling until GEMM RS issues are resolved") @pytest.mark.skip(reason="Temporarily disabling until GEMM RS issues are resolved")
def test_gemm_rs(): @pytest.mark.parametrize(
# GEMM tests for float16 "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) [
run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2) (512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2, 128),
run_gemm_rs(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2) (512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2, 128),
run_gemm_rs(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2) (512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2, 128),
# n8 tests (128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128),
run_gemm_rs(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128) (128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
# int8 tests (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
run_gemm_rs(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2) (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
run_gemm_rs(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2) (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128),
run_gemm_rs(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2) (128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128),
run_gemm_rs(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2) (128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128),
# float8 tests (128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128),
run_gemm_rs(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) ],
)
# float32 tests def test_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
run_gemm_rs(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) run_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads)
run_gemm_rs(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_rs(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_rs(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
def matmul_sr( def matmul_sr(
...@@ -384,31 +378,27 @@ def run_gemm_sr( ...@@ -384,31 +378,27 @@ def run_gemm_sr(
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_sr(): @pytest.mark.parametrize(
# GEMM tests for float16 "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
run_gemm_sr(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) [
run_gemm_sr(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2) (512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2, 128),
run_gemm_sr(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2) (512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2, 128),
run_gemm_sr(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2) (512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2, 128),
# n8 tests (128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128),
run_gemm_sr(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128) (128, 128, 32, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 32, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
# int8 tests (128, 128, 32, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
run_gemm_sr(128, 128, 32, False, True, "int8", "int8", "int32", 128, 128, 32, 2) (128, 128, 32, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
run_gemm_sr(128, 128, 32, False, False, "int8", "int8", "int32", 128, 128, 32, 2) (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128),
run_gemm_sr(128, 128, 32, True, False, "int8", "int8", "int32", 128, 128, 32, 2) (128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128),
run_gemm_sr(128, 128, 32, True, True, "int8", "int8", "int32", 128, 128, 32, 2) (128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128),
# float8 tests (128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128),
run_gemm_sr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) ],
)
# float32 tests def test_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
# TODO(lei): fix in future run_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads)
run_gemm_sr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
def matmul_rr( def matmul_rr(
...@@ -526,31 +516,29 @@ def run_gemm_rr( ...@@ -526,31 +516,29 @@ def run_gemm_rr(
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_rr(): @pytest.mark.parametrize(
# GEMM tests for float16 "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
run_gemm_rr(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2) [
run_gemm_rr(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2) (512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2, 128),
run_gemm_rr(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2) (512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2, 128),
run_gemm_rr(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2) (512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2, 128),
run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2) (512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2, 128),
# n8 tests (512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2, 128),
run_gemm_rr(128, 8, 128, False, True, "float16", "float16", "float16", 128, 8, 32, 2) (128, 8, 128, False, True, "float16", "float16", "float16", 128, 8, 32, 2, 128),
run_gemm_rr(128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 32, 2) (128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 32, 2, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
# int8 tests (128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
run_gemm_rr(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2) (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
run_gemm_rr(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2) (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
run_gemm_rr(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2) (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128),
run_gemm_rr(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2) (128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128),
# float8 tests (128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128),
run_gemm_rr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2) (128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128),
],
# float32 tests )
run_gemm_rr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2) def test_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
run_gemm_rr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2) run_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads)
run_gemm_rr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_rr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
if __name__ == "__main__": if __name__ == "__main__":
......
import pytest
import torch import torch
import tilelang import tilelang
import tilelang.testing import tilelang.testing
...@@ -303,50 +304,53 @@ def run_gemm_sp_sm80( ...@@ -303,50 +304,53 @@ def run_gemm_sp_sm80(
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0) @tilelang.testing.requires_cuda_compute_version(9, 0)
def test_gemm_sp_sm90(): @pytest.mark.parametrize(
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 2, 128) "M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B",
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 0, 256) [
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 2, 128, False, False),
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128) (512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 0, 256, False, False),
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128) (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128, False, False),
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 0, 128) (512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 0, 128, False, False),
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 2, 128) (512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 2, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 0, 128, False, False),
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 0, 128) (512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 2, 128, False, False),
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 2, 128) (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, False),
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True) (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, True),
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, False) (512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False, True),
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, True) (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True),
],
run_gemm_sp_sm90(512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False, True) )
run_gemm_sp_sm90(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True) def test_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B):
run_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B)
@tilelang.testing.requires_cuda @tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(8, 0) @tilelang.testing.requires_cuda_compute_version_ge(8, 0)
@tilelang.testing.requires_cuda_compute_version_le(8, 9) @tilelang.testing.requires_cuda_compute_version_le(8, 9)
def test_gemm_sp_sm80(): @pytest.mark.parametrize(
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 32, 0, 32) "M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B",
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32) [
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128) (512, 1024, 768, "float16", "float32", "float32", 32, 32, 32, 0, 32, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, False),
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False, True) (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, False),
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, True) (512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False, True),
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True) (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, True),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True),
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 1, 128) (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 1, 128, False, False),
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128) (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128, False, False),
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 3, 128) (512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 3, 128, False, False),
(512, 1024, 768, "int8", "int32", "int32", 32, 32, 64, 0, 32, False, True),
run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 32, 32, 64, 0, 32, False, True) (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 0, 32, False, True),
run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 0, 32, False, True) (512, 1024, 768, "int8", "int32", "int32", 128, 128, 128, 0, 128, False, True),
run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 128, 128, 128, 0, 128, False, True) (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 1, 128, False, True),
(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True),
run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 1, 128, False, True) (512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 3, 128, False, True),
run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True) ],
run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 3, 128, False, True) )
def test_gemm_sp_sm80(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B):
run_gemm_sp_sm80(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B)
if __name__ == "__main__": if __name__ == "__main__":
......
import pytest
from tilelang import tvm as tvm from tilelang import tvm as tvm
from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse
from tilelang.utils.tensor import torch_assert_close, map_torch_type from tilelang.utils.tensor import torch_assert_close, map_torch_type
...@@ -153,33 +154,24 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype): ...@@ -153,33 +154,24 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
return A, B return A, B
def test_gemm_ss(): @pytest.mark.parametrize(
# More test case can be found in kernel/test_tilelang_kernel_gemm.py "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
# GEMM tests for float16 [
# TODO: support transposed A compressor (512, 1024, 768, False, True, "float16", "float16", "float", 128, 128, 32, 2, 128),
run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float", 128, 128, 32, 2) (512, 1024, 768, False, False, "float16", "float16", "float", 128, 128, 32, 2, 128),
run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float", 128, 128, 32, 2) (512, 1024, 768, True, False, "float16", "float16", "float", 128, 128, 32, 2, 128),
run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float", 128, 128, 32, 2) (512, 1024, 768, True, True, "float16", "float16", "float", 128, 128, 32, 2, 128),
run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float", 128, 128, 32, 2) (128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128),
(128, 128, 128, False, True, "int8", "int32", "int32", 128, 128, 64, 2, 128),
# n8 test (128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
run_gemm_ss(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128) (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128),
# int8 test (128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128),
run_gemm_ss(128, 128, 128, False, True, "int8", "int32", "int32", 128, 128, 64, 2) (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128),
run_gemm_ss(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2) ],
run_gemm_ss(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2) )
run_gemm_ss(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2) def test_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
run_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads)
# float8 tests
run_gemm_ss(128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
run_gemm_ss(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
# tfloat32 test
# run_gemm_ss(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_ss(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_ss(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_ss(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
def matmul_rs( def matmul_rs(
...@@ -313,30 +305,23 @@ def run_gemm_rs( ...@@ -313,30 +305,23 @@ def run_gemm_rs(
print("pass") print("pass")
def test_gemm_rs(): @pytest.mark.parametrize(
# GEMM tests for float16 "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2) [
run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2) (512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2, 128),
run_gemm_rs(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2) (512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2, 128),
run_gemm_rs(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2) (512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2, 128),
# n8 tests (128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128),
run_gemm_rs(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128) (128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
# int8 tests (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
run_gemm_rs(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2) (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128),
run_gemm_rs(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2) (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128),
run_gemm_rs(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2) ],
run_gemm_rs(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2) )
def test_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
# float8 tests run_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads)
run_gemm_rs(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
# float32 tests
# run_gemm_rs(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_rs(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_rs(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_rs(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
def matmul_sr( def matmul_sr(
...@@ -470,30 +455,23 @@ def run_gemm_sr( ...@@ -470,30 +455,23 @@ def run_gemm_sr(
print("pass") print("pass")
def test_gemm_sr(): @pytest.mark.parametrize(
# GEMM tests for float16 "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
run_gemm_sr(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2) [
run_gemm_sr(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2) (512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2, 128),
run_gemm_sr(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2) (512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2, 128),
run_gemm_sr(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2) (512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2, 128),
# n8 tests (128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128),
run_gemm_sr(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128) (128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 128, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 128, 2, 128),
# int8 tests (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
run_gemm_sr(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 128, 2) (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128),
run_gemm_sr(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 128, 2) (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128),
run_gemm_sr(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2) ],
run_gemm_sr(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2) )
def test_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
# float8 tests run_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads)
run_gemm_sr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
# float32 tests
# run_gemm_sr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_sr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_sr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_sr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
def matmul_rr( def matmul_rr(
...@@ -631,31 +609,25 @@ def run_gemm_rr( ...@@ -631,31 +609,25 @@ def run_gemm_rr(
print("pass") print("pass")
def test_gemm_rr(): @pytest.mark.parametrize(
# GEMM tests for float16 "M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
run_gemm_rr(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2) [
run_gemm_rr(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2) (512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2, 128),
run_gemm_rr(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2) (512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2, 128),
run_gemm_rr(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2) (512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2, 128),
run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2) (512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2, 128),
# n8 tests (512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2, 128),
run_gemm_rr(128, 8, 128, False, True, "float16", "float16", "float", 128, 8, 32, 2) (128, 8, 128, False, True, "float16", "float16", "float", 128, 8, 32, 2, 128),
run_gemm_rr(128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 64, 2) (128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 64, 2, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2, 128),
# int8 tests (128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
run_gemm_rr(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2) (128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
run_gemm_rr(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2) (128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128),
run_gemm_rr(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2) (128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128),
run_gemm_rr(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2) ],
)
# float8 tests def test_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
run_gemm_rr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2) run_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads)
# float32 tests
# run_gemm_rr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_rr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_rr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_rr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment