Commit e7b97be2 authored by Yu Cheng's avatar Yu Cheng Committed by LeiWang1999
Browse files

[Feature] Introduce Persistent Loop and Update GEMM Example (#563)

* [Feature] Added Support for Synchronizing Grids and Persistent Threadblock Transformation

- Defined the sync_grid operation in builtin.cc and builtin.h, allowing synchronization of all threads within a grid.
- Implemented support for sync_grid in codegen_cuda.cc, ensuring proper handling of this operation in the generated CUDA code.
- Added the PersistThreadblock transformation, enabling the conversion of thread blocks to persistent thread blocks, enhancing support for persistent kernels.
- Updated relevant documentation and comments to reflect the addition of new features and usage instructions.

* [Example] Add MLA Decode With Persistent Threadblock Example

* [Feature] Introduce Persistent Loop and Update GEMM Example

- Added a new persistent loop construct in the TIR framework, enabling more efficient kernel execution.
- Updated the GEMM example to utilize the new persistent primitive, enhancing performance for matrix multiplication.
- Introduced a `loop_break` intrinsic for better control flow within persistent loops.
- Updated relevant files to support the new features, including changes in code generation and language interface.

* lint fix
parent e5e36dbf
import torch
import torch.nn.functional as F
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
from tilelang.carver.arch import driver
from einops import rearrange, einsum
import argparse
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
dtype = "float16"
accum_dtype = "float"
kv_group_num = heads // kv_head_num
VALID_BLOCK_H = min(block_H, kv_group_num)
assert kv_head_num == 1, "kv_head_num must be 1"
sm_num = driver.get_num_sms()
@T.prim_func
def main_split_persistent(
Q: T.Tensor([batch, heads, dim], dtype),
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
glse: T.Tensor([batch, heads, num_split], dtype),
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
Output: T.Tensor([batch, heads, dim], dtype),
):
with T.Kernel(sm_num, threads=256) as (block_id):
Q_shared = T.alloc_shared([block_H, dim], dtype)
S_shared = T.alloc_shared([block_H, block_N], dtype)
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
KV_shared = T.alloc_shared([block_N, dim], dtype)
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
# O_shared = T.alloc_shared([block_H, dim], dtype)
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
scores_max = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
scores_scale = T.alloc_fragment([block_H], accum_dtype)
scores_sum = T.alloc_fragment([block_H], accum_dtype)
logsum = T.alloc_fragment([block_H], accum_dtype)
po_local = T.alloc_fragment([dim], dtype)
o_accum_local = T.alloc_fragment([dim], accum_dtype)
lse_local_split = T.alloc_local([1], accum_dtype)
lse_logsum_local = T.alloc_local([1], accum_dtype)
lse_max_local = T.alloc_local([1], accum_dtype)
scale_local = T.alloc_local([1], accum_dtype)
T.annotate_layout({
# O_shared: tilelang.layout.make_swizzled_layout(O_shared),
S_shared: tilelang.layout.make_swizzled_layout(S_shared),
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
})
T.use_swizzle(10)
total_tiles = batch * (heads // min(block_H, kv_group_num)) * num_split
waves = T.ceildiv(total_tiles, sm_num)
for w in T.serial(waves):
tile_id = sm_num * w + block_id
bid = tile_id // ((heads // min(block_H, kv_group_num)) * num_split)
hid = tile_id // num_split % (heads // min(block_H, kv_group_num))
sid = tile_id % num_split
cur_kv_head = hid // (kv_group_num // block_H)
if bid < batch and hid * VALID_BLOCK_H < heads and sid < num_split:
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
for k in T.Pipelined(loop_range, num_stages=2):
kv_start = (seqlen_kv // num_split) * sid + k * block_N
kv_end = (seqlen_kv // num_split) * sid + (k + 1) * block_N
T.copy(KV[bid, kv_start:kv_end, cur_kv_head, :], KV_shared)
T.copy(K_pe[bid, kv_start:kv_end, cur_kv_head, :], K_pe_shared)
T.clear(acc_s)
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.gemm(
Q_pe_shared,
K_pe_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
for i in T.Parallel(block_H):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale -
scores_max[i] * scale)
for i, j in T.Parallel(block_H, block_N):
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
T.copy(acc_s, S_shared)
T.copy(S_shared, acc_s_cast)
for i in T.Parallel(block_H):
logsum[i] = logsum[i] * sco
es_scale[i] + scores_sum[i]
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
for i in T.Parallel(block_H):
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, sid])
# T.copy(acc_o, O_shared)
T.copy(
acc_o, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H,
sid, :])
T.sync_grid()
waves = T.ceildiv(heads * batch, sm_num)
for w in T.serial(waves):
tile_id = sm_num * w + block_id
hid = tile_id // batch
bid = tile_id % batch
if bid < batch and hid < heads:
T.clear(lse_logsum_local)
T.clear(o_accum_local)
lse_max_local[0] = -T.infinity(accum_dtype)
for k in T.serial(num_split):
lse_max_local[0] = T.max(lse_max_local[0], glse[bid, hid, k])
for k in T.Pipelined(num_split, num_stages=1):
lse_local_split[0] = glse[bid, hid, k]
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
for k in T.serial(num_split):
for i in T.Parallel(dim):
po_local[i] = Output_partial[bid, hid, k, i]
lse_local_split[0] = glse[bid, hid, k]
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
for i in T.Parallel(dim):
o_accum_local[i] += po_local[i] * scale_local[0]
for i in T.Parallel(dim):
Output[bid, hid, i] = o_accum_local[i]
return main_split_persistent
def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
# """
# Inputs:
# - q (Tensor): [batch, heads, dim]
# - q_pe (Tensor): [batch, heads, pe_dim]
# - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
# - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
# - glse (Tensor): [batch, heads, num_split]
# - Output_partial (Tensor): [batch, heads, num_split, dim]
# Outputs:
# - output (Tensor): [batch, heads, dim]
# """
dim = q.shape[-1]
pe_dim = q_pe.shape[-1]
num_head_groups = q.shape[1] // kv.shape[2]
scale = (dim + pe_dim)**0.5
q = rearrange(
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
q_pe = rearrange(
q_pe, 'b (h g) d -> b g h d',
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
query = torch.concat([q, q_pe], dim=-1)
key = torch.concat([kv, k_pe], dim=-1)
scores = einsum(
query, key,
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
attention = F.softmax(
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
out = einsum(attention, kv,
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
return out
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
BLOCK_N = 64
BLOCK_H = 64
num_split = 2
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
kernel = tilelang.compile(program, out_idx=[6])
print(kernel.get_kernel_source())
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
latency = profiler.do_bench(warmup=500)
print(f"Latency: {latency} ms")
print(f"TFlops: {total_flops / latency * 1e-9} TFlops")
if __name__ == "__main__":
main()
...@@ -50,7 +50,8 @@ def matmul_persistent(M, ...@@ -50,7 +50,8 @@ def matmul_persistent(M,
threads, threads,
num_stages, num_stages,
dtype="float16", dtype="float16",
accum_dtype="float"): accum_dtype="float",
use_persistent_primitive=True):
sm_num = driver.get_num_sms() sm_num = driver.get_num_sms()
m_blocks = T.ceildiv(M, block_M) m_blocks = T.ceildiv(M, block_M)
...@@ -85,7 +86,30 @@ def matmul_persistent(M, ...@@ -85,7 +86,30 @@ def matmul_persistent(M,
T.copy(C_local, C_shared) T.copy(C_local, C_shared)
T.copy(C_shared, C[bx * block_M, by * block_N]) T.copy(C_shared, C[bx * block_M, by * block_N])
return main @T.prim_func
def main_persistent_primitive(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(sm_num, threads=threads) as (block_id):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), dtype)
for bx, by in T.Persistent(
[T.ceildiv(M, block_M), T.ceildiv(N, block_N)], sm_num, block_id):
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[bx * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, by * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[bx * block_M, by * block_N])
return main_persistent_primitive if use_persistent_primitive else main
def ref_program(A, B): def ref_program(A, B):
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
*/ */
#include "./transform/common/attr.h" #include "./transform/common/attr.h"
#include "op/builtin.h"
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/script/ir_builder/tir/ir.h> #include <tvm/script/ir_builder/tir/ir.h>
...@@ -104,6 +105,70 @@ ForFrame PipelinedFor(PrimExpr start, PrimExpr stop, int num_stages, ...@@ -104,6 +105,70 @@ ForFrame PipelinedFor(PrimExpr start, PrimExpr stop, int num_stages,
return ForFrame(n); return ForFrame(n);
} }
ForFrame PersistentFor(Array<PrimExpr> domain, PrimExpr wave_size,
PrimExpr index, PrimExpr group_size) {
using namespace tvm::tir;
ICHECK(domain.size() > 0);
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
n->vars.reserve(domain.size());
n->doms.reserve(domain.size());
PrimExpr domain_size = domain[0];
for (int i = 1; i < domain.size(); i++) {
domain_size *= domain[i];
}
auto waves = ceildiv(domain_size, wave_size);
auto loop_var = Var("w", waves.dtype());
group_size = min(group_size, domain[domain.size() - 1]);
Array<Var> coord_vars;
for (int i = 0; i < domain.size(); ++i) {
DataType dtype = domain[i].dtype();
Var coord("v" + std::to_string(i), dtype);
coord_vars.push_back(coord);
n->vars.push_back(coord);
n->doms.push_back(Range(make_const(dtype, 0), domain[i]));
}
Array<PrimExpr> grouped_domain;
grouped_domain.push_back(truncdiv(domain[domain.size() - 1], group_size));
for (int i = 0; i < domain.size() - 1; ++i) {
grouped_domain.push_back(domain[i]);
}
grouped_domain.push_back(group_size);
n->f_make_for_loop = [=](Array<Var> vars, Array<Range> doms,
Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size());
Map<String, ObjectRef> anno;
Array<PrimExpr> idxs(grouped_domain.size(), PrimExpr());
PrimExpr rem = loop_var * wave_size + index;
for (int i = grouped_domain.size() - 1; i >= 1; --i) {
idxs.Set(i, truncmod(rem, grouped_domain[i]));
rem = truncdiv(rem, grouped_domain[i]);
}
idxs.Set(0, rem);
auto out_if = tvm::tir::IfThenElse(
domain_size <= (loop_var * wave_size + index),
tvm::tir::Evaluate(
tvm::tir::Call(DataType::Handle(), tvm::tl::loop_break(), {})),
Stmt());
Stmt outer = For(loop_var, 0, waves, ForKind::kSerial,
SeqStmt({out_if, body}), NullOpt, anno);
for (int i = 0; i < vars.size() - 1; ++i) {
outer = tvm::tir::LetStmt(vars[i], idxs[i + 1], outer);
}
outer = tvm::tir::LetStmt(vars[vars.size() - 1],
idxs[0] * group_size + idxs[vars.size()], outer);
return outer;
};
return ForFrame(n);
}
/*! /*!
* \brief A frame that represents a kernel launch. * \brief A frame that represents a kernel launch.
* *
...@@ -202,11 +267,11 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size, ...@@ -202,11 +267,11 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size,
} }
if (attrs.defined()) { if (attrs.defined()) {
auto empty_block = Block(MainBlockName); auto empty_block = tvm::script::ir_builder::tir::Block(MainBlockName);
empty_block->annotations = attrs; empty_block->annotations = attrs;
n->frames.push_back(empty_block); n->frames.push_back(empty_block);
} else { } else {
n->frames.push_back(Block(MainBlockName)); n->frames.push_back(tvm::script::ir_builder::tir::Block(MainBlockName));
} }
return KernelLaunchFrame(n); return KernelLaunchFrame(n);
...@@ -216,6 +281,7 @@ TVM_REGISTER_NODE_TYPE(KernelLaunchFrameNode); ...@@ -216,6 +281,7 @@ TVM_REGISTER_NODE_TYPE(KernelLaunchFrameNode);
TVM_REGISTER_GLOBAL("tl.Parallel").set_body_typed(ParallelFor); TVM_REGISTER_GLOBAL("tl.Parallel").set_body_typed(ParallelFor);
TVM_REGISTER_GLOBAL("tl.Pipelined").set_body_typed(PipelinedFor); TVM_REGISTER_GLOBAL("tl.Pipelined").set_body_typed(PipelinedFor);
TVM_REGISTER_GLOBAL("tl.Persistent").set_body_typed(PersistentFor);
TVM_REGISTER_GLOBAL("tl.KernelLaunch").set_body_typed(KernelLaunch); TVM_REGISTER_GLOBAL("tl.KernelLaunch").set_body_typed(KernelLaunch);
class WarpSpecializeFrameNode : public TIRFrameNode { class WarpSpecializeFrameNode : public TIRFrameNode {
......
...@@ -119,5 +119,13 @@ TIR_DEFINE_TL_BUILTIN(wait_wgmma) ...@@ -119,5 +119,13 @@ TIR_DEFINE_TL_BUILTIN(wait_wgmma)
TIR_DEFINE_TL_BUILTIN(pack_b16).set_num_inputs(2).set_attr<TCallEffectKind>( TIR_DEFINE_TL_BUILTIN(pack_b16).set_num_inputs(2).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kPure)); "TCallEffectKind", Integer(CallEffectKind::kPure));
TIR_DEFINE_TL_BUILTIN(sync_grid).set_num_inputs(0).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(loop_break)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -208,6 +208,22 @@ const Op &no_set_max_nreg(); ...@@ -208,6 +208,22 @@ const Op &no_set_max_nreg();
*/ */
const Op &wait_wgmma(); const Op &wait_wgmma();
/*!
* \brief Synchronize all threads in a grid
*
* sync_grid()
*
*/
const Op &sync_grid();
/*!
* \brief tvm intrinsic for loop continue
*
* loop_break()
*
*/
const Op &loop_break();
/*! /*!
* \brief tvm intrinsic for amd matrix core mfma instructions. * \brief tvm intrinsic for amd matrix core mfma instructions.
* *
......
...@@ -119,6 +119,10 @@ std::string CodeGenTileLangCUDA::Finish() { ...@@ -119,6 +119,10 @@ std::string CodeGenTileLangCUDA::Finish() {
decl_stream << "#include <math_constants.h>\n"; decl_stream << "#include <math_constants.h>\n";
} }
if (need_cooperative_groups_) {
decl_stream << "#include <cooperative_groups.h>\n";
}
decl_stream << "#include <tl_templates/cuda/gemm.h>\n"; decl_stream << "#include <tl_templates/cuda/gemm.h>\n";
decl_stream << "#include <tl_templates/cuda/copy.h>\n"; decl_stream << "#include <tl_templates/cuda/copy.h>\n";
decl_stream << "#include <tl_templates/cuda/reduce.h>\n"; decl_stream << "#include <tl_templates/cuda/reduce.h>\n";
...@@ -891,6 +895,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -891,6 +895,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} else if (op->op.same_as(tl::pack_b16())) { } else if (op->op.same_as(tl::pack_b16())) {
os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", " os << "__pack_half2(" << this->PrintExpr(op->args[0]) << ", "
<< this->PrintExpr(op->args[1]) << ")"; << this->PrintExpr(op->args[1]) << ")";
} else if (op->op.same_as(tl::sync_grid())) {
this->need_cooperative_groups_ = true;
this->PrintIndent();
this->stream << "cooperative_groups::grid_group grid = "
"cooperative_groups::this_grid();\n";
this->PrintIndent();
this->stream << "grid.sync();\n";
} else if (op->op.same_as(tl::loop_break())) {
this->PrintIndent();
this->stream << "break;\n";
} else if (op->op.same_as(builtin::tvm_fill_fragment())) { } else if (op->op.same_as(builtin::tvm_fill_fragment())) {
need_mma_h_ = true; need_mma_h_ = true;
ICHECK_EQ(op->args.size(), 6U); ICHECK_EQ(op->args.size(), 6U);
......
...@@ -96,6 +96,8 @@ private: ...@@ -96,6 +96,8 @@ private:
bool need_mma_h_{false}; bool need_mma_h_{false};
// whether need cast_smem_ptr_to_int helper function // whether need cast_smem_ptr_to_int helper function
bool need_cast_smem_ptr_to_int_{false}; bool need_cast_smem_ptr_to_int_{false};
// whether need cooperative_groups.h
bool need_cooperative_groups_{false};
// Op attribute map // Op attribute map
OpAttrMap<bool> op_need_warp_shuffle_ = OpAttrMap<bool> op_need_warp_shuffle_ =
Op::GetAttrMap<bool>("cuda.need_warp_shuffle"); Op::GetAttrMap<bool>("cuda.need_warp_shuffle");
......
/*!
* \file lower_l2_persistent_annotation.cc
* \brief Lower L2 persistent annotation
*/
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
#include "../op/bulk_copy.h"
#include "../runtime/runtime.h"
namespace tvm {
namespace tl {
namespace attr {
// BlockAttr, Containing the layout for all the buffers in the block
constexpr const char *kUseCooperativeGroups = "use_cooperative_groups";
} // namespace attr
using namespace tir;
class PersistThreadblock : public StmtExprMutator {
public:
static PrimFunc Substitute(PrimFunc &f) {
PrimFuncNode *fptr = f.CopyOnWrite();
PersistThreadblock substituter;
// Trace the buffer map for tvm_access_ptr
fptr->body = substituter.VisitStmt(f->body);
if (substituter.has_sync_grid_) {
f = WithAttr(std::move(f), attr::kUseCooperativeGroups,
IntImm(DataType::Int(32), 1));
}
return f;
}
Stmt VisitStmt_(const EvaluateNode *op) final {
if (const auto *call = op->value.as<CallNode>()) {
if (call->op.same_as(sync_grid())) {
has_sync_grid_ = true;
}
}
return StmtExprMutator::VisitStmt_(op);
}
private:
PersistThreadblock() = default;
bool has_sync_grid_ = false;
};
using namespace tir::transform;
tvm::transform::Pass PersistThreadblock() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return PersistThreadblock::Substitute(f);
};
return CreatePrimFuncPass(pass_func, 0, "tl.PersistThreadblock", {});
}
TVM_REGISTER_GLOBAL("tl.transform.PersistThreadblock")
.set_body_typed(PersistThreadblock);
} // namespace tl
} // namespace tvm
...@@ -104,6 +104,9 @@ public: ...@@ -104,6 +104,9 @@ public:
role = Role::kProducer; role = Role::kProducer;
has_bulk_copy_ = true; has_bulk_copy_ = true;
} }
if (call->op.same_as(loop_break())) {
role = Role::kBoth;
}
} }
SetRole(op, role); SetRole(op, role);
} }
......
...@@ -165,5 +165,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -165,5 +165,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.MakePackedAPI()(mod) mod = tilelang.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod) mod = tir.transform.LowerDeviceKernelLaunch()(mod)
# Transform threadblock to persistent threadblock
mod = tilelang.transform.PersistThreadblock()(mod)
return mod return mod
...@@ -309,7 +309,6 @@ class TLCUDASourceWrapper(object): ...@@ -309,7 +309,6 @@ class TLCUDASourceWrapper(object):
# Identify the start of the function body to insert arguments # Identify the start of the function body to insert arguments
index = code.index("{", index) index = code.index("{", index)
call_args = ", ".join(func_call_args(declaration, function_args, desc_name_map))
block_str = "dim3({}, {}, {})".format( block_str = "dim3({}, {}, {})".format(
legalize_c(block_info[0]), legalize_c(block_info[0]),
...@@ -321,9 +320,20 @@ class TLCUDASourceWrapper(object): ...@@ -321,9 +320,20 @@ class TLCUDASourceWrapper(object):
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
init_l2_persistent_map = self.generate_l2_persistent_map(function_name) init_l2_persistent_map = self.generate_l2_persistent_map(function_name)
kernel_launch_code += init_l2_persistent_map kernel_launch_code += init_l2_persistent_map
kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(
function_name, grid_str, block_str, smem_str, call_args) if self.use_cooperative_groups[function_name]:
kernel_launch_code += "\tTILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name) args_list = func_call_args(declaration, function_args, desc_name_map)
args_array = [f"(void*)&{arg}" for arg in args_list]
call_args = f"\tvoid* {function_name}_args[] = {{{', '.join(args_array)}}};\n"
kernel_launch_code += call_args
# Using cudaLaunchCooperativeKernel to launch the kernel
kernel_launch_code += "\tTILELANG_CHECK(cudaLaunchCooperativeKernel((void*){}, {}, {}, {}, {}, stream));\n".format(
function_name, grid_str, block_str, function_name + "_args", smem_str)
else:
call_args = ", ".join(func_call_args(declaration, function_args, desc_name_map))
kernel_launch_code += "\t{}<<<{}, {}, {}, stream>>>({});\n".format(
function_name, grid_str, block_str, smem_str, call_args)
kernel_launch_code += "\tTILELANG_CHECK_LAST_ERROR(\"{}\");\n".format(function_name)
if has_l2_persistent_map: if has_l2_persistent_map:
kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE kernel_launch_code += L2_PERSISTENT_MAP_RESET_HANDLE
...@@ -427,6 +437,7 @@ class TLCUDASourceWrapper(object): ...@@ -427,6 +437,7 @@ class TLCUDASourceWrapper(object):
grid_info_map = {} grid_info_map = {}
dynamic_smem_buf_map = {} dynamic_smem_buf_map = {}
function_names = [] function_names = []
use_cooperative_groups_map = {}
for g_var, func in self.device_mod.functions.items(): for g_var, func in self.device_mod.functions.items():
# Default block and grid configurations # Default block and grid configurations
block_info = [1, 1, 1] block_info = [1, 1, 1]
...@@ -434,6 +445,9 @@ class TLCUDASourceWrapper(object): ...@@ -434,6 +445,9 @@ class TLCUDASourceWrapper(object):
function_name = g_var.name_hint function_name = g_var.name_hint
attrs = func.attrs attrs = func.attrs
dynamic_smem_buf = None dynamic_smem_buf = None
use_cooperative_groups = False
if "use_cooperative_groups" in attrs:
use_cooperative_groups = attrs["use_cooperative_groups"]
if "dyn_shared_memory_buf" in attrs: if "dyn_shared_memory_buf" in attrs:
dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"])
if "thread_extent" in attrs: if "thread_extent" in attrs:
...@@ -448,12 +462,14 @@ class TLCUDASourceWrapper(object): ...@@ -448,12 +462,14 @@ class TLCUDASourceWrapper(object):
block_info_map[function_name] = block_info block_info_map[function_name] = block_info
grid_info_map[function_name] = grid_info grid_info_map[function_name] = grid_info
dynamic_smem_buf_map[function_name] = dynamic_smem_buf dynamic_smem_buf_map[function_name] = dynamic_smem_buf
use_cooperative_groups_map[function_name] = use_cooperative_groups
function_names.append(function_name) function_names.append(function_name)
# Store the mappings for use in code generation # Store the mappings for use in code generation
self.block_info = block_info_map self.block_info = block_info_map
self.grid_info = grid_info_map self.grid_info = grid_info_map
self.dynamic_smem_buf = dynamic_smem_buf_map self.dynamic_smem_buf = dynamic_smem_buf_map
self.use_cooperative_groups = use_cooperative_groups_map
function_names_index = {} function_names_index = {}
for _, func in self.host_mod.functions.items(): for _, func in self.host_mod.functions.items():
......
...@@ -23,6 +23,7 @@ from .proxy import ( ...@@ -23,6 +23,7 @@ from .proxy import (
) )
from .parallel import Parallel # noqa: F401 from .parallel import Parallel # noqa: F401
from .pipeline import Pipelined # noqa: F401 from .pipeline import Pipelined # noqa: F401
from .persistent import Persistent # noqa: F401
from .frame import has_let_value, get_let_value # noqa: F401 from .frame import has_let_value, get_let_value # noqa: F401
from .kernel import ( from .kernel import (
Kernel, # noqa: F401 Kernel, # noqa: F401
......
...@@ -322,3 +322,9 @@ def sync_global(): ...@@ -322,3 +322,9 @@ def sync_global():
print(tx, ty, tz, ex, ey, ez) print(tx, ty, tz, ex, ey, ez)
args = ["global", tx == 0 and ty == 0 and tz == 0, ex * ey * ez] args = ["global", tx == 0 and ty == 0 and tz == 0, ex * ey * ez]
return evaluate(tir.Call("handle", "tir.tvm_storage_sync", args)) return evaluate(tir.Call("handle", "tir.tvm_storage_sync", args))
def sync_grid():
"""Synchronize all threads in a grid.
"""
return tir.call_intrin("handle", tir.op.Op.get("tl.sync_grid"))
"""The language interface for tl programs."""
from typing import List, Optional
from tvm import tir
from tilelang import _ffi_api
def Persistent(
domain: List[tir.PrimExpr],
wave_size: tir.PrimExpr,
index: tir.PrimExpr,
group_size: Optional[tir.PrimExpr] = 8,
):
"""Tools to construct persistent for loop.
Parameters
----------
domain : List[tir.PrimExpr]
The list of dominators.
wave_size : int
The wave size.
index : int
The tile index in one wave.
group_size : tir.PrimExpr
The group size.
"""
return _ffi_api.Persistent(domain, wave_size, index, group_size)
...@@ -348,3 +348,9 @@ def LowerL2Persistent(): ...@@ -348,3 +348,9 @@ def LowerL2Persistent():
"""LowerL2Persistent """LowerL2Persistent
""" """
return _ffi_api.LowerL2Persistent() # type: ignore return _ffi_api.LowerL2Persistent() # type: ignore
def PersistThreadblock():
"""PersistThreadblock
"""
return _ffi_api.PersistThreadblock() # type: ignore
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment