Unverified Commit 0b6336b5 authored by Kuris's avatar Kuris Committed by GitHub
Browse files

[Refactor] Use `pytest.mark.parameterize` to speedup parallel testing (#1447)



* Refactor GEMM tests to use parameterized pytest fixtures

- Converted multiple test cases for GEMM operations in `test_tilelang_tilelibrary_gemm_sp.py` to use `pytest.mark.parametrize` for better maintainability and readability.
- Similar refactoring applied to `test_tilelang_tilelibrary_gemm_sp_v2.py`, consolidating test cases for `run_gemm_ss`, `run_gemm_rs`, `run_gemm_sr`, and `run_gemm_rr` into parameterized tests.
- This change reduces code duplication and enhances the clarity of test configurations.

* Update testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

---------
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
parent dda45126
import pytest
import torch
import tilelang.testing
from tilelang import tvm as tvm
......@@ -207,17 +208,33 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype="floa
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@pytest.mark.parametrize(
"M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack",
[
(128, 128, 128, "float16", "float16", "float32", False, True, 1),
(128, 256, 256, "float16", "float32", "float32", False, True, 1),
(128, 256, 256, "float16", "float32", "float32", False, True, 2),
(128, 128, 128, "int8", "int32", "int32", False, True, 1),
(128, 256, 256, "int8", "int32", "int32", False, True, 1),
(128, 256, 256, "int8", "int32", "int32", False, True, 2),
(128, 256, 256, "int8", "int32", "int32", False, False, 1),
(128, 256, 256, "int8", "int32", "int32", False, False, 2),
(128, 128, 128, "float8_e4m3fnuz", "float16", "float32", False, True, 1),
],
)
@tilelang.testing.requires_rocm
def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", k_pack=2)
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2)
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2)
assert_tl_matmul_correctness(128, 128, 128, "float8_e4m3fnuz", "float16")
def test_assert_tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack):
assert_tl_matmul_correctness(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype=accum_dtype,
a_transposed=a_transposed,
b_transposed=b_transposed,
k_pack=k_pack,
)
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32")
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", k_pack=2)
assert_tl_matmul_correctness(128, 256, 256, "float8_e4m3fnuz", "float32", b_transposed=False)
......
import pytest
import torch
import tilelang.testing
from tilelang import tvm as tvm
......@@ -257,19 +258,46 @@ def assert_tl_matmul_correctness(
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
@pytest.mark.parametrize(
"M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack, b_preshuffle, b_g2l_load",
[
(256, 256, 512, "int8", "int32", "int32", False, True, 1, True, False),
(256, 256, 512, "int8", "int32", "int32", False, False, 1, True, False),
(256, 256, 512, "int8", "int32", "int32", False, True, 2, True, False),
(256, 256, 512, "int8", "int32", "int32", False, False, 2, True, False),
(256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, True, 1, True, False),
(256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, False, 1, True, False),
(256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, True, 2, True, False),
(256, 256, 512, "float8_e4m3fnuz", "float32", "float32", False, False, 2, True, False),
],
)
@tilelang.testing.requires_rocm
def test_assert_tl_matmul():
assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "int8", "int32", b_transposed=False, accum_dtype="int32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", b_transposed=False, b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_preshuffle=True)
assert_tl_matmul_correctness(256, 256, 512, "float8_e4m3fnuz", "float32", k_pack=2, b_transposed=False, b_preshuffle=True)
def test_assert_tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
a_transposed,
b_transposed,
k_pack,
b_preshuffle,
b_g2l_load,
):
assert_tl_matmul_correctness(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype=accum_dtype,
a_transposed=a_transposed,
b_transposed=b_transposed,
k_pack=k_pack,
b_preshuffle=b_preshuffle,
b_g2l_load=b_g2l_load,
)
if __name__ == "__main__":
......
import pytest
from tilelang import tvm as tvm
import tilelang as tl
import tilelang.language as T
......@@ -95,31 +96,49 @@ def run_gemm(
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
@pytest.mark.parametrize(
"trans_A, trans_B, k_pack",
[
(False, False, 1),
(False, True, 1),
(True, True, 1),
(True, False, 1),
(False, True, 2),
],
)
@tilelang.testing.requires_rocm
def test_gemm_f16f32f32_nt():
run_gemm(1024, 1024, 1024, False, False, "float16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, True, "float16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, False, "float16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, False, True, "float16", "float32", "float32", 128, 128, 32, k_pack=2)
def test_gemm_f16f32f32_nt(trans_A, trans_B, k_pack):
run_gemm(1024, 1024, 1024, trans_A, trans_B, "float16", "float32", "float32", 128, 128, 32, k_pack=k_pack)
@pytest.mark.parametrize(
"trans_A, trans_B, k_pack",
[
(False, False, 1),
(False, True, 1),
(True, True, 1),
(True, False, 1),
(False, True, 2),
],
)
@tilelang.testing.requires_rocm
def test_gemm_bf16f32f32_nt():
run_gemm(1024, 1024, 1024, False, False, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, True, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, False, "bfloat16", "float32", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, False, True, "bfloat16", "float32", "float32", 128, 128, 32, k_pack=2)
def test_gemm_bf16f32f32_nt(trans_A, trans_B, k_pack):
run_gemm(1024, 1024, 1024, trans_A, trans_B, "bfloat16", "float32", "float32", 128, 128, 32, k_pack=k_pack)
@pytest.mark.parametrize(
"trans_A, trans_B, k_pack",
[
(False, False, 1),
(False, True, 1),
(True, True, 1),
(True, False, 1),
(False, True, 2),
],
)
@tilelang.testing.requires_rocm
def test_gemm_bf16bf16f32():
run_gemm(1024, 1024, 1024, False, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, True, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, True, False, "bfloat16", "bfloat16", "float32", 128, 128, 32)
run_gemm(1024, 1024, 1024, False, True, "bfloat16", "bfloat16", "float32", 128, 128, 32, k_pack=2)
def test_gemm_bf16bf16f32(trans_A, trans_B, k_pack):
run_gemm(1024, 1024, 1024, trans_A, trans_B, "bfloat16", "bfloat16", "float32", 128, 128, 32, k_pack=k_pack)
def matmul_rs(
......
import pytest
import tilelang
import tilelang.language as T
import torch
......@@ -242,13 +243,9 @@ def run_fastmath_mathop_test(mathop_name, mathop_func, M=128, N=128, block_M=32,
print(f"✓ {mathop_name} numerical test passed")
@tilelang.testing.requires_cuda
def test_mathops_generate_no_fastmath():
"""Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)"""
# Based on test results, our tl.* intrinsics actually generate
# no fastmath versions
# This appears to be the intended behavior
single_arg_mathops = [
@pytest.mark.parametrize(
"name, func",
[
("exp", T.exp),
("exp2", T.exp2),
("exp10", T.exp10),
......@@ -270,24 +267,26 @@ def test_mathops_generate_no_fastmath():
("trunc", T.trunc),
("round", T.round),
("nearbyint", T.nearbyint),
]
for name, func in single_arg_mathops:
run_single_arg_mathop_test(name, func, dtype="float32")
print(f"✓ {name} test passed")
],
)
@tilelang.testing.requires_cuda
def test_mathops_generate_no_fastmath(name, func):
"""Test that our tl.* mathops generate fastmath CUDA code (__expf etc.)"""
run_single_arg_mathop_test(name, func, dtype="float32")
print(f"✓ {name} test passed")
@tilelang.testing.requires_cuda
def test_two_arg_mathops_fastmath():
"""Test all two-argument mathops"""
# Two argument mathops
two_arg_mathops = [
@pytest.mark.parametrize(
"name, func",
[
("pow", T.pow),
("fmod", T.fmod),
]
for name, func in two_arg_mathops:
run_two_arg_mathop_test(name, func, dtype="float32")
],
)
@tilelang.testing.requires_cuda
def test_two_arg_mathops_fastmath(name, func):
"""Test all two-argument mathops"""
run_two_arg_mathop_test(name, func, dtype="float32")
@tilelang.testing.requires_cuda
......@@ -296,11 +295,9 @@ def test_abs_maps_to_fabs():
run_abs_test()
@tilelang.testing.requires_cuda
def test_fastmath_versions():
"""Test that __exp, __exp10, __log, __log2, __log10, __tan, __cos, __sin generate fastmath CUDA code"""
# Test fastmath versions
fastmath_mathops = [
@pytest.mark.parametrize(
"name, func",
[
("__exp", T.__exp),
("__exp10", T.__exp10),
("__log", T.__log),
......@@ -309,11 +306,13 @@ def test_fastmath_versions():
("__tan", T.__tan),
("__cos", T.__cos),
("__sin", T.__sin),
]
for name, func in fastmath_mathops:
run_fastmath_mathop_test(name, func, dtype="float32")
print(f"✓ {name} test passed")
],
)
@tilelang.testing.requires_cuda
def test_fastmath_versions(name, func):
"""Test that __exp, __exp10, __log, __log2, __log10, __tan, __cos, __sin generate fastmath CUDA code"""
run_fastmath_mathop_test(name, func, dtype="float32")
print(f"✓ {name} test passed")
if __name__ == "__main__":
......
import pytest
import torch
import tilelang.testing
import tilelang.language as T
......@@ -77,38 +78,29 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
assert check_str in code and check_str in code_parallel, f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"
def test_vectorized_cast():
# fp32 -> fp16
run_vectorized_cast("float32", "float16", "__float22half2_rn", 2)
run_vectorized_cast("float32", "float16", "__float22half2_rn", 4)
# fp16 -> fp32
run_vectorized_cast("float16", "float32", "__half22float2", 2)
run_vectorized_cast("float16", "float32", "__half22float2", 4)
# fp32 -> fp8_e4m3
run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 2)
run_vectorized_cast("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 4)
# fp32 -> fp8_e5m2
run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 2)
run_vectorized_cast("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 4)
# fp32 -> bf16
run_vectorized_cast("float32", "bfloat16", "__float22bfloat162_rn", 2)
run_vectorized_cast("float32", "bfloat16", "__float22bfloat162_rn", 4)
# bf16 -> fp32
run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 2)
run_vectorized_cast("bfloat16", "float32", "__bfloat1622float2", 4)
# fp8_e4m3 -> fp32
run_vectorized_cast("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 2)
run_vectorized_cast("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 4)
# fp8_e5m2 -> fp32
run_vectorized_cast("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 2)
run_vectorized_cast("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 4)
@pytest.mark.parametrize(
"src_dtype, dst_dtype, check_str, lanes",
[
("float32", "float16", "__float22half2_rn", 2),
("float32", "float16", "__float22half2_rn", 4),
("float16", "float32", "__half22float2", 2),
("float16", "float32", "__half22float2", 4),
("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 2),
("float32", "float8_e4m3", "__nv_cvt_float2_to_fp8x2", 4),
("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 2),
("float32", "float8_e5m2", "__nv_cvt_float2_to_fp8x2", 4),
("float32", "bfloat16", "__float22bfloat162_rn", 2),
("float32", "bfloat16", "__float22bfloat162_rn", 4),
("bfloat16", "float32", "__bfloat1622float2", 2),
("bfloat16", "float32", "__bfloat1622float2", 4),
("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 2),
("float8_e4m3", "float32", "__tl_cvt_fp8x2_to_float2", 4),
("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 2),
("float8_e5m2", "float32", "__tl_cvt_fp8x2_to_float2", 4),
],
)
def test_vectorized_cast(src_dtype, dst_dtype, check_str, lanes):
run_vectorized_cast(src_dtype, dst_dtype, check_str, lanes)
if __name__ == "__main__":
......
......@@ -109,30 +109,27 @@ def run_gemm_ss(
@pytest.mark.skip(reason="Temporarily disabling until GEMM SS issues are resolved")
def test_gemm_ss():
# More test case can be found in kernel/test_tilelang_kernel_gemm.py
# GEMM tests for float16
run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 2)
run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 2)
run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 2)
run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 2)
# n8 test
run_gemm_ss(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128)
# int8 test
run_gemm_ss(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_ss(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_ss(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_ss(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2)
# float8 tests
run_gemm_ss(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2)
# tfloat32 test
run_gemm_ss(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_ss(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_ss(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_ss(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
@pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[
(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 128, 32, 2, 128),
(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 128, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 128, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 128, 32, 2, 128),
(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128),
],
)
def test_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
run_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads)
def matmul_rs(
......@@ -247,30 +244,27 @@ def run_gemm_rs(
@pytest.mark.skip(reason="Temporarily disabling until GEMM RS issues are resolved")
def test_gemm_rs():
# GEMM tests for float16
run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_rs(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_rs(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2)
# n8 tests
run_gemm_rs(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128)
# int8 tests
run_gemm_rs(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_rs(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_rs(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_rs(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2)
# float8 tests
run_gemm_rs(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2)
# float32 tests
run_gemm_rs(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_rs(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_rs(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_rs(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
@pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[
(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128),
],
)
def test_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
run_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads)
def matmul_sr(
......@@ -384,31 +378,27 @@ def run_gemm_sr(
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_sr():
# GEMM tests for float16
run_gemm_sr(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_sr(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_sr(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_sr(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2)
# n8 tests
run_gemm_sr(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128)
# int8 tests
run_gemm_sr(128, 128, 32, False, True, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 32, False, False, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 32, True, False, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 32, True, True, "int8", "int8", "int32", 128, 128, 32, 2)
# float8 tests
run_gemm_sr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2)
# float32 tests
# TODO(lei): fix in future
run_gemm_sr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_sr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
@pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[
(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(128, 8, 32, False, True, "float16", "float16", "float16", 128, 8, 32, 0, 128),
(128, 128, 32, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 32, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 32, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 32, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128),
],
)
def test_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
run_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads)
def matmul_rr(
......@@ -526,31 +516,29 @@ def run_gemm_rr(
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_rr():
# GEMM tests for float16
run_gemm_rr(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_rr(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_rr(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_rr(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2)
run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2)
# n8 tests
run_gemm_rr(128, 8, 128, False, True, "float16", "float16", "float16", 128, 8, 32, 2)
run_gemm_rr(128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 32, 2)
# int8 tests
run_gemm_rr(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_rr(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_rr(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2)
run_gemm_rr(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2)
# float8 tests
run_gemm_rr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2)
# float32 tests
run_gemm_rr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_rr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_rr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
run_gemm_rr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
@pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[
(512, 1024, 768, False, False, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float16", 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2, 128),
(128, 8, 128, False, True, "float16", "float16", "float16", 128, 8, 32, 2, 128),
(128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 32, 2, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2, 128),
(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2, 128),
],
)
def test_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
run_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads)
if __name__ == "__main__":
......
import pytest
import torch
import tilelang
import tilelang.testing
......@@ -303,50 +304,53 @@ def run_gemm_sp_sm80(
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version(9, 0)
def test_gemm_sp_sm90():
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 2, 128)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 0, 256)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 0, 128)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 2, 128)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 0, 128)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 2, 128)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, False)
run_gemm_sp_sm90(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, True)
run_gemm_sp_sm90(512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False, True)
run_gemm_sp_sm90(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True)
@pytest.mark.parametrize(
"M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B",
[
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 2, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 32, 0, 256, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 0, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 128, 128, 128, 2, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 0, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 128, 256, 2, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, True, True),
(512, 1024, 768, "float8_e4m3", "float16", "float16", 64, 64, 64, 2, 128, False, True),
(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True),
],
)
def test_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B):
run_gemm_sp_sm90(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B)
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(8, 0)
@tilelang.testing.requires_cuda_compute_version_le(8, 9)
def test_gemm_sp_sm80():
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 32, 0, 32)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False, True)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, True)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 1, 128)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128)
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 3, 128)
run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 32, 32, 64, 0, 32, False, True)
run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 0, 32, False, True)
run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 128, 128, 128, 0, 128, False, True)
run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 1, 128, False, True)
run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True)
run_gemm_sp_sm80(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 3, 128, False, True)
@pytest.mark.parametrize(
"M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B",
[
(512, 1024, 768, "float16", "float32", "float32", 32, 32, 32, 0, 32, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 32, 32, 64, 0, 32, False, True),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32, False, True),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 128, False, True),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 1, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 2, 128, False, False),
(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 3, 128, False, False),
(512, 1024, 768, "int8", "int32", "int32", 32, 32, 64, 0, 32, False, True),
(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 0, 32, False, True),
(512, 1024, 768, "int8", "int32", "int32", 128, 128, 128, 0, 128, False, True),
(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 1, 128, False, True),
(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True),
(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 3, 128, False, True),
],
)
def test_gemm_sp_sm80(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B):
run_gemm_sp_sm80(M, N, K, in_dtype, out_dtype, accum_dtype, block_M, block_N, block_K, num_stages, num_threads, trans_A, trans_B)
if __name__ == "__main__":
......
import pytest
from tilelang import tvm as tvm
from tilelang.utils.sparse import compress, randn_semi_sparse, randint_semi_sparse
from tilelang.utils.tensor import torch_assert_close, map_torch_type
......@@ -153,33 +154,24 @@ def generate_dense_input(M, N, K, trans_A, trans_B, in_dtype):
return A, B
def test_gemm_ss():
# More test case can be found in kernel/test_tilelang_kernel_gemm.py
# GEMM tests for float16
# TODO: support transposed A compressor
run_gemm_ss(512, 1024, 768, False, True, "float16", "float16", "float", 128, 128, 32, 2)
run_gemm_ss(512, 1024, 768, False, False, "float16", "float16", "float", 128, 128, 32, 2)
run_gemm_ss(512, 1024, 768, True, False, "float16", "float16", "float", 128, 128, 32, 2)
run_gemm_ss(512, 1024, 768, True, True, "float16", "float16", "float", 128, 128, 32, 2)
# n8 test
run_gemm_ss(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128)
# int8 test
run_gemm_ss(128, 128, 128, False, True, "int8", "int32", "int32", 128, 128, 64, 2)
run_gemm_ss(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2)
run_gemm_ss(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2)
run_gemm_ss(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2)
# float8 tests
run_gemm_ss(128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
run_gemm_ss(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
# tfloat32 test
# run_gemm_ss(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_ss(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_ss(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_ss(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
@pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[
(512, 1024, 768, False, True, "float16", "float16", "float", 128, 128, 32, 2, 128),
(512, 1024, 768, False, False, "float16", "float16", "float", 128, 128, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float", 128, 128, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float", 128, 128, 32, 2, 128),
(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128),
(128, 128, 128, False, True, "int8", "int32", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, False, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128),
],
)
def test_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
run_gemm_ss(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads)
def matmul_rs(
......@@ -313,30 +305,23 @@ def run_gemm_rs(
print("pass")
def test_gemm_rs():
# GEMM tests for float16
run_gemm_rs(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_rs(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_rs(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_rs(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2)
# n8 tests
run_gemm_rs(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128)
# int8 tests
run_gemm_rs(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2)
run_gemm_rs(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2)
run_gemm_rs(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2)
run_gemm_rs(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2)
# float8 tests
run_gemm_rs(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
# float32 tests
# run_gemm_rs(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_rs(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_rs(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_rs(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
@pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[
(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2, 128),
(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128),
],
)
def test_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
run_gemm_rs(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads)
def matmul_sr(
......@@ -470,30 +455,23 @@ def run_gemm_sr(
print("pass")
def test_gemm_sr():
# GEMM tests for float16
run_gemm_sr(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_sr(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_sr(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_sr(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2)
# n8 tests
run_gemm_sr(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128)
# int8 tests
run_gemm_sr(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 128, 2)
run_gemm_sr(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 128, 2)
run_gemm_sr(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2)
run_gemm_sr(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2)
# float8 tests
run_gemm_sr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
# float32 tests
# run_gemm_sr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_sr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_sr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_sr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
@pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[
(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2, 128),
(128, 8, 64, False, True, "float16", "float16", "float", 128, 8, 32, 0, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 128, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 128, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128),
],
)
def test_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
run_gemm_sr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads)
def matmul_rr(
......@@ -631,31 +609,25 @@ def run_gemm_rr(
print("pass")
def test_gemm_rr():
# GEMM tests for float16
run_gemm_rr(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_rr(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_rr(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_rr(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2)
run_gemm_rr(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2)
# n8 tests
run_gemm_rr(128, 8, 128, False, True, "float16", "float16", "float", 128, 8, 32, 2)
run_gemm_rr(128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 64, 2)
# int8 tests
run_gemm_rr(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2)
run_gemm_rr(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2)
run_gemm_rr(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2)
run_gemm_rr(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2)
# float8 tests
run_gemm_rr(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2)
# float32 tests
# run_gemm_rr(128, 128, 128, False, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_rr(128, 128, 128, False, True, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_rr(128, 128, 128, True, False, "float", "float", "float32", 128, 128, 32, 2)
# run_gemm_rr(128, 128, 128, True, True, "float", "float", "float32", 128, 128, 32, 2)
@pytest.mark.parametrize(
"M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads",
[
(512, 1024, 768, False, False, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, True, False, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, True, True, "float16", "float16", "float", 128, 256, 32, 2, 128),
(512, 1024, 768, False, True, "bfloat16", "bfloat16", "float", 128, 256, 32, 2, 128),
(128, 8, 128, False, True, "float16", "float16", "float", 128, 8, 32, 2, 128),
(128, 8, 128, False, True, "int8", "int8", "int32", 128, 8, 64, 2, 128),
(128, 128, 128, False, True, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, False, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, False, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "int8", "int8", "int32", 128, 128, 64, 2, 128),
(128, 128, 128, True, True, "float8_e5m2", "float8_e5m2", "float32", 128, 128, 64, 2, 128),
],
)
def test_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads):
run_gemm_rr(M, N, K, trans_A, trans_B, in_dtype, out_dtype, dtypeAccum, block_M, block_N, block_K, num_stages, num_threads)
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment