Commit d4f096ef authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Support register input for gemm when trans_a or trans_b is true (#490)

* [Refactor] Enhance makeGemmFragmentB to support transposition

* Updated the `makeGemmFragmentB` function to include a `transposed` parameter, allowing for flexible layout generation based on matrix transposition.
* Adjusted layout calculations for both transposed and non-transposed cases to ensure correct fragment generation.
* Modified the function signature in `layout.h` and updated all relevant calls in `gemm.cc` to accommodate the new parameter.
* Added a new `matmul_sr` function in the test suite to validate the behavior of the updated fragment generation with transposition support.

* [Refactor] Enhance makeGemmFragmentA and makeGemmFragmentB for transposition support

* Updated the `makeGemmFragmentA` and `makeGemmFragmentB` functions to include a `transposed` parameter, allowing for flexible layout generation based on matrix transposition.
* Adjusted layout calculations for both transposed and non-transposed cases to ensure correct fragment generation.
* Modified function signatures in `layout.h` and updated all relevant calls in `gemm.cc` to accommodate the new parameter.
* Added a new `matmul_rs` function in the test suite to validate the behavior of the updated fragment generation with transposition support.
*

* Improve error messaging in layout equality checks

* Enhanced the error output in layout equality checks to provide clearer context by adding line breaks for better readability in the debug output.
* This change ensures that when layouts are structurally unequal, the current and previous layouts are displayed more distinctly, aiding in debugging.
parent 39ae28e4
......@@ -139,7 +139,8 @@ Fragment makeGemmFragmentCHopper(const int block_m, const int block_n,
Fragment makeGemmFragmentA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, const int element_size) {
const int warp_n, const int element_size,
bool transposed) {
// assume not transposed
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
......@@ -148,23 +149,58 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n,
// Only support 8-bit and 16-bit
ICHECK(element_size == 8 || element_size == 16)
<< "element bitwidth=" << element_size;
if (element_size == 8) {
auto base_layout = makeGemmFragment8x16()->Repeat({2, 2}, false, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)
if (transposed) {
auto base_layout =
makeGemmFragment8x8Transposed()->Repeat({2, 2}, false, true);
auto warp_layout = base_layout->Repeat({1, block_m / warp_m}, true, false)
->Replicate(block_n / warp_n);
auto block_layout =
warp_layout->Repeat({warp_m / 16, block_k / 32}, false, false);
warp_layout->Repeat({block_k / 16, warp_m / 16}, false, true);
return block_layout;
} else if (element_size == 16) {
auto base_layout = makeGemmFragment8x8()->Repeat({2, 2}, false, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)
->Replicate(block_n / warp_n);
} else {
if (element_size == 8) {
auto base_layout = makeGemmFragment8x16()->Repeat({2, 2}, false, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)
->Replicate(block_n / warp_n);
auto block_layout =
warp_layout->Repeat({warp_m / 16, block_k / 32}, false, false);
return block_layout;
} else if (element_size == 16) {
auto base_layout = makeGemmFragment8x8()->Repeat({2, 2}, false, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)
->Replicate(block_n / warp_n);
auto block_layout =
warp_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
return block_layout;
} else {
ICHECK(0);
return Fragment();
}
}
}
Fragment makeGemmFragmentB(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, bool transposed) {
// transposed
ICHECK(warp_n % 8 == 0);
ICHECK(block_k % 16 == 0);
if (transposed) {
auto base_layout = makeGemmFragment8x8()->Repeat({1, 2}, false, false);
auto warp_layout = base_layout->Repeat({block_n / warp_n, 1}, true, true)
->Replicate(block_m / warp_m);
auto block_layout =
warp_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
warp_layout->Repeat({warp_n / 8, block_k / 16}, false, false);
return block_layout;
} else {
ICHECK(0);
return Fragment();
auto base_layout =
makeGemmFragment8x8Transposed()->Repeat({2, 1}, false, false);
auto warp_layout = base_layout->Replicate(block_m / warp_m)
->Repeat({1, block_n / warp_n}, true);
auto block_layout =
warp_layout->Repeat({block_k / 16, warp_n / 8}, false, true);
return block_layout;
}
}
......@@ -198,21 +234,6 @@ Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
}
}
Fragment makeGemmFragmentB(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n) {
// transposed
ICHECK(warp_n % 8 == 0);
ICHECK(block_k % 16 == 0);
auto base_layout =
makeGemmFragment8x8Transposed()->Repeat({2, 1}, false, false);
auto warp_layout = base_layout->Replicate(block_m / warp_m)
->Repeat({1, block_n / warp_n}, true);
auto block_layout =
warp_layout->Repeat({block_k / 16, warp_n / 8}, false, true);
return block_layout;
}
Fragment makeGemmFragment32x32(int element_size) {
IterVar i = make_itervar("i", 32);
IterVar j = make_itervar("j", 32);
......
......@@ -145,10 +145,11 @@ Fragment makeGemmFragmentCHopper(const int block_m, const int block_n,
const int element_size);
Fragment makeGemmFragmentA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, const int element_size);
const int warp_n, const int element_size,
bool transposed = false);
Fragment makeGemmFragmentB(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n);
const int warp_n, bool transposed = false);
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
const int block_k, const int warp_m,
......
......@@ -207,9 +207,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
A->dtype.bits(), trans_A ? 1 : 2));
} else if (A.scope() == "local.fragment") {
ICHECK(trans_A == false);
auto fragment =
makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n, A->dtype.bits());
auto fragment = makeGemmFragmentA(M, N, K, M / warp_m, N / warp_n,
A->dtype.bits(), trans_A);
results.Set(A, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
......@@ -222,9 +221,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
makeGemmABLayout(mat_stride, mat_continuous, mat_continuous,
B->dtype.bits(), trans_B ? 2 : 1));
} else if (B.scope() == "local.fragment") {
ICHECK(trans_B == false) << "B is local.fragment, trans_B must be false, "
"please raise an issue if you see this";
auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n);
auto fragment =
makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
results.Set(B, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
......@@ -295,7 +293,8 @@ LayoutMap Gemm::InferLayout(const LayoutInferArgs &T, InferLevel level) {
results.Set(B, shared_layout);
} else if (B.scope() == "local.fragment") {
auto fragment = makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n);
auto fragment =
makeGemmFragmentB(M, N, K, M / warp_m, N / warp_n, trans_B);
results.Set(B, fragment->BindThreadRange(thread_range));
} else {
ICHECK(0);
......
......@@ -288,8 +288,8 @@ public:
// If already in map, ensure they are structurally equal
ICHECK(StructuralEqual()(layout, layout_map[buffer]))
<< "Get different layout for " << buffer
<< " current layout: " << layout->DebugOutput()
<< " previous layout: " << layout_map[buffer]->DebugOutput();
<< "\n current layout: " << layout->DebugOutput()
<< "\n previous layout: " << layout_map[buffer]->DebugOutput();
} else {
// Otherwise, update map
layout_map.Set(buffer, layout);
......
......@@ -291,5 +291,238 @@ def test_pad_f16f16f32_nn():
)
def matmul_sr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
B_local = T.alloc_fragment(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
T.copy(B[bx * block_N, k * block_K], B_local)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.copy(B[k * block_K, bx * block_N], B_local)
T.gemm(A_shared, B_local, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_sr(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul_sr(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(program, out_idx=[2])
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_f16f16f16_sr():
run_gemm_sr(
512,
1024,
768,
False,
True,
"float16",
"float16",
"float16",
128,
256,
32,
2,
)
def matmul_rs(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
num_stages,
threads,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
A_local = T.alloc_fragment(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
T.copy(A[k * block_K, by * block_M], A_local)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(A[by * block_M, k * block_K], A_local)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_local, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm_rs(
M,
N,
K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
block_M,
block_N,
block_K,
num_stages=3,
num_threads=128,
):
program = matmul_rs(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_stages,
num_threads,
)
kernel = tilelang.compile(program, out_idx=[2])
print(kernel.get_kernel_source())
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_gemm_f16f16f16_rs():
run_gemm_rs(
512,
1024,
768,
True,
False,
"float16",
"float16",
"float16",
128,
128,
32,
0,
)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment