Commit 41bc15cb authored by Yu Cheng's avatar Yu Cheng Committed by LeiWang1999
Browse files

[Enhancement] Add warp specialization attribute handling in IR and rewriter (#518)

* Introduced an `AttrFrame` for warp specialization in the IR, enhancing the handling of warp-specific optimizations.
* Refactored the `VisitStmt_` method in `warp_specialized_rewriter.cc` to check for the new warp specialization attribute, improving the detection of warp specialization conditions.
* Removed outdated code related to condition checks in `IfThenElseNode`, streamlining the specialization logic.
parent 62a8d7f0
......@@ -285,8 +285,10 @@ WarpSpecializeFrame WarpSpecialize(Array<IntImm> warp_group_ids,
}
}
IfFrame if_frame = If(condition);
AttrFrame attr_frame = Attr(Integer(0), "warp_specialize", Integer(1));
n->frames.push_back(if_frame);
n->frames.push_back(Then());
n->frames.push_back(attr_frame);
return WarpSpecializeFrame(n);
}
......
......@@ -1160,44 +1160,11 @@ private:
IRVisitorWithAnalyzer::VisitExpr_(op);
}
void VisitStmt_(const IfThenElseNode *op) final {
// do not visit the body of the if-then-else statement
// because we only care about the condition
auto cond = op->condition;
// assert cond is a binary expression
PostOrderVisit(cond, [this](const ObjectRef &node) {
bool is_cmp_op = false;
if (const auto *lt = node.as<LTNode>()) {
is_cmp_op = true;
} else if (const auto *le = node.as<LENode>()) {
is_cmp_op = true;
} else if (const auto *gt = node.as<GTNode>()) {
is_cmp_op = true;
} else if (const auto *ge = node.as<GENode>()) {
is_cmp_op = true;
}
if (is_cmp_op) {
bool has_thread_var = false;
bool has_warp_group_size = false;
// check if has thread_var_ in lt->a or lt->b
PostOrderVisit(node, [this, &has_thread_var,
&has_warp_group_size](const ObjectRef &node_) {
if (node_.as<VarNode>() == thread_var_->var.get()) {
has_thread_var = true;
} else if (const auto *imm = node_.as<IntImmNode>()) {
// 128 is the warp group size of nvidia gpus
has_warp_group_size = imm->value % 128 == 0;
}
});
if (has_thread_var && has_warp_group_size) {
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == "warp_specialize" &&
op->value.as<IntImmNode>()->value == 1) {
has_warp_specialization_ = true;
}
}
});
}
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment