Unverified Commit bd1c7b39 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Refactor] Use `has_simt_copy` to decide whether to insert `set_max_nreg` (#982)

parent 8f001e02
...@@ -136,8 +136,6 @@ def mqa_attn_return_logits( ...@@ -136,8 +136,6 @@ def mqa_attn_return_logits(
cu_k_s_min = T.alloc_local([1], index_dtype) cu_k_s_min = T.alloc_local([1], index_dtype)
cu_k_e_max = T.alloc_local([1], index_dtype) cu_k_e_max = T.alloc_local([1], index_dtype)
T.no_set_max_nreg()
cu_k_s_min[0] = 2147483647 cu_k_s_min[0] = 2147483647
cu_k_e_max[0] = -2147483648 cu_k_e_max[0] = -2147483648
......
...@@ -59,6 +59,27 @@ private: ...@@ -59,6 +59,27 @@ private:
bool warp_specialized_ = false; bool warp_specialized_ = false;
}; };
class SimtCopyDetector : public StmtExprVisitor {
public:
static bool Detect(const Stmt &stmt) {
SimtCopyDetector detector;
detector.VisitStmt(stmt);
return detector.has_simt_copy_;
}
private:
void VisitStmt_(const BufferStoreNode *op) final {
auto scope =
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer->data));
if (scope.to_string() != "global") {
has_simt_copy_ = true;
}
StmtExprVisitor::VisitStmt_(op);
}
bool has_simt_copy_{false};
};
class SetMaxNRegInjector : public StmtExprMutator { class SetMaxNRegInjector : public StmtExprMutator {
public: public:
static PrimFunc Inject(PrimFunc f) { static PrimFunc Inject(PrimFunc f) {
...@@ -113,9 +134,7 @@ private: ...@@ -113,9 +134,7 @@ private:
auto dec_reg_stmt = Evaluate(0); auto dec_reg_stmt = Evaluate(0);
// Only inject if we have valid register hints and no SIMT copy // Only inject if we have valid register hints and no SIMT copy
// For now, we assume no SIMT copy detection is available here bool has_simt_copy = SimtCopyDetector::Detect(producer_body);
// TODO: Add SIMT copy detection if needed
bool has_simt_copy = false; // Placeholder
if (dec_reg >= 0 && inc_reg >= 0 && !has_simt_copy) { if (dec_reg >= 0 && inc_reg >= 0 && !has_simt_copy) {
auto inc_reg_num = auto inc_reg_num =
......
...@@ -135,7 +135,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -135,7 +135,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.MultiVersionBuffer()(mod) mod = tilelang.transform.MultiVersionBuffer()(mod)
mod = tilelang.transform.WarpSpecialized()(mod) mod = tilelang.transform.WarpSpecialized()(mod)
mod = tilelang.transform.InjectTmaBarrier()(mod) mod = tilelang.transform.InjectTmaBarrier()(mod)
mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod)
# if tma is not enabled, we can also do pipeline planning # if tma is not enabled, we can also do pipeline planning
# to get better performance with async copy # to get better performance with async copy
mod = tilelang.transform.PipelinePlanning()(mod) mod = tilelang.transform.PipelinePlanning()(mod)
...@@ -206,6 +205,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -206,6 +205,8 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
# Inject PTX async copy must behind the thread sync pass # Inject PTX async copy must behind the thread sync pass
# as ptx async copy won't be recognized as a valid buffer load # as ptx async copy won't be recognized as a valid buffer load
mod = tilelang.transform.InjectPTXAsyncCopy()(mod) mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
if allow_tma_and_warp_specialized(pass_ctx=pass_ctx, target=target):
mod = tilelang.transform.AnnotateWarpGroupRegAlloc()(mod)
mod = tilelang.transform.MakePackedAPI()(mod) mod = tilelang.transform.MakePackedAPI()(mod)
mod = tilelang.transform.LowerDeviceKernelLaunch()(mod) mod = tilelang.transform.LowerDeviceKernelLaunch()(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment