"...git@developer.sourcefind.cn:yangql/composable_kernel.git" did not exist on "f7d28f3e4b2992ff58169600e914bf0cc72bd756"
Unverified Commit 8b005226 authored by Yu Cheng's avatar Yu Cheng Committed by GitHub
Browse files

[Refactor] Update TVM subproject and refactor BlockNode handling in...

[Refactor] Update TVM subproject and refactor BlockNode handling in warp_specialized_rewriter.cc (#812)

* [Feature] Introduce custom warp specialization attribute and enhance warp group register allocation

- Added a new attribute `kCustomWarpSpecialization` to support custom warp specialization in the TileLang framework.
- Updated the `Collect` method in `SetMaxNRegCollector` to handle cases where warp specialization is detected, returning an empty array accordingly.
- Enhanced the `SetMaxNRegInjector` to skip processing when no registers are needed, improving efficiency.
- Modified the `WarpSpecialized` pass to include the new attribute in the function body when warp specialization is enabled, ensuring proper handling in transformations.

* lint

* lint
parent 0b3683bf
...@@ -25,6 +25,8 @@ namespace attr { ...@@ -25,6 +25,8 @@ namespace attr {
static constexpr const char *kPaddingMap = "padding_map"; static constexpr const char *kPaddingMap = "padding_map";
static constexpr const char *kWarpSpecializationScope = static constexpr const char *kWarpSpecializationScope =
"kWarpSpecializationScope"; "kWarpSpecializationScope";
static constexpr const char *kCustomWarpSpecialization =
"kCustomWarpSpecialization";
} // namespace attr } // namespace attr
static constexpr const char *kDebugMergeSharedMemoryAllocations = static constexpr const char *kDebugMergeSharedMemoryAllocations =
......
...@@ -17,6 +17,9 @@ public: ...@@ -17,6 +17,9 @@ public:
static Array<IntImm> Collect(const PrimFunc &f) { static Array<IntImm> Collect(const PrimFunc &f) {
SetMaxNRegCollector collector; SetMaxNRegCollector collector;
collector(f->body); collector(f->body);
if (collector.warp_specialized_) {
return Array<IntImm>({});
}
return collector.has_no_set_max_nreg_ return collector.has_no_set_max_nreg_
? Array<IntImm>({IntImm(DataType::Int(32), -1), ? Array<IntImm>({IntImm(DataType::Int(32), -1),
IntImm(DataType::Int(32), -1)}) IntImm(DataType::Int(32), -1)})
...@@ -43,21 +46,27 @@ private: ...@@ -43,21 +46,27 @@ private:
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
} }
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == attr::kCustomWarpSpecialization) {
warp_specialized_ = true;
}
StmtExprVisitor::VisitStmt_(op);
}
Array<IntImm> nreg_{IntImm(DataType::Int(32), 0), Array<IntImm> nreg_{IntImm(DataType::Int(32), 0),
IntImm(DataType::Int(32), 0)}; IntImm(DataType::Int(32), 0)};
bool has_no_set_max_nreg_ = false; bool has_no_set_max_nreg_ = false;
bool warp_specialized_ = false;
}; };
class SetMaxNRegInjector : public StmtExprMutator { class SetMaxNRegInjector : public StmtExprMutator {
public: public:
static PrimFunc Inject(PrimFunc f) { static PrimFunc Inject(PrimFunc f) {
bool warp_specialized = WarpSpecializedDetector::Detect(f->body);
if (warp_specialized) {
// Should handle set_max_nreg when using hand-written warp specialized
return f;
}
auto T = SetMaxNRegInjector(); auto T = SetMaxNRegInjector();
T.nreg_ = SetMaxNRegCollector::Collect(f); T.nreg_ = SetMaxNRegCollector::Collect(f);
if (T.nreg_.empty()) {
return f;
}
f.CopyOnWrite()->body = T(f->body); f.CopyOnWrite()->body = T(f->body);
return f; return f;
} }
......
...@@ -1283,8 +1283,12 @@ tvm::transform::Pass WarpSpecialized() { ...@@ -1283,8 +1283,12 @@ tvm::transform::Pass WarpSpecialized() {
if (!warp_specialized) { if (!warp_specialized) {
return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized, return WarpSpecializedRewriter::Substitute(f, disable_warp_specialized,
disable_shuffle_elect); disable_shuffle_elect);
} else {
ObjectRef node = String("default");
f.CopyOnWrite()->body =
AttrStmt(node, attr::kCustomWarpSpecialization, 1, f->body);
return f;
} }
return f;
}; };
return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {}); return CreatePrimFuncPass(pass_func, 0, "tl.WarpSpecialized", {});
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment