Commit 0ff81755 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Support tf32 gemm_rs (#607)

- Added a line break in `quickstart.py` for better readability.
- Simplified the JIT kernel compilation in `quickstart.py` by removing the unused execution backend option.
- Modified `example_elementwise_add.py` to disable cache for `tilelang` and optimized the element-wise addition kernel by utilizing shared memory for input tensors, improving performance.
- Updated default values for matrix dimensions and block sizes in the argument parser to enhance usability.
parent 8df45c9d
...@@ -5,6 +5,8 @@ import tilelang ...@@ -5,6 +5,8 @@ import tilelang
import tilelang.language as T import tilelang.language as T
from tilelang.autotuner import AutoTuner from tilelang.autotuner import AutoTuner
tilelang.disable_cache()
def ref_program(x, y): def ref_program(x, y):
return x + y return x + y
...@@ -17,12 +19,17 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): ...@@ -17,12 +19,17 @@ def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads):
def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor( def elem_add(A: T.Tensor((M, N), in_dtype), B: T.Tensor((M, N), in_dtype), C: T.Tensor(
(M, N), out_dtype)): (M, N), out_dtype)):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
start_x = bx * block_N A_shared = T.alloc_shared((block_M, block_N), in_dtype)
start_y = by * block_M B_shared = T.alloc_shared((block_M, block_N), in_dtype)
C_local = T.alloc_fragment((block_M, block_N), out_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
T.copy(A[by * block_M, bx * block_N], A_shared)
T.copy(B[by * block_M, bx * block_N], B_shared)
for (local_y, local_x) in T.Parallel(block_M, block_N): for (local_y, local_x) in T.Parallel(block_M, block_N):
y = start_y + local_y C_local[local_y, local_x] = A_shared[local_y, local_x] + B_shared[local_y, local_x]
x = start_x + local_x T.copy(C_local, C_shared)
C[y, x] = A[y, x] + B[y, x] T.copy(C_shared, C[by * block_M, bx * block_N])
return elem_add return elem_add
...@@ -54,7 +61,7 @@ def get_best_config(M, N): ...@@ -54,7 +61,7 @@ def get_best_config(M, N):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--m", type=int, default=512) parser.add_argument("--m", type=int, default=1024)
parser.add_argument("--n", type=int, default=1024) parser.add_argument("--n", type=int, default=1024)
parser.add_argument("--use_autotune", action="store_true", default=False) parser.add_argument("--use_autotune", action="store_true", default=False)
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
...@@ -68,9 +75,8 @@ def main(): ...@@ -68,9 +75,8 @@ def main():
kernel = result.kernel kernel = result.kernel
else: else:
# Default config # Default config
config = {"block_M": 128, "block_N": 256, "threads": 128} config = {"block_M": 128, "block_N": 128, "threads": 128}
kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32")
out = kernel(a, b) out = kernel(a, b)
torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2)
......
import tilelang import tilelang
import tilelang.language as T import tilelang.language as T
# `make_mma_swizzle_layout` is a python defined layout function # `make_mma_swizzle_layout` is a python defined layout function
# specifically designed for MMA operations # specifically designed for MMA operations
# which ensures the consistency with the nvidia CUTLASS Library. # which ensures the consistency with the nvidia CUTLASS Library.
...@@ -71,8 +72,7 @@ func = matmul(M, N, K, block_M, block_N, block_K) ...@@ -71,8 +72,7 @@ func = matmul(M, N, K, block_M, block_N, block_K)
# out_idx specifies the index of the output buffer in the argument list # out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime # if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu". # target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", execution_backend="cython") jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")
# jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", execution_backend="dlpack")
# 3. Test the kernel in Python with PyTorch data # 3. Test the kernel in Python with PyTorch data
import torch import torch
......
...@@ -18,6 +18,15 @@ static IterVar make_itervar(std::string name, PrimExpr dom) { ...@@ -18,6 +18,15 @@ static IterVar make_itervar(std::string name, PrimExpr dom) {
return IterVar(Range(0, dom), var, IterVarType::kDataPar); return IterVar(Range(0, dom), var, IterVarType::kDataPar);
} }
Fragment makeGemmFragment8x4() {
IterVar i = make_itervar("i", 8);
IterVar j = make_itervar("j", 4);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = FloorDiv(j->var, 1) + 4 * i;
PrimExpr index = FloorMod(j->var, 1);
return Fragment({i, j}, {index}, forward_thread, rep);
}
Fragment makeGemmFragment8x8() { Fragment makeGemmFragment8x8() {
IterVar i = make_itervar("i", 8); IterVar i = make_itervar("i", 8);
IterVar j = make_itervar("j", 8); IterVar j = make_itervar("j", 8);
...@@ -26,6 +35,25 @@ Fragment makeGemmFragment8x8() { ...@@ -26,6 +35,25 @@ Fragment makeGemmFragment8x8() {
PrimExpr index = FloorMod(j->var, 2); PrimExpr index = FloorMod(j->var, 2);
return Fragment({i, j}, {index}, forward_thread, rep); return Fragment({i, j}, {index}, forward_thread, rep);
} }
Fragment makeGemmFragment8x16() {
IterVar i = make_itervar("i", 8);
IterVar j = make_itervar("j", 16);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = FloorDiv(j->var, 4) + 4 * i;
PrimExpr index = FloorMod(j->var, 4);
return Fragment({i, j}, {index}, forward_thread, rep);
}
Fragment makeGemmFragment8x8Transposed() {
IterVar i = make_itervar("i", 8);
IterVar j = make_itervar("j", 8);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = FloorDiv(i->var, 2) + 4 * j;
PrimExpr index = FloorMod(i->var, 2);
return Fragment({i, j}, {index}, forward_thread, rep);
}
/* /*
From https://github.com/RadeonOpenCompute/amd_matrix_instruction_calculator From https://github.com/RadeonOpenCompute/amd_matrix_instruction_calculator
./matrix_calculator.py --architecture cdna1 --instruction v_mfma_f32_16x16x16f16 ./matrix_calculator.py --architecture cdna1 --instruction v_mfma_f32_16x16x16f16
...@@ -58,24 +86,6 @@ Fragment makeGemmFragmentC16x16CDNA() { ...@@ -58,24 +86,6 @@ Fragment makeGemmFragmentC16x16CDNA() {
return Fragment({i, j}, {index}, forward_thread, rep); return Fragment({i, j}, {index}, forward_thread, rep);
} }
Fragment makeGemmFragment8x8Transposed() {
IterVar i = make_itervar("i", 8);
IterVar j = make_itervar("j", 8);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = FloorDiv(i->var, 2) + 4 * j;
PrimExpr index = FloorMod(i->var, 2);
return Fragment({i, j}, {index}, forward_thread, rep);
}
Fragment makeGemmFragment8x16() {
IterVar i = make_itervar("i", 8);
IterVar j = make_itervar("j", 16);
IterVar rep = make_itervar("rep", 1);
PrimExpr forward_thread = FloorDiv(j->var, 4) + 4 * i;
PrimExpr index = FloorMod(j->var, 4);
return Fragment({i, j}, {index}, forward_thread, rep);
}
Fragment makeGemmFragmentC_F64(const int block_m, const int block_n, Fragment makeGemmFragmentC_F64(const int block_m, const int block_n,
const int warp_m, const int warp_n) { const int warp_m, const int warp_n) {
ICHECK(block_m % warp_m == 0); ICHECK(block_m % warp_m == 0);
...@@ -147,8 +157,8 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n, ...@@ -147,8 +157,8 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n,
ICHECK(warp_m % 16 == 0); ICHECK(warp_m % 16 == 0);
ICHECK(block_k % 16 == 0); ICHECK(block_k % 16 == 0);
// Only support 8-bit and 16-bit // Only support 8-bit and 16-bit
ICHECK(element_size == 8 || element_size == 16) ICHECK(element_size == 8 || element_size == 16 || element_size == 32)
<< "element bitwidth=" << element_size; << "unsupported element bitwidth=" << element_size;
if (transposed) { if (transposed) {
auto base_layout = auto base_layout =
...@@ -173,6 +183,13 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n, ...@@ -173,6 +183,13 @@ Fragment makeGemmFragmentA(const int block_m, const int block_n,
auto block_layout = auto block_layout =
warp_layout->Repeat({warp_m / 16, block_k / 16}, false, false); warp_layout->Repeat({warp_m / 16, block_k / 16}, false, false);
return block_layout; return block_layout;
} else if (element_size == 32) {
auto base_layout = makeGemmFragment8x4()->Repeat({2, 2}, false, false);
auto warp_layout = base_layout->Repeat({block_m / warp_m, 1}, true)
->Replicate(block_n / warp_n);
auto block_layout =
warp_layout->Repeat({warp_m / 16, block_k / 8}, false, false);
return block_layout;
} else { } else {
ICHECK(0); ICHECK(0);
return Fragment(); return Fragment();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment