Commit 76435ca8 authored by Yu Cheng's avatar Yu Cheng Committed by LeiWang1999
Browse files

[Feature] Introduce NoSetMaxNReg for warp specialization (#289)

- Added NoSetMaxNReg as a new TIR built-in to indicate no register hint for warp-specialized branches.
- Updated the warp specialization rewriter to handle the new NoSetMaxNReg operation, allowing for improved register management.
- Enhanced the Python interface to include NoSetMaxNReg for consistency with TIR operations.
parent eee45f17
......@@ -103,6 +103,11 @@ TIR_DEFINE_TL_BUILTIN(SetMaxNReg)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(NoSetMaxNReg)
.set_num_inputs(0)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
TIR_DEFINE_TL_BUILTIN(WaitWgmma).set_num_inputs(1).set_attr<TCallEffectKind>(
"TCallEffectKind", Integer(CallEffectKind::kOpaque));
......
......@@ -160,6 +160,14 @@ const Op &TMAStoreWait();
*/
const Op &SetMaxNReg();
/*!
* \brief No set reg hint for warp-specialized branched
*
* NoSetMaxNReg()
*
*/
const Op &NoSetMaxNReg();
/*!
* \brief Wait the previous wgmma to finish
*
......
......@@ -1039,7 +1039,10 @@ public:
static Array<IntImm> Collect(const PrimFunc &f) {
SetMaxNRegCollector collector;
collector(f->body);
return collector.nreg_;
return collector.has_no_set_max_nreg_
? Array<IntImm>({IntImm(DataType::Int(32), -1),
IntImm(DataType::Int(32), -1)})
: collector.nreg_;
}
private:
......@@ -1055,6 +1058,8 @@ private:
// producer should decrease register hint while consumer should increase
// register hint
nreg_.Set(is_inc, IntImm(DataType::Int(32), reg_hint));
} else if (call->op.same_as(NoSetMaxNReg())) {
has_no_set_max_nreg_ = true;
}
}
StmtExprVisitor::VisitStmt_(op);
......@@ -1062,6 +1067,7 @@ private:
Array<IntImm> nreg_{IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32), 0)};
bool has_no_set_max_nreg_ = false;
};
class WarpSpecializedRewriter : public StmtExprMutator {
......@@ -1107,7 +1113,7 @@ private:
Stmt VisitStmt_(const EvaluateNode *op) final {
if (const CallNode *call = op->value.as<CallNode>()) {
if (call->op.same_as(SetMaxNReg())) {
if (call->op.same_as(SetMaxNReg()) || call->op.same_as(NoSetMaxNReg())) {
return Evaluate(0);
}
}
......@@ -1164,10 +1170,14 @@ private:
int dec_reg = nreg_[0].as<IntImmNode>()->value;
int inc_reg = nreg_[1].as<IntImmNode>()->value;
auto inc_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(),
{inc_reg == 0 ? 240 : inc_reg, 1}));
auto dec_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(),
{dec_reg == 0 ? 24 : dec_reg, 0}));
auto inc_reg_stmt = Evaluate(0);
auto dec_reg_stmt = Evaluate(0);
if (dec_reg >= 0 && inc_reg >= 0) {
inc_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(),
{inc_reg == 0 ? 240 : inc_reg, 1}));
dec_reg_stmt = Evaluate(Call(DataType::Handle(), SetMaxNReg(),
{dec_reg == 0 ? 24 : dec_reg, 0}));
}
producer_code = SeqStmt({dec_reg_stmt, producer_code});
consumer_code = SeqStmt({inc_reg_stmt, consumer_code});
......
......@@ -35,6 +35,10 @@ def SetMaxNReg(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.SetMaxNReg"), *args)
def NoSetMaxNReg(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.NoSetMaxNReg"), *args)
def MBarrierWaitParity(*args):
return tir.call_intrin("handle", tir.op.Op.get("tl.MBarrierWaitParity"), *args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment