Commit 0acb8586 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Fix Transposed Fragment Layout for amd GEMM_RS matrix core (#346)

* [Refactor] Update GEMM Fragment Layout and Improve Matrix Multiplication Functionality

- Adjusted the layout configuration in `gemm_layouts.cc` to correct the repetition parameters for warp and block layouts, enhancing the efficiency of the GEMM fragment generation.
- Refactored the `matmul_rs` function in `test_tilelang_test_amd.py` to improve readability by restructuring the function signature and ensuring consistent formatting.
- Updated the test execution call to run the new `test_gemm_rs_f16f32f32_nt` function, enhancing test coverage for the GEMM functionality.

* lint fix

* bugfix
parent 847a461b
......@@ -179,8 +179,8 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
auto base_layout =
makeGemmFragmentAB16x16CDNATransposed()->Repeat({1, 1}, false, false);
auto warp_layout =
base_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
auto block_layout = warp_layout->Repeat({block_m / warp_m, 1}, true, true)
base_layout->Repeat({block_k / 16, warp_m / 16}, false, true);
auto block_layout = warp_layout->Repeat({1, block_m / warp_m}, true, true)
->Replicate(block_n / warp_n);
return block_layout;
} else {
......
......@@ -128,8 +128,11 @@ def matmul_rs(
vec_size = 4 * k_pack
@T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor(
(M, N), out_dtype)):
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
A_local = T.alloc_fragment(A_shared_shape, in_dtype)
......@@ -142,6 +145,7 @@ def matmul_rs(
T.copy(A_shared, A_local)
else:
T.copy(A[by * block_M, k * block_K], A_shared, coalesced_width=vec_size)
T.copy(A_shared, A_local)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared, coalesced_width=vec_size)
else:
......@@ -168,7 +172,7 @@ def run_gemm_rs(
num_threads=128,
k_pack=1,
):
program = matmul(
program = matmul_rs(
M,
N,
K,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment