"...composable_kernel.git" did not exist on "f7e05f9efcbd777dc47bed6c3ca4ebef3dea9b47"
Commit f41c467c authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Enhance TMA barrier validation and support for additional architectures (#463)

* Updated the TMA barrier validation in `inject_tma_barrier.cc` to check for non-empty `barrier_id_to_range_` before raising an error for missing `create_list_of_mbarrier`.
* Refactored architecture checks in `phase.py` to utilize a new constant `SUPPORTED_TMA_ARCHS`, allowing for easier updates and improved readability in the target architecture validation logic.
parent d946d1d4
...@@ -32,10 +32,10 @@ ...@@ -32,10 +32,10 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include "../op/builtin.h" #include "../op/builtin.h"
#include "./common/attr.h"
#include "./common/collector.h"
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h" #include "arith/ir_visitor_with_analyzer.h"
#include "./common/collector.h"
#include "./common/attr.h"
namespace tvm { namespace tvm {
namespace tl { namespace tl {
...@@ -192,13 +192,15 @@ private: ...@@ -192,13 +192,15 @@ private:
tma_op_to_barrier_id_.Set(tma_call, barrier_id); tma_op_to_barrier_id_.Set(tma_call, barrier_id);
} }
auto const_int_bound = analyzer_.const_int_bound(thread_var_); auto const_int_bound = analyzer_.const_int_bound(thread_var_);
auto extent = const_int_bound->max_value - const_int_bound->min_value + 1; auto extent =
const_int_bound->max_value - const_int_bound->min_value + 1;
UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent)); UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent));
pending_tma_ops_.clear(); pending_tma_ops_.clear();
} else if (call->op.same_as(builtin::ptx_wait_barrier())) { } else if (call->op.same_as(builtin::ptx_wait_barrier())) {
PrimExpr barrier_id = call->args[0]; PrimExpr barrier_id = call->args[0];
auto const_int_bound = analyzer_.const_int_bound(thread_var_); auto const_int_bound = analyzer_.const_int_bound(thread_var_);
auto extent = const_int_bound->max_value - const_int_bound->min_value + 1; auto extent =
const_int_bound->max_value - const_int_bound->min_value + 1;
UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent)); UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent));
} }
} }
...@@ -237,25 +239,25 @@ public: ...@@ -237,25 +239,25 @@ public:
TmaBarrierCollector collector; TmaBarrierCollector collector;
collector(f->body); collector(f->body);
bool has_create_list_of_mbarrier = false; bool has_create_list_of_mbarrier = false;
PostOrderVisit(f->body, [&](const ObjectRef& node) { PostOrderVisit(f->body, [&](const ObjectRef &node) {
if (const auto* call = node.as<CallNode>()) { if (const auto *call = node.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier())) { if (call->op.same_as(create_list_of_mbarrier())) {
has_create_list_of_mbarrier = true; has_create_list_of_mbarrier = true;
} }
} }
}); });
TmaBarrierRewriter rewriter(analyzer, collector.tma_op_to_barrier_id(), TmaBarrierRewriter rewriter(analyzer, collector.tma_op_to_barrier_id(),
collector.barrier_id_to_range(), has_create_list_of_mbarrier); collector.barrier_id_to_range(),
has_create_list_of_mbarrier);
f.CopyOnWrite()->body = rewriter(f->body); f.CopyOnWrite()->body = rewriter(f->body);
return f; return f;
} }
private: private:
Stmt VisitStmt_(const BlockNode *op) {
Stmt VisitStmt_(const BlockNode *op){
auto block = GetRef<Block>(op); auto block = GetRef<Block>(op);
if (!has_create_list_of_mbarrier_ && op->name_hint == MainBlockName) { if (!has_create_list_of_mbarrier_ && barrier_id_to_range_.size() > 0 &&
op->name_hint == MainBlockName) {
ICHECK(false) << "Please declare create_list_of_mbarrier."; ICHECK(false) << "Please declare create_list_of_mbarrier.";
} }
return IRMutatorWithAnalyzer::VisitStmt_(op); return IRMutatorWithAnalyzer::VisitStmt_(op);
......
...@@ -4,12 +4,14 @@ import tilelang ...@@ -4,12 +4,14 @@ import tilelang
from tilelang.transform import PassContext from tilelang.transform import PassContext
from typing import Optional from typing import Optional
SUPPORTED_TMA_ARCHS = {"sm_90", "sm_90a"}
def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None, def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool: target: Optional[Target] = None) -> bool:
if pass_ctx is None: if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context() pass_ctx = tilelang.transform.get_pass_context()
if target.arch not in {"sm_90", "sm_90a"}: if target.arch not in SUPPORTED_TMA_ARCHS:
return False return False
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False) disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False) disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
...@@ -18,7 +20,7 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None, ...@@ -18,7 +20,7 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
def allow_fence_proxy(target: Optional[Target] = None) -> bool: def allow_fence_proxy(target: Optional[Target] = None) -> bool:
return target.arch in {"sm_90", "sm_90a"} return target.arch in SUPPORTED_TMA_ARCHS
def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool: def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment