"docs/source/Tutorial/HowToLaunchFromPython.rst" did not exist on "8931b1c5158ae2d88932dcf1d409d1a962782d3d"
Commit 0acb8586 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Bugfix] Fix Transposed Fragment Layout for amd GEMM_RS matrix core (#346)

* [Refactor] Update GEMM Fragment Layout and Improve Matrix Multiplication Functionality

- Adjusted the layout configuration in `gemm_layouts.cc` to correct the repetition parameters for warp and block layouts, enhancing the efficiency of the GEMM fragment generation.
- Refactored the `matmul_rs` function in `test_tilelang_test_amd.py` to improve readability by restructuring the function signature and ensuring consistent formatting.
- Updated the test execution call to run the new `test_gemm_rs_f16f32f32_nt` function, enhancing test coverage for the GEMM functionality.

* lint fix

* bugfix
parent 847a461b
...@@ -179,8 +179,8 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, ...@@ -179,8 +179,8 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
auto base_layout = auto base_layout =
makeGemmFragmentAB16x16CDNATransposed()->Repeat({1, 1}, false, false); makeGemmFragmentAB16x16CDNATransposed()->Repeat({1, 1}, false, false);
auto warp_layout = auto warp_layout =
base_layout->Repeat({warp_m / 16, block_k / 16}, false, false); base_layout->Repeat({block_k / 16, warp_m / 16}, false, true);
auto block_layout = warp_layout->Repeat({block_m / warp_m, 1}, true, true) auto block_layout = warp_layout->Repeat({1, block_m / warp_m}, true, true)
->Replicate(block_n / warp_n); ->Replicate(block_n / warp_n);
return block_layout; return block_layout;
} else { } else {
......
...@@ -128,8 +128,11 @@ def matmul_rs( ...@@ -128,8 +128,11 @@ def matmul_rs(
vec_size = 4 * k_pack vec_size = 4 * k_pack
@T.prim_func @T.prim_func
def main(A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, in_dtype), C: T.Tensor( def main(
(M, N), out_dtype)): A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype) A_shared = T.alloc_shared(A_shared_shape, in_dtype)
A_local = T.alloc_fragment(A_shared_shape, in_dtype) A_local = T.alloc_fragment(A_shared_shape, in_dtype)
...@@ -142,6 +145,7 @@ def matmul_rs( ...@@ -142,6 +145,7 @@ def matmul_rs(
T.copy(A_shared, A_local) T.copy(A_shared, A_local)
else: else:
T.copy(A[by * block_M, k * block_K], A_shared, coalesced_width=vec_size) T.copy(A[by * block_M, k * block_K], A_shared, coalesced_width=vec_size)
T.copy(A_shared, A_local)
if trans_B: if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared, coalesced_width=vec_size) T.copy(B[bx * block_N, k * block_K], B_shared, coalesced_width=vec_size)
else: else:
...@@ -168,7 +172,7 @@ def run_gemm_rs( ...@@ -168,7 +172,7 @@ def run_gemm_rs(
num_threads=128, num_threads=128,
k_pack=1, k_pack=1,
): ):
program = matmul( program = matmul_rs(
M, M,
N, N,
K, K,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment