Unverified Commit 7a80b6df authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Enable code lowering with producer‑copy‑only program (#1168)

* bugfix

* lint fix

* Enhance warp group register allocation to handle missing consumer bodies gracefully. Updated logic to annotate producer side when consumer is absent, ensuring robustness in degenerate warp-specialized patterns.

* Refactor VisitExpr_ method in inject_tma_barrier.cc for improved readability. Adjusted formatting and spacing for clarity in barrier handling logic.

* Update barrier handling in inject_tma_barrier.cc to accommodate newly appended entries. Adjusted the size of the replace vector to ensure it covers the full needed length, and modified the logic for appending barriers based on the updated replace conditions.
parent 10911e28
......@@ -124,7 +124,9 @@ private:
}
auto producer_body = if_then_else->then_case;
Optional<Stmt> consumer_body = if_then_else->else_case;
ICHECK(consumer_body.defined()) << "Consumer body is undefined";
// In some degenerate warp-specialized patterns (e.g., producer-only),
// the consumer body may be absent. Handle gracefully by only annotating
// the producer side when consumer is missing.
auto dec_reg = nreg_[0].as<IntImmNode>()->value;
auto inc_reg = nreg_[1].as<IntImmNode>()->value;
......@@ -150,15 +152,20 @@ private:
producer_stmts.push_back(producer_body);
auto new_producer_body = SeqStmt(producer_stmts);
Stmt new_if_stmt;
if (consumer_body.defined()) {
Array<Stmt> consumer_stmts;
consumer_stmts.push_back(inc_reg_stmt);
consumer_stmts.push_back(consumer_body.value());
auto new_consumer_body = SeqStmt(consumer_stmts);
auto new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body,
new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body,
new_consumer_body);
auto new_attr = AttrStmt(op->node, op->attr_key, op->value, new_if_stmt);
} else {
// No consumer branch; keep the if-then form.
new_if_stmt = IfThenElse(if_then_else->condition, new_producer_body);
}
auto new_attr = AttrStmt(op->node, op->attr_key, op->value, new_if_stmt);
return new_attr;
} else {
return StmtExprMutator::VisitStmt_(op);
......
......@@ -295,14 +295,15 @@ public:
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(mbarrier_expect_tx())) {
PrimExpr e = tma_op_to_barrier_id_[tvm::ffi::GetRef<Call>(op)]
.as<CallNode>()
->args[0];
auto call_ref = tvm::ffi::GetRef<Call>(op);
if (tma_op_to_barrier_id_.count(call_ref)) {
PrimExpr e = tma_op_to_barrier_id_[call_ref].as<CallNode>()->args[0];
auto int_set = arith::EvalSet(e, var_int_set_);
expect_.push_back(if_depth_ == 1);
sequence.push_back(0);
int_sets_.push_back(int_set);
expect_tx_count_ += 1;
}
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
sequence.push_back(1);
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
......@@ -337,32 +338,61 @@ public:
class BarrierCreationRewriter : public StmtExprMutator {
public:
BarrierCreationRewriter(std::vector<int> restore_barrier_ids,
PrimExpr producer_thread_extent)
PrimExpr producer_thread_extent,
int ensure_min_count = 0,
PrimExpr default_barrier_thread_count = 1)
: restore_barrier_ids_(std::move(restore_barrier_ids)),
producer_thread_extent_(std::move(producer_thread_extent)) {}
producer_thread_extent_(std::move(producer_thread_extent)),
ensure_min_count_(ensure_min_count),
default_barrier_thread_count_(std::move(default_barrier_thread_count)) {
}
PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(create_list_of_mbarrier())) {
std::vector<bool> tmp_(op->args.size(), false);
Array<PrimExpr> new_args;
size_t cur_n = op->args.size();
size_t need_n =
std::max<size_t>(cur_n, static_cast<size_t>(ensure_min_count_));
// Mark barriers to restore across the full needed length, not just the
// original length, so newly appended entries can be restored as well.
std::vector<bool> replace(need_n, false);
for (auto &id : restore_barrier_ids_) {
tmp_[id] = true;
if (id >= 0 && static_cast<size_t>(id) < replace.size()) {
replace[id] = true;
}
}
for (size_t i{0}; i < op->args.size(); ++i) {
if (tmp_[i]) {
Array<PrimExpr> new_args;
new_args.reserve(need_n);
// Preserve/override existing entries
for (size_t i{0}; i < cur_n; ++i) {
if (replace[i]) {
new_args.push_back(producer_thread_extent_);
} else {
new_args.push_back(op->args[i]);
}
}
// Append additional barriers if required
for (size_t i = cur_n; i < need_n; ++i) {
if (replace[i]) {
new_args.push_back(producer_thread_extent_);
} else {
new_args.push_back(default_barrier_thread_count_);
}
}
return Call(op->dtype, op->op, new_args);
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
private:
std::vector<int> restore_barrier_ids_;
PrimExpr producer_thread_extent_;
int ensure_min_count_{0};
PrimExpr default_barrier_thread_count_{1};
};
// we trust mbarrier_wait_parity to be correct
......@@ -399,8 +429,31 @@ public:
collector.barrier_id_to_range(),
has_create_list_of_mbarrier);
f.CopyOnWrite()->body = rewriter(f->body);
// Compute the minimum number of barriers actually referenced in the body
// after TMA barrier rewrites (e.g., get_mbarrier(0) inserted for TMA).
struct GetMbarrierMaxIdxCollector : public StmtExprVisitor {
int max_idx{-1};
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(get_mbarrier())) {
if (op->args.size() == 1) {
if (const auto *imm = op->args[0].as<IntImmNode>()) {
max_idx = std::max(max_idx, static_cast<int>(imm->value));
}
}
}
StmtExprVisitor::VisitExpr_(op);
}
};
GetMbarrierMaxIdxCollector max_idx_collector;
max_idx_collector(f->body);
int ensure_min_count = max_idx_collector.max_idx + 1; // 0-based -> count
// For simple TMA-only producers, default barrier arrive count should be 1
// (only the elected leader performs the TMA arrive/expect).
auto barrier_creation_rewriter = BarrierCreationRewriter(
rewriter.restore_barrier_ids_, rewriter.producer_thread_extent_);
rewriter.restore_barrier_ids_, rewriter.producer_thread_extent_,
ensure_min_count, Integer(1));
f.CopyOnWrite()->body = barrier_creation_rewriter(f->body);
return f;
}
......@@ -453,10 +506,27 @@ private:
PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
// check this must be in the tma_op_to_barrier_id_
ICHECK(tma_op_to_barrier_id_.count(tvm::ffi::GetRef<Call>(op)))
<< "tma_load must be in the tma_op_to_barrier_id_";
auto barrier_id = tma_op_to_barrier_id_[tvm::ffi::GetRef<Call>(op)];
auto call_ref = tvm::ffi::GetRef<Call>(op);
if (!tma_op_to_barrier_id_.count(call_ref)) {
// For 1D TMA loads, promote raw integer barrier id to get_mbarrier(id)
// so codegen can emit mbarrier[index]. This handles degenerate
// producer-only kernels where no arrive() is seen and mapping is empty.
auto arg0 = op->args[0].as<Call>();
bool is_1d_tma_load =
arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
!arg0.value()->op.same_as(create_tma_im2col_descriptor());
if (is_1d_tma_load && op->args.size() >= 3) {
if (const auto *imm = op->args[2].as<IntImmNode>()) {
Array<PrimExpr> new_args = op->args;
new_args.Set(2, Call(DataType::Handle(), get_mbarrier(),
{IntImm(DataType::Int(32),
static_cast<int>(imm->value))}));
return Call(op->dtype, op->op, new_args);
}
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
auto barrier_id = tma_op_to_barrier_id_[call_ref];
auto new_args = op->args;
auto arg0 = op->args[0].as<Call>();
auto is_1d_tma_load =
......@@ -469,9 +539,11 @@ private:
}
return Call(op->dtype, op->op, new_args);
} else if (op->op.same_as(mbarrier_expect_tx())) {
ICHECK(tma_op_to_barrier_id_.count(tvm::ffi::GetRef<Call>(op)))
<< "mbarrier_expect_tx must be in the tma_op_to_barrier_id_";
auto barrier_id = tma_op_to_barrier_id_[tvm::ffi::GetRef<Call>(op)];
auto call_ref = tvm::ffi::GetRef<Call>(op);
if (!tma_op_to_barrier_id_.count(call_ref)) {
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
auto barrier_id = tma_op_to_barrier_id_[call_ref];
auto new_args = op->args;
new_args.Set(0, barrier_id);
if (!has_warp_specialization_)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment