"...composable_kernel_rocm.git" did not exist on "f42ebfdf0b35cf5437ebb3743d81e5b9b64705d9"
Commit 225aca61 authored by Yu Cheng's avatar Yu Cheng Committed by LeiWang1999
Browse files

[Feature] Support persistent kernels and add persistent GEMM examples (#559)

* [Enhancement] Fix multi-version buffer index in nested-loop

* [Feature] Support persistent kernels and add persistent GEMM example

* lint fix

* lint fix

* [CI] Remove test_tilelang_transform_annotate_device_regions.py
parent 8cc8db52
import tilelang
import tilelang.language as T
from tilelang.carver.arch import driver
import argparse
def matmul_non_persistent(M,
N,
K,
block_M,
block_N,
block_K,
threads,
num_stages,
dtype="float16",
accum_dtype="float"):
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.use_swizzle(10)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[bx * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, by * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[bx * block_M, by * block_N])
return main
def matmul_persistent(M,
N,
K,
block_M,
block_N,
block_K,
threads,
num_stages,
dtype="float16",
accum_dtype="float"):
sm_num = driver.get_num_sms()
m_blocks = T.ceildiv(M, block_M)
n_blocks = T.ceildiv(N, block_N)
waves = T.ceildiv(m_blocks * n_blocks, sm_num)
group_size = 8
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(sm_num, threads=threads) as (block_id):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
for w in T.serial(waves):
tile_id = sm_num * w + block_id
bx = (tile_id // group_size) % m_blocks
by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size
if bx * block_M < M and by * block_N < N:
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[bx * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, by * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[bx * block_M, by * block_N])
return main
def ref_program(A, B):
return A @ B
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--M', type=int, default=8192, help='M dimension')
parser.add_argument('--N', type=int, default=8192, help='N dimension')
parser.add_argument('--K', type=int, default=8192, help='K dimension')
args = parser.parse_args()
M, N, K = args.M, args.N, args.K
total_flops = 2 * M * N * K
BLOCK_M = 128
BLOCK_N = 256
BLOCK_K = 64
threads = 256
num_stages = 3
persistent_program = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages)
persistent_kernel = tilelang.compile(persistent_program, out_idx=-1)
persistent_profiler = persistent_kernel.get_profiler(
tensor_supply_type=tilelang.TensorSupplyType.Randn)
persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("Persistent GEMM: All check passed.")
persistent_latency = persistent_profiler.do_bench(warmup=500)
print(f"Persistent GEMM Latency: {persistent_latency} ms")
print(f"Persistent GEMM TFlops: {total_flops / persistent_latency * 1e-9} TFlops")
non_persistent_program = matmul_non_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads,
num_stages)
non_persistent_kernel = tilelang.compile(non_persistent_program, out_idx=-1)
non_persistent_profiler = non_persistent_kernel.get_profiler(
tensor_supply_type=tilelang.TensorSupplyType.Randn)
non_persistent_profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("Non-Persistent GEMM: All check passed.")
non_persistent_latency = non_persistent_profiler.do_bench(warmup=500)
print(f"Non-Persistent GEMM Latency: {non_persistent_latency} ms")
print(f"Non-Persistent GEMM TFlops: {total_flops / non_persistent_latency * 1e-9} TFlops")
print(f"Persistent GEMM Speedup: {non_persistent_latency / persistent_latency}")
if __name__ == "__main__":
main()
......@@ -218,9 +218,13 @@ private:
}
Stmt VisitStmt_(const ForNode *op) final {
loop_stack_.emplace_back(op->loop_var, op->extent);
auto num_stages_anno = op->annotations.Get("num_stages");
if (!num_stages_anno.defined())
return StmtExprMutator::VisitStmt_(op);
if (!num_stages_anno.defined()) {
auto for_node = StmtExprMutator::VisitStmt_(op);
loop_stack_.pop_back();
return for_node;
}
ICHECK(num_stages_anno.as<IntImmNode>());
int num_stages = static_cast<int>(num_stages_anno.as<IntImmNode>()->value);
......@@ -244,8 +248,14 @@ private:
Buffer new_buffer = RewriteAllocBuffer(buffer, num_stages);
buffer_remap_.Set(buffer, new_buffer);
}
version_index_ = FloorMod(op->loop_var - op->min, num_stages);
PrimExpr linear_index = loop_stack_[0].first;
for (size_t i = 1; i < loop_stack_.size(); ++i) {
linear_index =
linear_index * loop_stack_[i].second + loop_stack_[i].first;
}
version_index_ = FloorMod(linear_index, num_stages);
auto for_node = StmtExprMutator::VisitStmt_(op);
loop_stack_.pop_back();
return for_node;
}
......@@ -315,6 +325,7 @@ private:
}
PrimExpr version_index_;
std::vector<std::pair<Var, PrimExpr>> loop_stack_;
Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, Optional<Stmt>> buffer_lca_;
Map<Buffer, Buffer> buffer_remap_;
......
......@@ -629,6 +629,7 @@ private:
num_stages = static_cast<int>(num_stages_anno.as<IntImmNode>()->value);
ICHECK(num_stages_ == 1) << "Nested pipeline not supported.";
}
loop_stack_.emplace_back(op->loop_var, op->extent);
Array<Array<Integer>> group_info_array;
Array<Integer> order_info_array;
......@@ -661,10 +662,14 @@ private:
num_stages_ = num_stages;
pipeline_info_ = pipeline_info;
stage_ = FloorMod(op->loop_var - op->min, num_stages);
parity_ = FloorMod(parity_before * op->extent +
FloorDiv(op->loop_var - op->min, num_stages),
2);
PrimExpr linear_index = loop_stack_[0].first;
for (size_t i = 1; i < loop_stack_.size(); ++i) {
linear_index =
linear_index * loop_stack_[i].second + loop_stack_[i].first;
}
stage_ = FloorMod(linear_index, num_stages);
parity_ = FloorMod(
parity_before * op->extent + FloorDiv(linear_index, num_stages), 2);
auto result = FilterByRole(op);
......@@ -692,10 +697,13 @@ private:
}
if (is_emitting_producer_ || !group_anno.defined() ||
group_info_array.size() == 0) {
loop_stack_.pop_back();
return for_node;
}
loop_stack_.pop_back();
return grouped_for_node;
}
loop_stack_.pop_back();
return result;
}
......@@ -908,6 +916,7 @@ private:
PrimExpr parity_ = 0;
PrimExpr stage_ = 0;
int num_stages_ = 1;
std::vector<std::pair<Var, PrimExpr>> loop_stack_;
Var thread_var_;
bool mbarrier_only_ = false;
PipelineInfo pipeline_info_;
......
import tilelang
import tilelang.testing
from tvm.script import tir as T
class BaseCompare(tilelang.testing.CompareBeforeAfter):
transform = tilelang.transform.AnnotateDeviceRegions()
class TestAnnotateThreadExtent(BaseCompare):
"""Annotation inserted at the "thread_extent" attribute"""
def before(A: T.Buffer(16, "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
i = T.launch_thread("threadIdx.x", 16)
A[i] = 0.0
def expected(A: T.Buffer(16, "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
T.attr(T.target("cuda"), "target", 0)
i = T.launch_thread("threadIdx.x", 16)
A[i] = 0.0
class TestAnnotateDeviceScope(BaseCompare):
"""Annotation inserted at the "device_scope" attribute"""
def before(A: T.Buffer(1, "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
T.attr(0, "device_scope", 0)
A[0] = 0.0
def expected(A: T.Buffer(1, "float32")):
T.func_attr({"target": T.target("cuda", host="llvm")})
T.attr(T.target("cuda"), "target", 0)
T.attr(0, "device_scope", 0)
A[0] = 0.0
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment