Unverified Commit 8b005226 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Refactor] Update TVM subproject and refactor BlockNode handling in...

[Refactor] Update TVM subproject and refactor BlockNode handling in warp_specialized_rewriter.cc (#812)

* [Feature] Introduce custom warp specialization attribute and enhance warp group register allocation

- Added a new attribute `kCustomWarpSpecialization` to support custom warp specialization in the TileLang framework.
- Updated the `Collect` method in `SetMaxNRegCollector` to handle cases where warp specialization is detected, returning an empty array accordingly.
- Enhanced the `SetMaxNRegInjector` to skip processing when no registers are needed, improving efficiency.
- Modified the `WarpSpecialized` pass to include the new attribute in the function body when warp specialization is enabled, ensuring proper handling in transformations.

* lint

* lint
parent 0b3683bf
......@@ -25,6 +25,8 @@ namespace attr {
static constexpr const char *kPaddingMap = "padding_map";
static constexpr const char *kWarpSpecializationScope =
"kWarpSpecializationScope";
static constexpr const char *kCustomWarpSpecialization =
"kCustomWarpSpecialization";
} // namespace attr
static constexpr const char *kDebugMergeSharedMemoryAllocations =
......
......@@ -17,6 +17,9 @@ public:
static Array<IntImm> Collect(const PrimFunc &f) {
SetMaxNRegCollector collector;
collector(f->body);
if (collector.warp_specialized_) {
return Array<IntImm>({});
}
return collector.has_no_set_max_nreg_
? Array<IntImm>({IntImm(DataType::Int(32), -1),
IntImm(DataType::Int(32), -1)})
......@@ -43,21 +46,27 @@ private:
StmtExprVisitor::VisitStmt_(op);
}
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == attr::kCustomWarpSpecialization) {
warp_specialized_ = true;
}
StmtExprVisitor::VisitStmt_(op);
}
Array<IntImm> nreg_{IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32), 0)};
bool has_no_set_max_nreg_ = false;
bool warp_specialized_ = false;
};
class SetMaxNRegInjector : public StmtExprMutator {
public:
static PrimFunc Inject(PrimFunc f) {
bool warp_specialized = WarpSpecializedDetector::Detect(f->body);
if (warp_specialized) {
// Should handle set_max_nreg when using hand-written warp specialized
return f;
}
auto T = SetMaxNRegInjector();
T.nreg_ = SetMaxNRegCollector::Collect(f);
if (T.nreg_.empty()) {
return f;
}
f.CopyOnWrite()->body = T(f->body);
return f;
}
......
......@@ -1283,8 +1283,12 @@ tvm::transform::Pass WarpSpecialized() {
if (!warp_specialized) {
return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized,
disable_shuffle_elect);
} else {
ObjectRef node = String("default");
f.CopyOnWrite()->body =
AttrStmt(node, attr::kCustomWarpSpecialization, 1, f->body);
return f;
}
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment