"src/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "a671ef89ea4cc131f30d7b0fa5d4651c187e6f0b"
Unverified Commit 6882bd50 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Minor fix for tcgen05 (#1242)



* Add correctness evaluation script for GEMM v2

- Introduced a new Python script `correctness_evaluation_tcgen05.py` for testing the correctness of GEMM v2 implementations using pytest.
- Implemented matrix multiplication and compilation checks, along with parameterized tests for various input configurations.
- Enhanced the testing framework to validate GEMM operations with different data types and configurations, ensuring robustness in the implementation.
- Updated logging in `legalize_negative_index.cc` to reduce verbosity by changing from WARNING to DLOG.
- Adjusted assertions in `tcgen05_macro_generator.py` to accommodate new warp size requirements for improved performance.
- Removed unused variable in `gemm_tcgen05.py` to streamline the codebase.

* lint fix

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent 4370309b
# pytest correctness_evaluation.py -n 32
import pytest
from tilelang import tvm as tvm
import tilelang.testing
import tilelang.language as T
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
mbar = T.alloc_barrier(1)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(
A_shared,
B_shared,
C_tmem,
trans_A,
trans_B,
mbar=mbar,
wg_wait=-1,
clear_accum=k == 0)
T.mbarrier_wait_parity(mbar, k % 2)
T.copy(C_tmem, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return main
def _compile_and_check(
program,
trans_A,
trans_B,
in_dtype,
out_dtype,
):
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal)
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
A = (A.view(torch.int32) - 0x1000).view(torch.float32)
B = (B.view(torch.int32) - 0x1000).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
print("assert_allclose")
def run_gemm(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=2,
num_threads=128,
):
if block_N >= 256 or block_M >= 256 or block_K >= 256:
num_stages = 0
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
_compile_and_check(program, trans_A, trans_B, in_dtype, out_dtype)
M_VALUES = [32, 64, 128, 256]
N_VALUES = [64, 128, 256, 512]
K_VALUES = [16, 32, 64, 128]
K_VALUES_8Bit = [32, 64, 128]
FALSE_TRUE_CASES = ([
pytest.param(
k,
"float16",
"float32",
"float32",
id=f"K{k}-float16-float-float",
) for k in K_VALUES
] + [
pytest.param(
k,
"float8_e5m2",
"float32",
"float32",
id="K32-float8_e5m2-float32-float32",
) for k in K_VALUES_8Bit
])
TRANS_CASES = [
pytest.param(False, True, id="nt"),
]
@pytest.mark.parametrize("m", M_VALUES, ids=lambda v: f"M{v}")
@pytest.mark.parametrize("n", N_VALUES, ids=lambda v: f"N{v}")
@pytest.mark.parametrize("k,in_dtype,out_dtype,accum_dtype", FALSE_TRUE_CASES)
def test_gemm_false_true(m, n, k, in_dtype, out_dtype, accum_dtype):
import torch
required_torch_attrs = {
in_dtype,
out_dtype,
accum_dtype,
}
for attr in required_torch_attrs:
if not hasattr(torch, attr):
pytest.skip(f"Torch does not expose dtype {attr}")
run_gemm(
m,
n,
k * 3,
False,
True,
in_dtype,
out_dtype,
accum_dtype,
m,
n,
k,
)
if __name__ == "__main__":
# tilelang.testing.main()
# # Test Pass
# for m in [32, 64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [16, 32, 64, 128]:
# if m in [32, 64] and (n not in [64, 128, 256]):
# continue
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float16", "float", "float", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
# # Test Pass
# for m in [32, 64, 128, 256]:
# for n in [16, 32, 64, 128]:
# for k in [32, 64, 128]:
# if m in [32, 64] and (n not in [64, 128, 256]):
# continue
# print(f"======================= Test {m} {n} {k} False True =============================")
# run_gemm(m, n, k * 3, False, True, "float8_e5m2", "float", "float", m, n, k, 2, 128)
# print(f"Test {m} {n} {k} Pass")
tilelang.disable_cache()
run_gemm(32, 512, 16, False, True, "float16", "float32", "float32", 32, 512, 16, 0, 128)
run_gemm(32, 512, 32, False, True, "float16", "float32", "float32", 32, 512, 32, 0, 128)
run_gemm(32, 512, 64, False, True, "float16", "float32", "float32", 32, 512, 64, 0, 128)
run_gemm(64, 512, 16, False, True, "float16", "float32", "float32", 64, 512, 16, 0, 128)
run_gemm(64, 512, 16, False, True, "float16", "float32", "float32", 32, 512, 16, 0, 128)
run_gemm(128, 512, 16, False, True, "float16", "float32", "float32", 128, 512, 16, 0, 128)
# run_gemm(64, 512, 32, False, True, "float16", "float32", "float32", 64, 512, 32, 0, 128)
# run_gemm(64, 512, 64, False, True, "float16", "float32", "float32", 64, 512, 64, 0, 128)
# run_gemm(128, 512, 16, False, True, "float16", "float32", "float32", 128, 512, 16, 0, 128)
...@@ -52,7 +52,7 @@ public: ...@@ -52,7 +52,7 @@ public:
} }
states.push_back(IndexSignState::kUnknown); states.push_back(IndexSignState::kUnknown);
needs_record = true; needs_record = true;
LOG(WARNING) DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index " << "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << load->buffer->name << " (axis " << simplified << " for buffer " << load->buffer->name << " (axis "
<< i << ")."; << i << ").";
......
...@@ -103,13 +103,14 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -103,13 +103,14 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16): def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16):
warp_row_tiles = self.warp_row_tiles warp_row_tiles = self.warp_row_tiles
warp_col_tiles = self.warp_col_tiles warp_col_tiles = self.warp_col_tiles
assert warp_row_tiles >= 16, f"warp_row_tiles must be greater than 16, got {warp_row_tiles}" # For tcgen05, warp_row_tiles is 8 as we can use .ws to support m32
assert warp_row_tiles % 16 == 0, f"warp_row_tiles must be divisible by 16, got {warp_row_tiles}" assert warp_row_tiles >= 8, f"warp_row_tiles must be greater than 8, got {warp_row_tiles}"
assert warp_row_tiles % 8 == 0, f"warp_row_tiles must be divisible by 8, got {warp_row_tiles}"
assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}" assert warp_col_tiles >= 8, f"warp_col_tiles must be greater than 8, got {warp_col_tiles}"
assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}" assert warp_col_tiles % 8 == 0, f"warp_col_tiles must be divisible by 8, got {warp_col_tiles}"
# four warps per block # four warps per block
self.warp_rows = warp_row_tiles // m_dim self.warp_rows = warp_row_tiles // 8
if warp_col_tiles % 16 == 0: if warp_col_tiles % 16 == 0:
self.n_dim = 16 self.n_dim = 16
self.micro_size_y = 16 self.micro_size_y = 16
...@@ -246,6 +247,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -246,6 +247,9 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
mask_zero = T.Cast("int32", 0) mask_zero = T.Cast("int32", 0)
mask0 = mask1 = mask2 = mask3 = mask_zero mask0 = mask1 = mask2 = mask3 = mask_zero
num_inst_m = 4 * self.warp_row_tiles // atom_m
num_inst_n = self.warp_col_tiles // atom_n
# Helper to allow BufferRegion/BufferLoad as inputs # Helper to allow BufferRegion/BufferLoad as inputs
def access_ptr_from(buffer_or_load_or_region, access_type: str = "r"): def access_ptr_from(buffer_or_load_or_region, access_type: str = "r"):
if isinstance(buffer_or_load_or_region, Buffer): if isinstance(buffer_or_load_or_region, Buffer):
...@@ -302,20 +306,27 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): ...@@ -302,20 +306,27 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter):
int(b_swizzle_mode), int(b_swizzle_mode),
) )
for ki in T.serial(0, (k_dim // micro_size_k)): tmem_col_step = atom_n // (128 // atom_m)
scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1)) for j in T.unroll(num_inst_n):
for i in T.serial(m_dim // atom_m): for i in T.unroll(num_inst_m):
for ki in T.unroll(0, (k_dim // micro_size_k)):
scale_out = T.Select(ki != 0, 1, T.Select(clear_accum, 0, 1))
A_elem_offset = ( A_elem_offset = (
ki % ak_atom_size ki % ak_atom_size
) * micro_size_k + i * atom_m * a_swizzle_atom_elems + ( ) * micro_size_k + i * atom_m * a_swizzle_atom_elems + (
ki // ak_atom_size ki // ak_atom_size
) * m_dim * a_swizzle_atom_elems if a_is_k_major else i * atom_m * k_dim + ki * a_swizzle_atom_elems * micro_size_k ) * m_dim * a_swizzle_atom_elems if a_is_k_major else i * atom_m * k_dim + ki * a_swizzle_atom_elems * micro_size_k
B_elem_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + ( B_elem_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + (
ki % bk_atom_size ki % bk_atom_size
) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k ) * micro_size_k + j * atom_n * b_swizzle_atom_elems if b_is_k_major else (
ki * b_swizzle_atom_elems * micro_size_k + j * atom_n *
(k_dim if n_dim // b_swizzle_atom_elems > 1 else 1))
A_byte_offset = A_elem_offset * elems_in_bytes A_byte_offset = A_elem_offset * elems_in_bytes
B_byte_offset = B_elem_offset * elems_in_bytes B_byte_offset = B_elem_offset * elems_in_bytes
C_offset = i * atom_n * accum_dtype_in_bits // 32 # 32 bits per tmem bank C_offset = (i * n_dim + j * tmem_col_step
) * accum_dtype_in_bits // 32 # 32 bits per tmem bank
T.ptx_tcgen05_mma_ss( T.ptx_tcgen05_mma_ss(
a_dtype_abbrv, a_dtype_abbrv,
......
...@@ -85,8 +85,6 @@ class GemmTCGEN5(GemmBase): ...@@ -85,8 +85,6 @@ class GemmTCGEN5(GemmBase):
raise ValueError(f"TCGEN5MMA currently only supports gemm_ss, got " raise ValueError(f"TCGEN5MMA currently only supports gemm_ss, got "
f"A scope {self.A.scope()}, B scope {self.B.scope()}") f"A scope {self.A.scope()}, B scope {self.B.scope()}")
atom_m, atom_n, atom_k = mma_emitter.get_tcgen5_mma_meta(self.M, self.N, self.K)
if self.A.scope() not in {"shared", "shared.dyn", "shared.tmem"}: if self.A.scope() not in {"shared", "shared.dyn", "shared.tmem"}:
raise ValueError(f"Unsupported A scope for TCGEN5MMA: {self.A.scope()}") raise ValueError(f"Unsupported A scope for TCGEN5MMA: {self.A.scope()}")
if self.B.scope() not in {"shared", "shared.dyn"}: if self.B.scope() not in {"shared", "shared.dyn"}:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment