Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
...@@ -249,7 +249,6 @@ private: ...@@ -249,7 +249,6 @@ private:
new_args.push_back(dst_node); new_args.push_back(dst_node);
new_args.push_back(value_node); new_args.push_back(value_node);
} }
new_args.push_back(memory_order); new_args.push_back(memory_order);
Call new_call = Call new_call =
...@@ -284,4 +283,4 @@ For VectorizeAtomicAdd(const For &for_node, int compute_capability) { ...@@ -284,4 +283,4 @@ For VectorizeAtomicAdd(const For &for_node, int compute_capability) {
} }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
\ No newline at end of file
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include "../support/ffi_aliases.h"
namespace tvm { namespace tvm {
namespace tir { namespace tir {
...@@ -66,7 +68,8 @@ public: ...@@ -66,7 +68,8 @@ public:
} }
if (mem_reuse_max > 0) { if (mem_reuse_max > 0) {
std::string tag_str = cluster_tag; // Convert to std::string std::string tag_str =
static_cast<std::string>(cluster_tag); // Convert to std::string
if (tag_str.rfind("blockIdx", 0) == 0) { if (tag_str.rfind("blockIdx", 0) == 0) {
// starts with "blockIdx" // starts with "blockIdx"
tag_str = "clusterIdx" + tag_str.substr(strlen("blockIdx")); tag_str = "clusterIdx" + tag_str.substr(strlen("blockIdx"));
...@@ -74,7 +77,7 @@ public: ...@@ -74,7 +77,7 @@ public:
// Unexpected format — maybe just prefix // Unexpected format — maybe just prefix
tag_str = "clusterIdx" + tag_str; tag_str = "clusterIdx" + tag_str;
} }
cluster_tag = tvm::ffi::String(tag_str); // Convert back cluster_tag = String(tag_str); // Convert back
return WithAttr(f, cluster_tag, Integer(cluster_size_)); return WithAttr(f, cluster_tag, Integer(cluster_size_));
} else { } else {
return f; return f;
...@@ -122,10 +125,10 @@ tvm::transform::Pass ClusterPlanning() { ...@@ -122,10 +125,10 @@ tvm::transform::Pass ClusterPlanning() {
return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {}); return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ClusterPlanning", ClusterPlanning); refl::GlobalDef().def("tl.transform.ClusterPlanning", ClusterPlanning);
}); }
} // namespace transform } // namespace transform
} // namespace tir } // namespace tir
......
...@@ -41,7 +41,7 @@ public: ...@@ -41,7 +41,7 @@ public:
return StmtMutator::VisitStmt_(op); return StmtMutator::VisitStmt_(op);
// Collect loop variables and ranges // Collect loop variables and ranges
auto for_node = GetRef<For>(op); auto for_node = tvm::ffi::GetRef<For>(op);
Array<Var> loop_vars; Array<Var> loop_vars;
Array<PrimExpr> loop_extents; Array<PrimExpr> loop_extents;
Stmt body = op->body; Stmt body = op->body;
...@@ -81,7 +81,7 @@ public: ...@@ -81,7 +81,7 @@ public:
// post order visit the index // post order visit the index
PostOrderVisit(index, [&](const ObjectRef &obj) { PostOrderVisit(index, [&](const ObjectRef &obj) {
if (const VarNode *v = obj.as<VarNode>()) { if (const VarNode *v = obj.as<VarNode>()) {
used_vars.insert(GetRef<Var>(v)); used_vars.insert(tvm::ffi::GetRef<Var>(v));
} }
}); });
if (used_vars.empty()) { if (used_vars.empty()) {
......
...@@ -211,7 +211,7 @@ public: ...@@ -211,7 +211,7 @@ public:
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b); PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) { if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector(); bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector();
bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector(); bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector();
...@@ -265,7 +265,7 @@ public: ...@@ -265,7 +265,7 @@ public:
PrimExpr VisitExpr_(const NotNode *op) final { PrimExpr VisitExpr_(const NotNode *op) final {
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
if (a.same_as(op->a)) { if (a.same_as(op->a)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return !(a); return !(a);
} }
...@@ -306,10 +306,10 @@ public: ...@@ -306,10 +306,10 @@ public:
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
if (value.dtype().is_scalable_or_fixed_length_vector()) { if (value.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true; need_scalarize_ = true;
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
if (value.same_as(op->value)) { if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return Broadcast(op->value, op->lanes); return Broadcast(op->value, op->lanes);
} }
...@@ -321,7 +321,7 @@ public: ...@@ -321,7 +321,7 @@ public:
PrimExpr f = this->VisitExpr(op->false_value); PrimExpr f = this->VisitExpr(op->false_value);
if (cond.same_as(op->condition) && t.same_as(op->true_value) && if (cond.same_as(op->condition) && t.same_as(op->true_value) &&
f.same_as(op->false_value)) { f.same_as(op->false_value)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int cond_lanes = cond.dtype().get_lanes_or_vscale_factor(); int cond_lanes = cond.dtype().get_lanes_or_vscale_factor();
int t_lanes = t.dtype().get_lanes_or_vscale_factor(); int t_lanes = t.dtype().get_lanes_or_vscale_factor();
...@@ -339,7 +339,7 @@ public: ...@@ -339,7 +339,7 @@ public:
PrimExpr VisitExpr_(const CastNode *op) final { PrimExpr VisitExpr_(const CastNode *op) final {
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) { if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
if (value.dtype().is_scalable_vector()) { if (value.dtype().is_scalable_vector()) {
return Cast(op->dtype.with_scalable_vscale_factor( return Cast(op->dtype.with_scalable_vscale_factor(
...@@ -352,20 +352,20 @@ public: ...@@ -352,20 +352,20 @@ public:
} }
PrimExpr VisitExpr_(const FloatImmNode *op) final { PrimExpr VisitExpr_(const FloatImmNode *op) final {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr VisitExpr_(const IntImmNode *op) final { PrimExpr VisitExpr_(const IntImmNode *op) final {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr VisitExpr_(const StringImmNode *op) final { PrimExpr VisitExpr_(const StringImmNode *op) final {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
// Variable // Variable
PrimExpr VisitExpr_(const VarNode *op) final { PrimExpr VisitExpr_(const VarNode *op) final {
Var var = GetRef<Var>(op); Var var = tvm::ffi::GetRef<Var>(op);
if (var.same_as(var_)) { if (var.same_as(var_)) {
return ramp_; return ramp_;
...@@ -382,13 +382,13 @@ public: ...@@ -382,13 +382,13 @@ public:
PrimExpr cond = this->VisitExpr(op->args[0]); PrimExpr cond = this->VisitExpr(op->args[0]);
if (cond.dtype().is_scalable_or_fixed_length_vector()) { if (cond.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true; need_scalarize_ = true;
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr t = this->VisitExpr(op->args[1]); PrimExpr t = this->VisitExpr(op->args[1]);
PrimExpr f = this->VisitExpr(op->args[2]); PrimExpr f = this->VisitExpr(op->args[2]);
if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) &&
f.same_as(op->args[2])) { f.same_as(op->args[2])) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int t_lanes = t.dtype().get_lanes_or_vscale_factor(); int t_lanes = t.dtype().get_lanes_or_vscale_factor();
int f_lanes = f.dtype().get_lanes_or_vscale_factor(); int f_lanes = f.dtype().get_lanes_or_vscale_factor();
...@@ -410,7 +410,7 @@ public: ...@@ -410,7 +410,7 @@ public:
ICHECK(op->op.same_as(builtin::reinterpret())); ICHECK(op->op.same_as(builtin::reinterpret()));
PrimExpr value = this->VisitExpr(op->args[0]); PrimExpr value = this->VisitExpr(op->args[0]);
if (value.same_as(op->args[0])) { if (value.same_as(op->args[0])) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int lanes = value.dtype().get_lanes_or_vscale_factor(); int lanes = value.dtype().get_lanes_or_vscale_factor();
if (value.dtype().is_scalable_vector()) { if (value.dtype().is_scalable_vector()) {
...@@ -455,12 +455,12 @@ public: ...@@ -455,12 +455,12 @@ public:
auto new_arg = this->VisitExpr(arg); auto new_arg = this->VisitExpr(arg);
if (new_arg.dtype().is_scalable_or_fixed_length_vector()) { if (new_arg.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true; need_scalarize_ = true;
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
new_args.push_back(new_arg); new_args.push_back(new_arg);
} }
if (op->args.same_as(new_args)) { if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return Call(op->dtype, op->op, new_args); return Call(op->dtype, op->op, new_args);
} }
...@@ -469,7 +469,7 @@ public: ...@@ -469,7 +469,7 @@ public:
Array<PrimExpr> new_args = MutateArray(op->args, &lane); Array<PrimExpr> new_args = MutateArray(op->args, &lane);
// normal code path. // normal code path.
if (op->args.same_as(new_args)) { if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return Call(op->dtype.with_lanes(lane), op->op, new_args); return Call(op->dtype.with_lanes(lane), op->op, new_args);
} }
...@@ -477,7 +477,7 @@ public: ...@@ -477,7 +477,7 @@ public:
} }
// BufferLoad // BufferLoad
PrimExpr VisitExpr_(const BufferLoadNode *op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = GetRef<BufferLoad>(op); auto load = tvm::ffi::GetRef<BufferLoad>(op);
auto fmutate = [this](const PrimExpr &index) { auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index); return this->VisitExpr(index);
...@@ -514,7 +514,7 @@ public: ...@@ -514,7 +514,7 @@ public:
let_binding_[op->var] = op->var; let_binding_[op->var] = op->var;
PrimExpr body = this->VisitExpr(op->body); PrimExpr body = this->VisitExpr(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) { if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
return Let(op->var, value, body); return Let(op->var, value, body);
} }
...@@ -522,7 +522,7 @@ public: ...@@ -522,7 +522,7 @@ public:
} }
// BufferStore // BufferStore
Stmt VisitStmt_(const BufferStoreNode *op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = GetRef<BufferStore>(op); auto store = tvm::ffi::GetRef<BufferStore>(op);
auto fmutate = [this](const PrimExpr &index) { auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index); return this->VisitExpr(index);
...@@ -585,11 +585,11 @@ public: ...@@ -585,11 +585,11 @@ public:
ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector()); ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector());
PrimExpr extent = this->VisitExpr(op->extent); PrimExpr extent = this->VisitExpr(op->extent);
if (extent.dtype().is_scalable_or_fixed_length_vector()) { if (extent.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op)); return Scalarize(tvm::ffi::GetRef<Stmt>(op));
} }
Stmt body = this->VisitStmt(op->body); Stmt body = this->VisitStmt(op->body);
if (extent.same_as(op->extent) && body.same_as(op->body)) { if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} else { } else {
return For(op->loop_var, op->min, extent, op->kind, body, return For(op->loop_var, op->min, extent, op->kind, body,
op->thread_binding, op->annotations); op->thread_binding, op->annotations);
...@@ -600,7 +600,7 @@ public: ...@@ -600,7 +600,7 @@ public:
ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector()); ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector());
PrimExpr condition = this->VisitExpr(op->condition); PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_scalable_or_fixed_length_vector()) { if (condition.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op)); return Scalarize(tvm::ffi::GetRef<Stmt>(op));
} }
Stmt then_case = this->VisitStmt(op->then_case); Stmt then_case = this->VisitStmt(op->then_case);
Optional<Stmt> else_case = std::nullopt; Optional<Stmt> else_case = std::nullopt;
...@@ -609,7 +609,7 @@ public: ...@@ -609,7 +609,7 @@ public:
} }
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) { else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} else { } else {
return IfThenElse(condition, then_case, else_case); return IfThenElse(condition, then_case, else_case);
} }
...@@ -634,7 +634,7 @@ public: ...@@ -634,7 +634,7 @@ public:
let_binding_[op->var] = op->var; let_binding_[op->var] = op->var;
Stmt body = this->VisitStmt(op->body); Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) { if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} else { } else {
return LetStmt(op->var, value, body); return LetStmt(op->var, value, body);
} }
...@@ -647,7 +647,7 @@ public: ...@@ -647,7 +647,7 @@ public:
if (condition.dtype().is_scalable_or_fixed_length_vector()) { if (condition.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of " LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint; << op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op)); return Scalarize(tvm::ffi::GetRef<Stmt>(op));
} }
// Mutate the extents // Mutate the extents
...@@ -657,7 +657,7 @@ public: ...@@ -657,7 +657,7 @@ public:
if (new_ext.dtype().is_scalable_or_fixed_length_vector()) { if (new_ext.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of " LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint; << op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op)); return Scalarize(tvm::ffi::GetRef<Stmt>(op));
} }
extents.push_back(new_ext); extents.push_back(new_ext);
} }
...@@ -738,7 +738,7 @@ private: ...@@ -738,7 +738,7 @@ private:
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b); PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) { if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor();
...@@ -754,7 +754,7 @@ private: ...@@ -754,7 +754,7 @@ private:
PrimExpr a = this->VisitExpr(op->a); PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b); PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) { if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} else { } else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor(); int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor(); int b_lanes = b.dtype().get_lanes_or_vscale_factor();
......
...@@ -38,7 +38,7 @@ protected: ...@@ -38,7 +38,7 @@ protected:
if (is_enabled_ && op->dtype.is_int() && op->dtype.bits() < 64) { if (is_enabled_ && op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(_index_bitwidth_), op->value); return IntImm(DataType::Int(_index_bitwidth_), op->value);
} }
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr VisitExpr_(const CastNode *op) final { PrimExpr VisitExpr_(const CastNode *op) final {
...@@ -88,23 +88,23 @@ private: ...@@ -88,23 +88,23 @@ private:
PrimExpr VisitExpr_(const VarNode *op) final { PrimExpr VisitExpr_(const VarNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) { if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), GetRef<Var>(op)); return cast(DataType::Int(64), tvm::ffi::GetRef<Var>(op));
} }
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr VisitExpr_(const IntImmNode *op) final { PrimExpr VisitExpr_(const IntImmNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) { if (op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(64), op->value); return IntImm(DataType::Int(64), op->value);
} }
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr VisitExpr_(const CastNode *op) final { PrimExpr VisitExpr_(const CastNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) { if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), op->value); return cast(DataType::Int(64), op->value);
} }
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
Stmt VisitStmt_(const BufferStoreNode *op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
...@@ -183,11 +183,11 @@ tvm::transform::Pass ConfigIndexBitwidth() { ...@@ -183,11 +183,11 @@ tvm::transform::Pass ConfigIndexBitwidth() {
return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {}); return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ConfigIndexBitwidth", refl::GlobalDef().def("tl.transform.ConfigIndexBitwidth",
ConfigIndexBitwidth); ConfigIndexBitwidth);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -35,9 +35,7 @@ public: ...@@ -35,9 +35,7 @@ public:
Stmt VisitStmt_(const AttrStmtNode *op) final { Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == "thread_extent") { if (op->attr_key == "thread_extent") {
const VarNode *var = nullptr; if (const auto *var = op->node.as<VarNode>()) {
if (op->node->IsInstance<VarNode>()) {
var = op->node.as<VarNode>();
if (var->name_hint == "threadIdx.x") { if (var->name_hint == "threadIdx.x") {
thread_extent_ = op; thread_extent_ = op;
} }
...@@ -82,7 +80,7 @@ public: ...@@ -82,7 +80,7 @@ public:
} }
Stmt VisitStmt_(const ForNode *op) final { Stmt VisitStmt_(const ForNode *op) final {
PostOrderVisit(GetRef<For>(op), [&](const ObjectRef &node) { PostOrderVisit(tvm::ffi::GetRef<For>(op), [&](const ObjectRef &node) {
if (const auto *call = node.as<CallNode>()) { if (const auto *call = node.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier()) || if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) || call->op.same_as(mbarrier_wait_parity()) ||
...@@ -116,11 +114,11 @@ tvm::transform::Pass EliminateStorageSyncForMBarrier() { ...@@ -116,11 +114,11 @@ tvm::transform::Pass EliminateStorageSyncForMBarrier() {
{}); {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.EliminateStorageSyncForMBarrier", refl::GlobalDef().def("tl.transform.EliminateStorageSyncForMBarrier",
EliminateStorageSyncForMBarrier); EliminateStorageSyncForMBarrier);
}); }
} // namespace transform } // namespace transform
} // namespace tl } // namespace tl
......
...@@ -75,23 +75,23 @@ private: ...@@ -75,23 +75,23 @@ private:
PrimExpr VisitExpr_(const VarNode *op) final { PrimExpr VisitExpr_(const VarNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) { if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), GetRef<Var>(op)); return cast(DataType::Int(64), tvm::ffi::GetRef<Var>(op));
} }
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr VisitExpr_(const IntImmNode *op) final { PrimExpr VisitExpr_(const IntImmNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) { if (op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(64), op->value); return IntImm(DataType::Int(64), op->value);
} }
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
PrimExpr VisitExpr_(const CastNode *op) final { PrimExpr VisitExpr_(const CastNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) { if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), op->value); return cast(DataType::Int(64), op->value);
} }
return GetRef<PrimExpr>(op); return tvm::ffi::GetRef<PrimExpr>(op);
} }
Stmt VisitStmt_(const BufferStoreNode *op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
...@@ -115,7 +115,7 @@ private: ...@@ -115,7 +115,7 @@ private:
<< "All MatchBufferRegion should be removed in " << "All MatchBufferRegion should be removed in "
"tir.transform.LowerMatchBuffer."; "tir.transform.LowerMatchBuffer.";
Block block = GetRef<Block>(op); Block block = tvm::ffi::GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers; Array<Buffer> alloc_buffers = op->alloc_buffers;
alloc_buffers.MutateByApply( alloc_buffers.MutateByApply(
...@@ -385,10 +385,10 @@ tvm::transform::Pass FlattenBuffer() { ...@@ -385,10 +385,10 @@ tvm::transform::Pass FlattenBuffer() {
return CreatePrimFuncPass(pass_func, 0, "tl.FlattenBuffer", {}); return CreatePrimFuncPass(pass_func, 0, "tl.FlattenBuffer", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.FlattenBuffer", FlattenBuffer); refl::GlobalDef().def("tl.transform.FlattenBuffer", FlattenBuffer);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -89,10 +89,10 @@ Pass LetInline() { ...@@ -89,10 +89,10 @@ Pass LetInline() {
return CreatePrimFuncPass(pass_func, 0, "tl.LetInline", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LetInline", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LetInline", LetInline); refl::GlobalDef().def("tl.transform.LetInline", LetInline);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -33,7 +33,7 @@ private: ...@@ -33,7 +33,7 @@ private:
auto then_case = VisitStmt(op->then_case); auto then_case = VisitStmt(op->then_case);
Optional<Stmt> else_case = op->else_case; Optional<Stmt> else_case = op->else_case;
if (else_case.defined()) { if (else_case.defined()) {
return GetRef<Stmt>(op); return tvm::ffi::GetRef<Stmt>(op);
} }
ICHECK(then_case.defined()) << "then_case must be defined"; ICHECK(then_case.defined()) << "then_case must be defined";
ICHECK(!else_case.defined()) << "else_case must be undefined"; ICHECK(!else_case.defined()) << "else_case must be undefined";
...@@ -81,10 +81,10 @@ tvm::transform::Pass IfStmtBinding() { ...@@ -81,10 +81,10 @@ tvm::transform::Pass IfStmtBinding() {
return CreatePrimFuncPass(pass_func, 0, "tl.IfStmtBinding", {}); return CreatePrimFuncPass(pass_func, 0, "tl.IfStmtBinding", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.IfStmtBinding", IfStmtBinding); refl::GlobalDef().def("tl.transform.IfStmtBinding", IfStmtBinding);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -156,9 +156,9 @@ tvm::transform::Pass InjectAssumes() { ...@@ -156,9 +156,9 @@ tvm::transform::Pass InjectAssumes() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectAssumes", {}); return CreatePrimFuncPass(pass_func, 0, "tl.InjectAssumes", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectAssumes", InjectAssumes); refl::GlobalDef().def("tl.transform.InjectAssumes", InjectAssumes);
}); }
} // namespace tvm::tl } // namespace tvm::tl
...@@ -108,7 +108,8 @@ bool IsKnownGeneric(const CallNode *call) { ...@@ -108,7 +108,8 @@ bool IsKnownGeneric(const CallNode *call) {
return false; return false;
} }
return call->op.same_as(ptx_ldmatrix()) || call->op.same_as(ptx_stmatrix()) || return call->op.same_as(ptx_ldmatrix()) || call->op.same_as(ptx_stmatrix()) ||
call->op.same_as(initialize_descriptor()); call->op.same_as(initialize_wgmma_descriptor()) ||
call->op.same_as(initialize_tcgen05_descriptor());
} }
ProxyKind ProxyFromAttrValue(const ObjectRef &value) { ProxyKind ProxyFromAttrValue(const ObjectRef &value) {
...@@ -319,10 +320,10 @@ tvm::transform::Pass InjectFenceProxy() { ...@@ -319,10 +320,10 @@ tvm::transform::Pass InjectFenceProxy() {
{}); {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectFenceProxy", InjectFenceProxy); refl::GlobalDef().def("tl.transform.InjectFenceProxy", InjectFenceProxy);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -37,9 +37,14 @@ ...@@ -37,9 +37,14 @@
namespace tvm { namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
using namespace ffi;
namespace software_pipeline { namespace software_pipeline {
struct LetWrapper {
Var var;
PrimExpr value;
};
/*! /*!
* \brief Create a block and infer the access region with the given body. * \brief Create a block and infer the access region with the given body.
* *
...@@ -233,10 +238,12 @@ class PipelineRewriter : public StmtExprMutator { ...@@ -233,10 +238,12 @@ class PipelineRewriter : public StmtExprMutator {
public: public:
PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer, PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
const Array<Buffer> &pipeline_allocs, const Array<Buffer> &pipeline_allocs,
const For &pipeline_loop, const PipelineInfo &pipeline_info) const For &pipeline_loop, const PipelineInfo &pipeline_info,
const std::vector<LetWrapper> &loop_var_let_wrappers)
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop), pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
pipeline_info_(pipeline_info) {} pipeline_info_(pipeline_info),
loop_var_let_wrappers_(loop_var_let_wrappers) {}
Stmt BuildPipeline() { Stmt BuildPipeline() {
// Step 1: Analyze accesses to the buffers in the pipeline and compute the // Step 1: Analyze accesses to the buffers in the pipeline and compute the
...@@ -459,7 +466,8 @@ private: ...@@ -459,7 +466,8 @@ private:
* \return The resized buffer. * \return The resized buffer.
*/ */
Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) { Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get())); ObjectPtr<BufferNode> new_buffer =
tvm::ffi::make_object<BufferNode>(*(buffer.get()));
new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions)); new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
if (!new_buffer->strides.empty()) { if (!new_buffer->strides.empty()) {
ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size()); ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
...@@ -676,6 +684,20 @@ private: ...@@ -676,6 +684,20 @@ private:
new_block = Downcast<Block>(Substitute( new_block = Downcast<Block>(Substitute(
new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
// If there were Let-wrappers outside the original pipeline body that
// depended on the pipeline loop var, push them into each rewritten
// block with the correct per-block substitution.
if (!loop_var_let_wrappers_.empty()) {
BlockNode *n = new_block.CopyOnWrite();
Stmt inner = n->body;
for (const auto &lw : loop_var_let_wrappers_) {
PrimExpr substituted = Substitute(
lw.value, {{pipeline_loop_->loop_var, normalized_access_index}});
inner = LetStmt(lw.var, substituted, inner);
}
n->body = inner;
}
if (pipeline_info_[block].async) { if (pipeline_info_[block].async) {
auto &local_state = async_states_local[stage]; auto &local_state = async_states_local[stage];
local_state.producer_head = normalized_access_index; local_state.producer_head = normalized_access_index;
...@@ -737,6 +759,7 @@ private: ...@@ -737,6 +759,7 @@ private:
Map<Buffer, Buffer> buffer_remap_; Map<Buffer, Buffer> buffer_remap_;
Array<Block> ordered_stmts_; Array<Block> ordered_stmts_;
std::map<int, AsyncStateGlobal> async_states; std::map<int, AsyncStateGlobal> async_states;
std::vector<LetWrapper> loop_var_let_wrappers_;
}; };
/*! /*!
...@@ -864,8 +887,9 @@ private: ...@@ -864,8 +887,9 @@ private:
const SeqStmtNode *pipeline_body_seq = nullptr; const SeqStmtNode *pipeline_body_seq = nullptr;
std::vector<std::function<Stmt(Stmt)>> rewrap_fns; std::vector<std::function<Stmt(Stmt)>> rewrap_fns;
std::vector<LetWrapper> loop_var_let_wrappers;
auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) { auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) {
ObjectRef node = attr->node; Any node = attr->node;
String attr_key = attr->attr_key; String attr_key = attr->attr_key;
PrimExpr value = attr->value; PrimExpr value = attr->value;
Span span = attr->span; Span span = attr->span;
...@@ -896,14 +920,25 @@ private: ...@@ -896,14 +920,25 @@ private:
continue; continue;
} }
if (const auto *let_stmt = current.as<LetStmtNode>()) { if (const auto *let_stmt = current.as<LetStmtNode>()) {
Var var = let_stmt->var; // If this Let value uses the pipeline loop var, record it and push
PrimExpr value = let_stmt->value; // inside each rewritten block later so the loop var can be
Span span = let_stmt->span; // substituted with the correct per-iteration index. Otherwise, keep
rewrap_fns.emplace_back([var = std::move(var), // it as a normal wrapper.
value = std::move(value), bool uses_loop_var = UsesVar(
span](Stmt body) -> Stmt { let_stmt->value,
return LetStmt(var, value, body, span); [v = op->loop_var.get()](const VarNode *vn) { return vn == v; });
}); if (uses_loop_var) {
loop_var_let_wrappers.push_back({let_stmt->var, let_stmt->value});
} else {
Var var = let_stmt->var;
PrimExpr value = let_stmt->value;
Span span = let_stmt->span;
rewrap_fns.emplace_back([var = std::move(var),
value = std::move(value),
span](Stmt body) -> Stmt {
return LetStmt(var, value, body, span);
});
}
current = let_stmt->body; current = let_stmt->body;
continue; continue;
} }
...@@ -981,7 +1016,8 @@ private: ...@@ -981,7 +1016,8 @@ private:
// Step 4: Rewrite the pipeline body. // Step 4: Rewrite the pipeline body.
Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs, Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
GetRef<For>(op), pipeline_info) tvm::ffi::GetRef<For>(op), pipeline_info,
loop_var_let_wrappers)
.BuildPipeline(); .BuildPipeline();
auto apply_wrappers = [&](Stmt stmt) { auto apply_wrappers = [&](Stmt stmt) {
for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) { for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) {
...@@ -1072,11 +1108,11 @@ tir::transform::Pass InjectSoftwarePipeline() { ...@@ -1072,11 +1108,11 @@ tir::transform::Pass InjectSoftwarePipeline() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {}); return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectSoftwarePipeline", refl::GlobalDef().def("tl.transform.InjectSoftwarePipeline",
InjectSoftwarePipeline); InjectSoftwarePipeline);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -232,10 +232,10 @@ tvm::transform::Pass InjectPTXAsyncCopy() { ...@@ -232,10 +232,10 @@ tvm::transform::Pass InjectPTXAsyncCopy() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectPTXAsyncCopy", {}); return CreatePrimFuncPass(pass_func, 0, "tl.InjectPTXAsyncCopy", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectPTXAsyncCopy", InjectPTXAsyncCopy); refl::GlobalDef().def("tl.transform.InjectPTXAsyncCopy", InjectPTXAsyncCopy);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -204,9 +204,9 @@ private: ...@@ -204,9 +204,9 @@ private:
void VisitStmt_(const EvaluateNode *op) final { void VisitStmt_(const EvaluateNode *op) final {
if (const auto *call = op->value.as<CallNode>()) { if (const auto *call = op->value.as<CallNode>()) {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) { if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
pending_tma_ops_.push_back(GetRef<Call>(call)); pending_tma_ops_.push_back(tvm::ffi::GetRef<Call>(call));
} else if (call->op.same_as(mbarrier_expect_tx())) { } else if (call->op.same_as(mbarrier_expect_tx())) {
pending_tma_ops_.push_back(GetRef<Call>(call)); pending_tma_ops_.push_back(tvm::ffi::GetRef<Call>(call));
} else if (call->op.same_as(builtin::ptx_arrive_barrier())) { } else if (call->op.same_as(builtin::ptx_arrive_barrier())) {
PrimExpr barrier_id = call->args[0]; PrimExpr barrier_id = call->args[0];
for (const auto &tma_call : pending_tma_ops_) { for (const auto &tma_call : pending_tma_ops_) {
...@@ -295,13 +295,15 @@ public: ...@@ -295,13 +295,15 @@ public:
void VisitExpr_(const CallNode *op) final { void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(mbarrier_expect_tx())) { if (op->op.same_as(mbarrier_expect_tx())) {
PrimExpr e = auto call_ref = tvm::ffi::GetRef<Call>(op);
tma_op_to_barrier_id_[GetRef<Call>(op)].as<CallNode>()->args[0]; if (tma_op_to_barrier_id_.count(call_ref)) {
auto int_set = arith::EvalSet(e, var_int_set_); PrimExpr e = tma_op_to_barrier_id_[call_ref].as<CallNode>()->args[0];
expect_.push_back(if_depth_ == 1); auto int_set = arith::EvalSet(e, var_int_set_);
sequence.push_back(0); expect_.push_back(if_depth_ == 1);
int_sets_.push_back(int_set); sequence.push_back(0);
expect_tx_count_ += 1; int_sets_.push_back(int_set);
expect_tx_count_ += 1;
}
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) { } else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
sequence.push_back(1); sequence.push_back(1);
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
...@@ -336,32 +338,61 @@ public: ...@@ -336,32 +338,61 @@ public:
class BarrierCreationRewriter : public StmtExprMutator { class BarrierCreationRewriter : public StmtExprMutator {
public: public:
BarrierCreationRewriter(std::vector<int> restore_barrier_ids, BarrierCreationRewriter(std::vector<int> restore_barrier_ids,
PrimExpr producer_thread_extent) PrimExpr producer_thread_extent,
int ensure_min_count = 0,
PrimExpr default_barrier_thread_count = 1)
: restore_barrier_ids_(std::move(restore_barrier_ids)), : restore_barrier_ids_(std::move(restore_barrier_ids)),
producer_thread_extent_(std::move(producer_thread_extent)) {} producer_thread_extent_(std::move(producer_thread_extent)),
ensure_min_count_(ensure_min_count),
default_barrier_thread_count_(std::move(default_barrier_thread_count)) {
}
PrimExpr VisitExpr_(const CallNode *op) { PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(create_list_of_mbarrier())) { if (op->op.same_as(create_list_of_mbarrier())) {
std::vector<bool> tmp_(op->args.size(), false); size_t cur_n = op->args.size();
Array<PrimExpr> new_args; size_t need_n =
std::max<size_t>(cur_n, static_cast<size_t>(ensure_min_count_));
// Mark barriers to restore across the full needed length, not just the
// original length, so newly appended entries can be restored as well.
std::vector<bool> replace(need_n, false);
for (auto &id : restore_barrier_ids_) { for (auto &id : restore_barrier_ids_) {
tmp_[id] = true; if (id >= 0 && static_cast<size_t>(id) < replace.size()) {
replace[id] = true;
}
} }
for (size_t i{0}; i < op->args.size(); ++i) { Array<PrimExpr> new_args;
if (tmp_[i]) { new_args.reserve(need_n);
// Preserve/override existing entries
for (size_t i{0}; i < cur_n; ++i) {
if (replace[i]) {
new_args.push_back(producer_thread_extent_); new_args.push_back(producer_thread_extent_);
} else { } else {
new_args.push_back(op->args[i]); new_args.push_back(op->args[i]);
} }
} }
// Append additional barriers if required
for (size_t i = cur_n; i < need_n; ++i) {
if (replace[i]) {
new_args.push_back(producer_thread_extent_);
} else {
new_args.push_back(default_barrier_thread_count_);
}
}
return Call(op->dtype, op->op, new_args); return Call(op->dtype, op->op, new_args);
} else { } else {
return StmtExprMutator::VisitExpr_(op); return StmtExprMutator::VisitExpr_(op);
} }
} }
private:
std::vector<int> restore_barrier_ids_; std::vector<int> restore_barrier_ids_;
PrimExpr producer_thread_extent_; PrimExpr producer_thread_extent_;
int ensure_min_count_{0};
PrimExpr default_barrier_thread_count_{1};
}; };
// we trust mbarrier_wait_parity to be correct // we trust mbarrier_wait_parity to be correct
...@@ -398,15 +429,38 @@ public: ...@@ -398,15 +429,38 @@ public:
collector.barrier_id_to_range(), collector.barrier_id_to_range(),
has_create_list_of_mbarrier); has_create_list_of_mbarrier);
f.CopyOnWrite()->body = rewriter(f->body); f.CopyOnWrite()->body = rewriter(f->body);
// Compute the minimum number of barriers actually referenced in the body
// after TMA barrier rewrites (e.g., get_mbarrier(0) inserted for TMA).
struct GetMbarrierMaxIdxCollector : public StmtExprVisitor {
int max_idx{-1};
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(get_mbarrier())) {
if (op->args.size() == 1) {
if (const auto *imm = op->args[0].as<IntImmNode>()) {
max_idx = std::max(max_idx, static_cast<int>(imm->value));
}
}
}
StmtExprVisitor::VisitExpr_(op);
}
};
GetMbarrierMaxIdxCollector max_idx_collector;
max_idx_collector(f->body);
int ensure_min_count = max_idx_collector.max_idx + 1; // 0-based -> count
// For simple TMA-only producers, default barrier arrive count should be 1
// (only the elected leader performs the TMA arrive/expect).
auto barrier_creation_rewriter = BarrierCreationRewriter( auto barrier_creation_rewriter = BarrierCreationRewriter(
rewriter.restore_barrier_ids_, rewriter.producer_thread_extent_); rewriter.restore_barrier_ids_, rewriter.producer_thread_extent_,
ensure_min_count, Integer(1));
f.CopyOnWrite()->body = barrier_creation_rewriter(f->body); f.CopyOnWrite()->body = barrier_creation_rewriter(f->body);
return f; return f;
} }
private: private:
Stmt VisitStmt_(const BlockNode *op) { Stmt VisitStmt_(const BlockNode *op) {
auto block = GetRef<Block>(op); auto block = tvm::ffi::GetRef<Block>(op);
if (!has_create_list_of_mbarrier_ && !barrier_id_to_range_.empty() && if (!has_create_list_of_mbarrier_ && !barrier_id_to_range_.empty() &&
op->name_hint == MainBlockName) { op->name_hint == MainBlockName) {
ICHECK(false) << "Please declare create_list_of_mbarrier."; ICHECK(false) << "Please declare create_list_of_mbarrier.";
...@@ -452,10 +506,27 @@ private: ...@@ -452,10 +506,27 @@ private:
PrimExpr VisitExpr_(const CallNode *op) { PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) { if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
// check this must be in the tma_op_to_barrier_id_ auto call_ref = tvm::ffi::GetRef<Call>(op);
ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op))) if (!tma_op_to_barrier_id_.count(call_ref)) {
<< "tma_load must be in the tma_op_to_barrier_id_"; // For 1D TMA loads, promote raw integer barrier id to get_mbarrier(id)
auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)]; // so codegen can emit mbarrier[index]. This handles degenerate
// producer-only kernels where no arrive() is seen and mapping is empty.
auto arg0 = op->args[0].as<Call>();
bool is_1d_tma_load =
arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
!arg0.value()->op.same_as(create_tma_im2col_descriptor());
if (is_1d_tma_load && op->args.size() >= 3) {
if (const auto *imm = op->args[2].as<IntImmNode>()) {
Array<PrimExpr> new_args = op->args;
new_args.Set(2, Call(DataType::Handle(), get_mbarrier(),
{IntImm(DataType::Int(32),
static_cast<int>(imm->value))}));
return Call(op->dtype, op->op, new_args);
}
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
auto barrier_id = tma_op_to_barrier_id_[call_ref];
auto new_args = op->args; auto new_args = op->args;
auto arg0 = op->args[0].as<Call>(); auto arg0 = op->args[0].as<Call>();
auto is_1d_tma_load = auto is_1d_tma_load =
...@@ -468,9 +539,11 @@ private: ...@@ -468,9 +539,11 @@ private:
} }
return Call(op->dtype, op->op, new_args); return Call(op->dtype, op->op, new_args);
} else if (op->op.same_as(mbarrier_expect_tx())) { } else if (op->op.same_as(mbarrier_expect_tx())) {
ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op))) auto call_ref = tvm::ffi::GetRef<Call>(op);
<< "mbarrier_expect_tx must be in the tma_op_to_barrier_id_"; if (!tma_op_to_barrier_id_.count(call_ref)) {
auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)]; return IRMutatorWithAnalyzer::VisitExpr_(op);
}
auto barrier_id = tma_op_to_barrier_id_[call_ref];
auto new_args = op->args; auto new_args = op->args;
new_args.Set(0, barrier_id); new_args.Set(0, barrier_id);
if (!has_warp_specialization_) if (!has_warp_specialization_)
...@@ -522,10 +595,10 @@ tvm::transform::Pass InjectTmaBarrier() { ...@@ -522,10 +595,10 @@ tvm::transform::Pass InjectTmaBarrier() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectTmaBarrier", {}); return CreatePrimFuncPass(pass_func, 0, "tl.InjectTmaBarrier", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectTmaBarrier", InjectTmaBarrier); refl::GlobalDef().def("tl.transform.InjectTmaBarrier", InjectTmaBarrier);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <tvm/tir/utils.h> #include <tvm/tir/utils.h>
#include <algorithm>
#include <queue> #include <queue>
#include "../layout/utils.h" #include "../layout/utils.h"
...@@ -105,20 +106,60 @@ public: ...@@ -105,20 +106,60 @@ public:
"required for layout inference."; "required for layout inference.";
// Run InferLayout // Run InferLayout
DLOG(INFO) << "[RunInferStep] working on " << cur_infer_id << '\n';
auto updates = auto updates =
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map, next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
&analyzer_, buffer_oob}, &analyzer_, buffer_oob},
level); level);
// Process the returned updates // Process the returned updates
for (const auto &[buffer, layout] : updates) { for (const auto &[buffer, layout] : updates) {
DLOG(INFO) << " consider update " << buffer << " as "
<< layout->DebugOutput() << '\n';
// Basic validity checks // Basic validity checks
ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer."; ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
ICHECK(layout.defined()) << "InferLayout returned an undefined layout."; ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";
// Helper: propagate inferred layout to alias buffers (same data Var)
auto propagate_alias = [&](const Buffer &src_buffer,
const Layout &src_layout) {
if (!buffer_data_to_buffers_.count(src_buffer->data))
return;
const auto &siblings = buffer_data_to_buffers_[src_buffer->data];
for (const auto &sib : siblings) {
if (sib.same_as(src_buffer))
continue;
bool shapes_equal =
src_layout->InputShape().size() == sib->shape.size();
if (shapes_equal) {
for (size_t i = 0; i < src_layout->InputShape().size(); ++i) {
if (!analyzer_.CanProveEqual(src_layout->InputShape()[i],
sib->shape[i])) {
shapes_equal = false;
break;
}
}
}
Layout target_layout =
shapes_equal ? src_layout
: src_layout->Reshape(sib->shape, &analyzer_);
if (layout_map.count(sib)) {
ICHECK(target_layout->IsEqual(layout_map[sib].get()))
<< "Get different layout for alias buffer " << sib
<< " (data-shared with " << src_buffer
<< ")\n current: " << target_layout->DebugOutput()
<< "\n previous: " << layout_map[sib]->DebugOutput();
} else {
layout_map.Set(sib, target_layout);
if (update_queue && use_list_.count(sib)) {
for (int idx : use_list_[sib]) {
if (!in_queue[idx] && idx != cur_infer_id) {
in_queue[idx] = true;
q.push(idx);
}
}
}
}
}
};
if (layout_map.count(buffer)) { if (layout_map.count(buffer)) {
// If new layout contains the old one, update map // If new layout contains the old one, update map
if (buffer.scope() == "local.fragment" && if (buffer.scope() == "local.fragment" &&
...@@ -153,8 +194,8 @@ public: ...@@ -153,8 +194,8 @@ public:
if (ProveFragmentContains(src_layout, dst_layout, indices, indices, if (ProveFragmentContains(src_layout, dst_layout, indices, indices,
inner_analyzer)) { inner_analyzer)) {
layout_map.Set(buffer, layout); layout_map.Set(buffer, layout);
DLOG(INFO) << " layout broadcast from " // Propagate to alias buffers as well
<< src_layout->DebugOutput() << ", accepted" << '\n'; propagate_alias(buffer, layout);
continue; continue;
} }
} }
...@@ -163,10 +204,13 @@ public: ...@@ -163,10 +204,13 @@ public:
<< "Get different layout for " << buffer << "Get different layout for " << buffer
<< "\n current layout: " << layout->DebugOutput() << "\n current layout: " << layout->DebugOutput()
<< "\n previous layout: " << layout_map[buffer]->DebugOutput(); << "\n previous layout: " << layout_map[buffer]->DebugOutput();
// Ensure aliases are consistent too
propagate_alias(buffer, layout);
} else { } else {
// Otherwise, update map // Otherwise, update map
layout_map.Set(buffer, layout); layout_map.Set(buffer, layout);
DLOG(INFO) << " new layout accepted" << '\n'; // Propagate to alias buffers (may enqueue their users)
propagate_alias(buffer, layout);
if (!update_queue) if (!update_queue)
continue; continue;
...@@ -272,6 +316,46 @@ public: ...@@ -272,6 +316,46 @@ public:
// step 3: relax constraints to free and re-run // step 3: relax constraints to free and re-run
InferInFreeMode(layout_map, strict_layout_map); InferInFreeMode(layout_map, strict_layout_map);
// step 4: finalize alias layouts by Var
// For each storage var, if any buffer in the group has a layout,
// propagate (reshape if needed) to the rest to ensure completeness.
for (const auto &[var, buffers] : buffer_data_to_buffers_) {
// Find a representative with existing layout
Optional<Buffer> rep;
Optional<Layout> rep_layout;
for (const auto &buf : buffers) {
if (layout_map.count(buf)) {
rep = buf;
rep_layout = layout_map[buf];
break;
}
}
if (!rep_layout.defined())
continue;
for (const auto &buf : buffers) {
if (!layout_map.count(buf)) {
bool shapes_equal =
rep_layout.value()->InputShape().size() == buf->shape.size();
if (shapes_equal) {
for (size_t i = 0; i < rep_layout.value()->InputShape().size();
++i) {
if (!analyzer_.CanProveEqual(rep_layout.value()->InputShape()[i],
buf->shape[i])) {
shapes_equal = false;
break;
}
}
}
Layout reshaped =
shapes_equal
? rep_layout.value()
: rep_layout.value()->Reshape(buf->shape, &analyzer_);
layout_map.Set(buf, reshaped);
}
}
}
// Check that all local.fragment buffers have inferred layouts // Check that all local.fragment buffers have inferred layouts
for (const auto &[buffer, _] : use_list_) { for (const auto &[buffer, _] : use_list_) {
if (buffer.scope() == "local.fragment") { if (buffer.scope() == "local.fragment") {
...@@ -314,7 +398,13 @@ public: ...@@ -314,7 +398,13 @@ public:
void Collect(const PrimFunc &f) { void Collect(const PrimFunc &f) {
for (const auto &[_, buffer] : f->buffer_map) { for (const auto &[_, buffer] : f->buffer_map) {
buffer_data_to_buffer_.Set(buffer->data, buffer); if (buffer_data_to_buffers_.count(buffer->data)) {
auto buffers = buffer_data_to_buffers_[buffer->data];
buffers.push_back(buffer);
buffer_data_to_buffers_.Set(buffer->data, buffers);
} else {
buffer_data_to_buffers_.Set(buffer->data, {buffer});
}
} }
auto target = f->GetAttr<Target>(tvm::attr::kTarget); auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) ICHECK(target.defined())
...@@ -324,13 +414,25 @@ public: ...@@ -324,13 +414,25 @@ public:
} }
private: private:
Map<Var, Buffer> GetBufferMap() const {
Map<Var, Buffer> buffer_map;
for (const auto &[var, buffers] : buffer_data_to_buffers_) {
// Use the first buffer for each var
// TODO(lei): phaseout buffer_map in future.
if (!buffers.empty()) {
buffer_map.Set(var, buffers[0]);
}
}
return buffer_map;
}
void VisitExpr_(const CallNode *op) final { void VisitExpr_(const CallNode *op) final {
IRVisitorWithAnalyzer::VisitExpr_(op); IRVisitorWithAnalyzer::VisitExpr_(op);
// Do not analysis the call node to the global function. // Do not analysis the call node to the global function.
if (op->op.as<GlobalVarNode>()) if (op->op.as<GlobalVarNode>())
return; return;
auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_); auto p = ParseOperator(tvm::ffi::GetRef<Call>(op), GetBufferMap());
if (p.defined()) { if (p.defined()) {
for (const auto &arg : op->args) { for (const auto &arg : op->args) {
if (auto buffer = getBufferFromAccessPtr(arg)) { if (auto buffer = getBufferFromAccessPtr(arg)) {
...@@ -381,7 +483,7 @@ private: ...@@ -381,7 +483,7 @@ private:
} }
// Add the tile operator to infer_list_ // Add the tile operator to infer_list_
infer_list_stmt_.push_back(GetRef<ObjectRef>(op)); infer_list_stmt_.push_back(tvm::ffi::GetRef<ObjectRef>(op));
infer_list_.push_back(std::move(p)); infer_list_.push_back(std::move(p));
} }
} }
...@@ -394,12 +496,18 @@ private: ...@@ -394,12 +496,18 @@ private:
if (call->op.same_as(builtin::tvm_access_ptr())) { if (call->op.same_as(builtin::tvm_access_ptr())) {
auto var_opt = call->args[1].as<Var>(); auto var_opt = call->args[1].as<Var>();
if (!var_opt.has_value()) { if (!var_opt.has_value()) {
DLOG(WARNING) << "[getBufferFromAccessPtr] args[1] is not a Var, type: " LOG(WARNING) << "[getBufferFromAccessPtr] args[1] is not a Var, type: "
<< call->args[1]->GetTypeKey(); << call->args[1]->GetTypeKey();
return std::nullopt; return std::nullopt;
} }
const auto &var = var_opt.value(); const auto &var = var_opt.value();
return buffer_data_to_buffer_[var]; if (buffer_data_to_buffers_.count(var)) {
const auto &buffers = buffer_data_to_buffers_[var];
if (!buffers.empty()) {
return buffers[0]; // Return the first buffer
}
}
return std::nullopt;
} else if (call->op.same_as(RegionOp::Get())) { } else if (call->op.same_as(RegionOp::Get())) {
return call->args[0].as<BufferLoadNode>()->buffer; return call->args[0].as<BufferLoadNode>()->buffer;
} }
...@@ -416,11 +524,11 @@ private: ...@@ -416,11 +524,11 @@ private:
void VisitStmt_(const ForNode *op) final { void VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kParallel) { if (op->kind == ForKind::kParallel) {
auto infer = ParallelOp(GetRef<For>(op)); auto infer = ParallelOp(tvm::ffi::GetRef<For>(op));
for (const auto &[buffer, _] : infer->GetIndiceMap()) { for (const auto &[buffer, _] : infer->GetIndiceMap()) {
addToUseList(buffer); addToUseList(buffer);
} }
infer_list_stmt_.push_back(GetRef<ObjectRef>(op)); infer_list_stmt_.push_back(tvm::ffi::GetRef<ObjectRef>(op));
infer_list_.push_back(std::move(infer)); infer_list_.push_back(std::move(infer));
thread_var_vec_.push_back(thread_var_); thread_var_vec_.push_back(thread_var_);
if (thread_var_.defined() && if (thread_var_.defined() &&
...@@ -442,21 +550,55 @@ private: ...@@ -442,21 +550,55 @@ private:
void VisitStmt_(const BlockNode *op) final { void VisitStmt_(const BlockNode *op) final {
for (auto buffer : op->alloc_buffers) { for (auto buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer); if (buffer_data_to_buffers_.count(buffer->data)) {
auto buffers = buffer_data_to_buffers_[buffer->data];
buffers.push_back(buffer);
buffer_data_to_buffers_.Set(buffer->data, buffers);
} else {
buffer_data_to_buffers_.Set(buffer->data, {buffer});
}
} }
// First, visit the block body to collect all buffers from
// BufferLoad/BufferStore
IRVisitorWithAnalyzer::VisitStmt_(op);
// After visiting, apply layouts to all collected buffers
if (op->annotations.count(attr::kLayoutMap)) { if (op->annotations.count(attr::kLayoutMap)) {
// Check if the layout map is Map<Var, Layout> // Check if the layout map is Map<Var, Layout>
auto map = auto map =
op->annotations.Get(attr::kLayoutMap)->as<Map<Var, Layout>>().value(); op->annotations.Get(attr::kLayoutMap)->as<Map<Var, Layout>>().value();
for (const auto &[var, layout] : map) { for (const auto &[var, layout] : map) {
ICHECK(buffer_data_to_buffer_.count(var)) ICHECK(buffer_data_to_buffers_.count(var))
<< "buffer " << var << " is not found in the block"; << "buffer " << var << " is not found in the block";
auto buffer = buffer_data_to_buffer_[var]; const auto &buffers = buffer_data_to_buffers_[var];
ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape)); ICHECK(!buffers.empty()) << "buffer list for " << var << " is empty";
annotated_layout_map_.Set(buffer, layout); // Apply layout to all buffers associated with this var
for (const auto &buffer : buffers) {
// Reshape the layout to match the buffer's shape
// Check if shapes are structurally equal
bool shapes_equal =
layout->InputShape().size() == buffer->shape.size();
if (shapes_equal) {
for (size_t i = 0; i < layout->InputShape().size(); ++i) {
if (!analyzer_.CanProveEqual(layout->InputShape()[i],
buffer->shape[i])) {
shapes_equal = false;
break;
}
}
}
if (shapes_equal) {
annotated_layout_map_.Set(buffer, layout);
} else {
auto reshaped_layout = layout->Reshape(buffer->shape, &analyzer_);
annotated_layout_map_.Set(buffer, reshaped_layout);
}
}
} }
} }
IRVisitorWithAnalyzer::VisitStmt_(op);
} }
void VisitStmt_(const AttrStmtNode *op) final { void VisitStmt_(const AttrStmtNode *op) final {
...@@ -470,7 +612,67 @@ private: ...@@ -470,7 +612,67 @@ private:
IRVisitorWithAnalyzer::VisitStmt_(op); IRVisitorWithAnalyzer::VisitStmt_(op);
} }
Map<Var, Buffer> buffer_data_to_buffer_; void VisitExpr_(const BufferLoadNode *op) final {
// Collect buffer from BufferLoad
if (op->buffer.defined() && op->buffer->data.defined()) {
if (buffer_data_to_buffers_.count(op->buffer->data)) {
// Check if this buffer is already in the list
auto buffers = buffer_data_to_buffers_[op->buffer->data];
bool found = false;
for (const auto &buf : buffers) {
if (buf.same_as(op->buffer)) {
found = true;
break;
}
}
if (!found) {
buffers.push_back(op->buffer);
buffer_data_to_buffers_.Set(op->buffer->data, buffers);
DLOG(INFO) << "[LayoutInference] BufferLoad: added buffer "
<< op->buffer << " buffer.get() = " << op->buffer.get()
<< " data = " << op->buffer->data.get();
}
} else {
buffer_data_to_buffers_.Set(op->buffer->data, {op->buffer});
DLOG(INFO) << "[LayoutInference] BufferLoad: new buffer " << op->buffer
<< " buffer.get() = " << op->buffer.get()
<< " data = " << op->buffer->data.get();
}
}
IRVisitorWithAnalyzer::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode *op) final {
// Collect buffer from BufferStore
if (op->buffer.defined() && op->buffer->data.defined()) {
if (buffer_data_to_buffers_.count(op->buffer->data)) {
// Check if this buffer is already in the list
auto buffers = buffer_data_to_buffers_[op->buffer->data];
bool found = false;
for (const auto &buf : buffers) {
if (buf.same_as(op->buffer)) {
found = true;
break;
}
}
if (!found) {
buffers.push_back(op->buffer);
buffer_data_to_buffers_.Set(op->buffer->data, buffers);
DLOG(INFO) << "[LayoutInference] BufferStore: added buffer "
<< op->buffer << " buffer.get() = " << op->buffer.get()
<< " data = " << op->buffer->data.get();
}
} else {
buffer_data_to_buffers_.Set(op->buffer->data, {op->buffer});
DLOG(INFO) << "[LayoutInference] BufferStore: new buffer " << op->buffer
<< " buffer.get() = " << op->buffer.get()
<< " data = " << op->buffer->data.get();
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}
Map<Var, Array<Buffer>> buffer_data_to_buffers_;
std::vector<ObjectRef> infer_list_stmt_; std::vector<ObjectRef> infer_list_stmt_;
std::vector<TileOperator> infer_list_; std::vector<TileOperator> infer_list_;
std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual> std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
...@@ -513,12 +715,33 @@ private: ...@@ -513,12 +715,33 @@ private:
if (infer_indices.empty()) if (infer_indices.empty())
continue; continue;
// Union all infer_list_ indices that share the same buffer // Union all infer_list_ indices that share the same Buffer object
int first_idx = infer_indices[0]; int first_idx = infer_indices[0];
for (size_t i = 1; i < infer_indices.size(); i++) { for (size_t i = 1; i < infer_indices.size(); i++) {
uf.Union(first_idx, infer_indices[i]); uf.Union(first_idx, infer_indices[i]);
} }
} }
// Additionally, union across buffers that share the same underlying
// buffer->data (Var). This handles cases like reshape where multiple
// Buffer objects alias the same storage.
for (const auto &[var, buffers] : buffer_data_to_buffers_) {
std::vector<int> merged;
for (const auto &buf : buffers) {
auto it = use_list_.find(buf);
if (it != use_list_.end()) {
const auto &vec = it->second;
merged.insert(merged.end(), vec.begin(), vec.end());
}
}
if (merged.size() > 1) {
std::sort(merged.begin(), merged.end());
merged.erase(std::unique(merged.begin(), merged.end()), merged.end());
int first = merged[0];
for (size_t i = 1; i < merged.size(); ++i) {
uf.Union(first, merged[i]);
}
}
}
std::unordered_map<int, std::vector<int>> components; std::unordered_map<int, std::vector<int>> components;
for (int i = 0; i < infer_list_.size(); i++) { for (int i = 0; i < infer_list_.size(); i++) {
int root = uf.Find(i); int root = uf.Find(i);
...@@ -597,7 +820,9 @@ private: ...@@ -597,7 +820,9 @@ private:
} }
} }
// Update the best plan if this one uses fewer registers // Update the best plan if this one uses fewer registers
if (reg_num < min_reg_num) { if (reg_num < min_reg_num ||
(reg_num == min_reg_num &&
attempt_infer_root < min_reg_num_infer_root)) {
best_infer_list = best_infer_list =
BackupInferList(); // Use backup to avoid moving out infer_list_ BackupInferList(); // Use backup to avoid moving out infer_list_
best_layout_map = tmp_layout_map; best_layout_map = tmp_layout_map;
...@@ -711,8 +936,8 @@ private: ...@@ -711,8 +936,8 @@ private:
.value(); .value();
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op)); For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
if (result_.for_map.count(GetRef<For>(op))) { if (result_.for_map.count(tvm::ffi::GetRef<For>(op))) {
auto root = GetRef<For>(op); auto root = tvm::ffi::GetRef<For>(op);
// This check is a workaround to support T.Parallel for local buffers. // This check is a workaround to support T.Parallel for local buffers.
// For example: // For example:
// for i in T.Parallel(1024): // for i in T.Parallel(1024):
...@@ -787,7 +1012,18 @@ private: ...@@ -787,7 +1012,18 @@ private:
} }
}); });
if (has_non_local && !has_reducer) { // If a cast operation exists, vectorization may still be required
bool has_cast_operations = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
// Check if this is a non-reducer store with Cast operation
if (store->value.as<CastNode>()) {
has_cast_operations = true;
}
}
});
if ((has_non_local || has_cast_operations) && !has_reducer) {
for_node = VectorizeLoop(for_node); for_node = VectorizeLoop(for_node);
} }
...@@ -831,10 +1067,10 @@ tvm::transform::Pass LayoutInference() { ...@@ -831,10 +1067,10 @@ tvm::transform::Pass LayoutInference() {
return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LayoutInference", LayoutInference); refl::GlobalDef().def("tl.transform.LayoutInference", LayoutInference);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "../layout/layout.h" #include "../layout/layout.h"
#include "../op/fill.h" #include "../op/fill.h"
#include "../op/finalize_reducer.h" #include "../op/finalize_reducer.h"
#include "../op/region.h"
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
#include "layout_reducer.h" #include "layout_reducer.h"
...@@ -275,17 +276,34 @@ private: ...@@ -275,17 +276,34 @@ private:
auto op = op_ref.CopyOnWrite(); auto op = op_ref.CopyOnWrite();
if (op->op.same_as(Fill::Get())) { if (op->op.same_as(Fill::Get())) {
ICHECK(!op->args.empty()); ICHECK(!op->args.empty());
if (auto arg0_call = op->args[0].as<Call>(); if (auto arg0_call = op->args[0].as<Call>()) {
arg0_call && // Case 1: tl.region(...) — extract buffer var from its first arg
arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) { if (arg0_call.value()->op.same_as(RegionOp::Get())) {
ICHECK(arg0_call.value()->args.size() > 1); ICHECK(!arg0_call.value()->args.empty());
if (auto var = arg0_call.value()->args[1].as<Var>(); if (auto bl = arg0_call.value()->args[0].as<BufferLoadNode>()) {
var && reducer_info_map_.count(var.value())) { Var var = bl->buffer->data;
ICHECK(inside_reducer_range_.count(var.value()) == 0) if (reducer_info_map_.count(var)) {
<< "T.fill on reducer must be enclosed with a T.finalize_reducer " ICHECK(inside_reducer_range_.count(var) == 0)
"before next."; << "T.fill on reducer must be enclosed with a "
inside_reducer_range_.Set(var.value(), "T.finalize_reducer "
reducer_info_map_.Get(var.value()).value()); "before next.";
inside_reducer_range_.Set(var,
reducer_info_map_.Get(var).value());
}
}
}
// Case 2: builtin.tvm_access_ptr(...) — existing path
else if (arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) {
ICHECK(arg0_call.value()->args.size() > 1);
if (auto var = arg0_call.value()->args[1].as<Var>();
var && reducer_info_map_.count(var.value())) {
ICHECK(inside_reducer_range_.count(var.value()) == 0)
<< "T.fill on reducer must be enclosed with a "
"T.finalize_reducer "
"before next.";
inside_reducer_range_.Set(
var.value(), reducer_info_map_.Get(var.value()).value());
}
} }
} }
} else if (op->op.same_as(FinalizeReducerOp::Get())) { } else if (op->op.same_as(FinalizeReducerOp::Get())) {
...@@ -362,10 +380,10 @@ tvm::transform::Pass LayoutReducer() { ...@@ -362,10 +380,10 @@ tvm::transform::Pass LayoutReducer() {
return CreatePrimFuncPass(pass_func, 0, "tl.LayoutReducer", {}); return CreatePrimFuncPass(pass_func, 0, "tl.LayoutReducer", {});
} }
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LayoutReducer", LayoutReducer); refl::GlobalDef().def("tl.transform.LayoutReducer", LayoutReducer);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -66,17 +66,17 @@ struct ReducerInfoNode : Object { ...@@ -66,17 +66,17 @@ struct ReducerInfoNode : Object {
ReducerInfoNode() = default; ReducerInfoNode() = default;
ReducerInfoNode(const String &op_str, const String &rep_str); ReducerInfoNode(const String &op_str, const String &rep_str);
static constexpr const char *_type_key = "tl.ReducerInfo"; TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReducerInfo", ReducerInfoNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ReducerInfoNode, Object);
}; };
struct ReducerInfo : ObjectRef { struct ReducerInfo : ObjectRef {
public: public:
TVM_DLL ReducerInfo(const String &op_str, const String &rep_str) { TVM_DLL ReducerInfo(const String &op_str, const String &rep_str) {
data_ = make_object<ReducerInfoNode>(op_str, rep_str); data_ = tvm::ffi::make_object<ReducerInfoNode>(op_str, rep_str);
} }
TVM_DEFINE_OBJECT_REF_METHODS(ReducerInfo, ObjectRef, ReducerInfoNode); TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReducerInfo, ObjectRef,
ReducerInfoNode);
}; };
namespace attr { namespace attr {
......
/*!
* \file legalize_negative_index.cc
* \brief Legalize negative indices in buffer load expressions.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <vector>
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
namespace tvm {
namespace tl {
using namespace tir;
using arith::IRVisitorWithAnalyzer;
enum class IndexSignState { kNonNegative, kNegative, kUnknown };
class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer {
public:
explicit NegativeIndexAnalyzer(
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
*result)
: result_(result) {}
void VisitExpr_(const BufferLoadNode *op) final {
auto load = tvm::ffi::GetRef<BufferLoad>(op);
std::vector<IndexSignState> states;
states.reserve(op->indices.size());
bool needs_record = false;
for (size_t i = 0; i < op->indices.size(); ++i) {
PrimExpr simplified = analyzer_.Simplify(op->indices[i]);
// Handle scalar indices with the standard analyzer
if (simplified.dtype().lanes() == 1) {
if (analyzer_.CanProve(simplified >= 0)) {
states.push_back(IndexSignState::kNonNegative);
continue;
}
if (analyzer_.CanProve(simplified < 0)) {
states.push_back(IndexSignState::kNegative);
needs_record = true;
continue;
}
states.push_back(IndexSignState::kUnknown);
needs_record = true;
DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << load->buffer->name << " (axis "
<< i << ").";
continue;
}
// Vector indices: try to reason about non-negativity/negativity
// Common patterns are Ramp(base, stride, lanes) and Broadcast(value,
// lanes).
IndexSignState vec_state = IndexSignState::kUnknown;
if (const auto *ramp = simplified.as<RampNode>()) {
// Compute a safe lower/upper bound for the vector lanes
// lower_bound = base_min + min(0, stride_min) * (lanes - 1)
// upper_bound = base_max + max(0, stride_max) * (lanes - 1)
auto base_bound = analyzer_.const_int_bound(ramp->base);
auto stride_bound = analyzer_.const_int_bound(ramp->stride);
int lanes = *as_const_int(ramp->lanes);
int64_t base_min = base_bound->min_value;
int64_t base_max = base_bound->max_value;
int64_t s_min = stride_bound->min_value;
int64_t s_max = stride_bound->max_value;
// Guard against overflow is not strictly necessary here because
// bounds may be +/-inf represented by sentinel values.
int64_t lower = base_min;
if (s_min < 0)
lower += s_min * (lanes - 1);
int64_t upper = base_max;
if (s_max > 0)
upper += s_max * (lanes - 1);
if (lower >= 0) {
vec_state = IndexSignState::kNonNegative;
} else if (upper < 0) {
vec_state = IndexSignState::kNegative;
} else {
vec_state = IndexSignState::kUnknown;
}
} else if (const auto *bc = simplified.as<BroadcastNode>()) {
auto v = analyzer_.Simplify(bc->value);
if (analyzer_.CanProve(v >= 0)) {
vec_state = IndexSignState::kNonNegative;
} else if (analyzer_.CanProve(v < 0)) {
vec_state = IndexSignState::kNegative;
} else {
// Try const bound if proof unavailable
auto vb = analyzer_.const_int_bound(v);
if (vb->min_value >= 0) {
vec_state = IndexSignState::kNonNegative;
} else if (vb->max_value < 0) {
vec_state = IndexSignState::kNegative;
} else {
vec_state = IndexSignState::kUnknown;
}
}
}
if (vec_state == IndexSignState::kNonNegative) {
states.push_back(IndexSignState::kNonNegative);
continue;
}
if (vec_state == IndexSignState::kNegative) {
states.push_back(IndexSignState::kNegative);
needs_record = true;
continue;
}
states.push_back(IndexSignState::kUnknown);
needs_record = true;
DLOG(WARNING) << "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << load->buffer->name
<< " (axis " << i << ").";
}
if (needs_record) {
(*result_)[op] = std::move(states);
}
IRVisitorWithAnalyzer::VisitExpr_(op);
}
private:
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
*result_;
};
class NegativeIndexRewriter : public arith::IRMutatorWithAnalyzer {
public:
static PrimFunc
Apply(PrimFunc func,
const std::unordered_map<const BufferLoadNode *,
std::vector<IndexSignState>> &states) {
arith::Analyzer analyzer;
NegativeIndexRewriter rewriter(&analyzer, states);
if (!func->body.defined()) {
return func;
}
PrimFuncNode *func_node = func.CopyOnWrite();
func_node->body = rewriter.VisitStmt(func_node->body);
return func;
}
private:
NegativeIndexRewriter(
arith::Analyzer *analyzer,
const std::unordered_map<const BufferLoadNode *,
std::vector<IndexSignState>> &states)
: arith::IRMutatorWithAnalyzer(analyzer), states_(states) {}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
BufferLoad load =
Downcast<BufferLoad>(arith::IRMutatorWithAnalyzer::VisitExpr_(op));
auto it = states_.find(op);
if (it == states_.end()) {
return load;
}
auto indices = load->indices;
bool changed = false;
const auto &state_vector = it->second;
ICHECK_EQ(state_vector.size(), indices.size())
<< "State vector size mismatch for buffer load " << load->buffer->name;
for (size_t i = 0; i < indices.size(); ++i) {
if (state_vector[i] != IndexSignState::kNegative) {
continue;
}
PrimExpr extent = load->buffer->shape[i];
indices.Set(i, analyzer_->Simplify(extent + indices[i]));
changed = true;
}
if (!changed) {
return load;
}
return BufferLoad(load->buffer, indices);
}
const std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
&states_;
};
PrimFunc LegalizeNegativeIndex(PrimFunc func) {
if (!func->body.defined()) {
return func;
}
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
states;
NegativeIndexAnalyzer analyzer(&states);
analyzer(func->body);
if (states.empty()) {
return func;
}
return NegativeIndexRewriter::Apply(std::move(func), states);
}
tvm::transform::Pass LegalizeNegativeIndexPass() {
using namespace tir::transform;
auto pass_func = [](PrimFunc f, const IRModule &, PassContext) {
return LegalizeNegativeIndex(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeNegativeIndex", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LegalizeNegativeIndex",
LegalizeNegativeIndexPass);
}
} // namespace tl
} // namespace tvm
...@@ -38,7 +38,7 @@ private: ...@@ -38,7 +38,7 @@ private:
StmtVisitor::VisitStmt(op->body); StmtVisitor::VisitStmt(op->body);
if (!has_child_for_) { if (!has_child_for_) {
leaf_for_nodes.push_back(GetRef<For>(op)); leaf_for_nodes.push_back(tvm::ffi::GetRef<For>(op));
} }
parent_has_child_for_ = parent_has_child_for; parent_has_child_for_ = parent_has_child_for;
...@@ -378,11 +378,11 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() { ...@@ -378,11 +378,11 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() {
} }
// Register the pass globally so it can be used in the compilation pipeline // Register the pass globally so it can be used in the compilation pipeline
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LegalizeSafeMemoryAccess", refl::GlobalDef().def("tl.transform.LegalizeSafeMemoryAccess",
LegalizeSafeMemoryAccess); LegalizeSafeMemoryAccess);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
...@@ -89,11 +89,11 @@ tvm::transform::Pass LegalizeVectorizedLoop() { ...@@ -89,11 +89,11 @@ tvm::transform::Pass LegalizeVectorizedLoop() {
} }
// Register the pass globally so it can be used in the compilation pipeline // Register the pass globally so it can be used in the compilation pipeline
TVM_FFI_STATIC_INIT_BLOCK({ TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection; namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LegalizeVectorizedLoop", refl::GlobalDef().def("tl.transform.LegalizeVectorizedLoop",
LegalizeVectorizedLoop); LegalizeVectorizedLoop);
}); }
} // namespace tl } // namespace tl
} // namespace tvm } // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment