Unverified Commit bd1c7b39 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Refactor] Use `has_simt_copy` to decide whether to insert `set_max_nreg` (#982)

parent 8f001e02
......@@ -136,8 +136,6 @@ def mqa_attn_return_logits(
cu_k_s_min = T.alloc_local([1], index_dtype)
cu_k_e_max = T.alloc_local([1], index_dtype)
T.no_set_max_nreg()
cu_k_s_min[0] = 2147483647
cu_k_e_max[0] = -2147483648
......
......@@ -59,6 +59,27 @@ private:
bool warp_specialized_ = false;
};
class SimtCopyDetector : public StmtExprVisitor {
public:
static bool Detect(const Stmt &stmt) {
SimtCopyDetector detector;
detector.VisitStmt(stmt);
return detector.has_simt_copy_;
}
private:
void VisitStmt_(const BufferStoreNode *op) final {
auto scope =
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data));
if (scope.to_string() != "global") {
has_simt_copy_ = true;
}
StmtExprVisitor::VisitStmt_(op);
}
bool has_simt_copy_{false};
};
class SetMaxNRegInjector : public StmtExprMutator {
public:
static PrimFunc Inject(PrimFunc f) {
......@@ -113,9 +134,7 @@ private:
auto dec_reg_stmt = Evaluate(0);
// Only inject if we have valid register hints and no SIMT copy
// For now, we assume no SIMT copy detection is available here
// TODO: Add SIMT copy detection if needed
bool has_simt_copy = false; // Placeholder
bool has_simt_copy = SimtCopyDetector::Detect(producer_body);
if (dec_reg >= 0 && inc_reg >= 0 && !has_simt_copy) {
auto inc_reg_num =
......
......@@ -135,7 +135,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.MultiVersionBuffer()(mod)
mod = tilelang.transform.WarpSpecialized()(mod)
mod = tilelang.transform.InjectTmaBarrier()(mod)
mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod)
# if tma is not enabled, we can also do pipeline planning
# to get better performance with async copy
mod = tilelang.transform.PipelinePlanning()(mod)
......@@ -206,6 +205,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# Inject PTX async copy must behind the thread sync pass
# as ptx async copy won't be recognized as a valid buffer load
mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target):
mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod)
mod = tilelang.transform.MakePackedAPI()(mod)
mod = tilelang.transform.LowerDeviceKernelLaunch()(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment