Commit f41c467c authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Refactor] Enhance TMA barrier validation and support for additional architectures (#463)

* Updated the TMA barrier validation in `inject_tma_barrier.cc` to check for non-empty `barrier_id_to_range_` before raising an error for missing `create_list_of_mbarrier`.
* Refactored architecture checks in `phase.py` to utilize a new constant `SUPPORTED_TMA_ARCHS`, allowing for easier updates and improved readability in the target architecture validation logic.
parent d946d1d4
......@@ -32,10 +32,10 @@
#include <tvm/tir/transform.h>
#include "../op/builtin.h"
#include "./common/attr.h"
#include "./common/collector.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "./common/collector.h"
#include "./common/attr.h"
namespace tvm {
namespace tl {
......@@ -192,13 +192,15 @@ private:
tma_op_to_barrier_id_.Set(tma_call, barrier_id);
}
auto const_int_bound = analyzer_.const_int_bound(thread_var_);
auto extent = const_int_bound->max_value - const_int_bound->min_value + 1;
auto extent =
const_int_bound->max_value - const_int_bound->min_value + 1;
UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent));
pending_tma_ops_.clear();
} else if (call->op.same_as(builtin::ptx_wait_barrier())) {
PrimExpr barrier_id = call->args[0];
auto const_int_bound = analyzer_.const_int_bound(thread_var_);
auto extent = const_int_bound->max_value - const_int_bound->min_value + 1;
auto extent =
const_int_bound->max_value - const_int_bound->min_value + 1;
UpdateBarrierRange(barrier_id, IntImm(DataType::Int(32), extent));
}
}
......@@ -237,25 +239,25 @@ public:
TmaBarrierCollector collector;
collector(f->body);
bool has_create_list_of_mbarrier = false;
PostOrderVisit(f->body, [&](const ObjectRef& node) {
if (const auto* call = node.as<CallNode>()) {
PostOrderVisit(f->body, [&](const ObjectRef &node) {
if (const auto *call = node.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier())) {
has_create_list_of_mbarrier = true;
}
}
});
TmaBarrierRewriter rewriter(analyzer, collector.tma_op_to_barrier_id(),
collector.barrier_id_to_range(), has_create_list_of_mbarrier);
collector.barrier_id_to_range(),
has_create_list_of_mbarrier);
f.CopyOnWrite()->body = rewriter(f->body);
return f;
}
private:
Stmt VisitStmt_(const BlockNode *op){
Stmt VisitStmt_(const BlockNode *op) {
auto block = GetRef<Block>(op);
if (!has_create_list_of_mbarrier_ && op->name_hint == MainBlockName) {
if (!has_create_list_of_mbarrier_ && barrier_id_to_range_.size() > 0 &&
op->name_hint == MainBlockName) {
ICHECK(false) << "Please declare create_list_of_mbarrier.";
}
return IRMutatorWithAnalyzer::VisitStmt_(op);
......
......@@ -4,12 +4,14 @@ import tilelang
from tilelang.transform import PassContext
from typing import Optional
SUPPORTED_TMA_ARCHS = {"sm_90", "sm_90a"}
def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
target: Optional[Target] = None) -> bool:
if pass_ctx is None:
pass_ctx = tilelang.transform.get_pass_context()
if target.arch not in {"sm_90", "sm_90a"}:
if target.arch not in SUPPORTED_TMA_ARCHS:
return False
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
disable_tma_lower = pass_ctx.config.get("tl.disable_tma_lower", False)
......@@ -18,7 +20,7 @@ def allow_tma_and_warp_specialized(pass_ctx: Optional[PassContext] = None,
def allow_fence_proxy(target: Optional[Target] = None) -> bool:
return target.arch in {"sm_90", "sm_90a"}
return target.arch in SUPPORTED_TMA_ARCHS
def allow_vectorize(pass_ctx: Optional[PassContext] = None) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment