Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
......@@ -173,7 +173,7 @@ private:
if (as_const_int(analyzer->Simplify(node->extent)) == nullptr) {
return StmtExprMutator::VisitStmt_(node);
}
For new_for = GetRef<For>(node);
For new_for = tvm::ffi::GetRef<For>(node);
auto for_ptr = new_for.CopyOnWrite();
for_ptr->annotations.Set(tir::attr::pragma_unroll_explicit, Bool(false));
for_ptr->kind = ForKind::kUnrolled;
......
......@@ -240,8 +240,9 @@ int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }
bool CanProveIndependent(const PrimExpr &expr, Var var,
arith::Analyzer *analyzer) {
// 1. if var doesn't exist, it is independent
bool used_var = UsesVar(
expr, [&](const VarNode *v) { return GetRef<Var>(v).same_as(var); });
bool used_var = UsesVar(expr, [&](const VarNode *v) {
return tvm::ffi::GetRef<Var>(v).same_as(var);
});
if (!used_var) {
return true;
}
......
/*!
* \file loop_vectorize_dynamic.cc
* \brief A tool to automatically vectorize a for loop with dynamic shape
* \brief Reference to loop_vectorize.cc and vectorize_loop.cc
*/
#include <cstdint>
#include <tvm/arith/iter_affine_map.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <numeric>
#include <utility>
#include "../layout/layout.h"
#include "../layout/utils.h"
#include "../op/builtin.h"
#include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_vectorization_utils.h"
namespace tvm {
namespace tl {
using namespace tir;
using arith::IRMutatorWithAnalyzer;
struct VectorizePlanResult {
int vector_size;
bool dynamic;
PrimExpr condition;
};
bool IndiceCanVectorizeDynamic(const PrimExpr &expr, Var var,
const PrimExpr &iter_var_size,
int target_vectorized_size,
arith::Analyzer *analyzer) {
ICHECK(target_vectorized_size >= 1);
if (target_vectorized_size == 1)
return true;
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_vectorized_size),
0))
return false;
Var v0("v0"), v1("v1");
analyzer->Bind(v0, Range(0, target_vectorized_size));
analyzer->Bind(v1, Range(0, FloorDiv(iter_var_size, target_vectorized_size)));
PrimExpr expr_transformed = analyzer->Simplify(
Substitute(expr, {{var, v0 + v1 * target_vectorized_size}}));
Vectorizer vectorizer(v0, IntImm(v0->dtype, target_vectorized_size));
PrimExpr expr_vectorized = vectorizer.VisitExpr(expr_transformed);
auto ramp_node = expr_vectorized.as<RampNode>();
if (!ramp_node) {
// Broadcast value
if (expr_vectorized.dtype().lanes() == 1)
return true;
else
return false;
} else {
return is_one(ramp_node->stride);
}
}
class VectorizePlannerDynamic : public arith::IRVisitorWithAnalyzer {
public:
VectorizePlannerDynamic(int dynamic_alignment,
bool disable_dynamic_tail_split)
: dynamic_alignment_(dynamic_alignment),
disable_dynamic_tail_split_(disable_dynamic_tail_split),
vector_load_bits_max_(128) {
if (disable_dynamic_tail_split_) {
vector_size_ = dynamic_alignment_;
} else {
vector_size_ = vector_load_bits_max_;
}
}
int Plan(const For &node) {
this->operator()(node);
// Always Enable vectorization
// if (!has_nonlocal_memory_access_) return 1;
return vector_size_;
}
bool GetDynamic() { return dynamic_; }
PrimExpr GetCondition() { return condition_; }
private:
void VisitStmt_(const ForNode *node) final {
inner_for_ = node;
iter_map_.Set(node->loop_var, Range(node->min, node->extent));
arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}
void VisitExpr_(const BufferLoadNode *node) final {
if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
node->buffer.scope() == "shared.dyn")
has_nonlocal_memory_access_ = true;
if (node->buffer->shape.size() == 1) {
// TODO(lei): This should be improved as
// constant buffer that tl hack to use as local register.
auto boundary_check = node->buffer->shape[0].as<IntImmNode>();
if (boundary_check && boundary_check->value == 1) {
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}
}
UpdateVectorSize(node->indices, node->buffer);
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}
void VisitStmt_(const BufferStoreNode *node) final {
if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
node->buffer.scope() == "shared.dyn")
has_nonlocal_memory_access_ = true;
UpdateVectorSize(node->indices, node->buffer);
return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}
void VisitStmt_(const IfThenElseNode *node) final {
CheckConditionVectorized(node->condition);
return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}
void VisitExpr_(const CallNode *node) final {
if (node->op == builtin::if_then_else()) {
CheckConditionVectorized(node->args[0]);
} else if (node->op == builtin::call_extern()) {
// do not vectorize extern calls
vector_size_ = 1;
}
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}
void CheckConditionVectorized(const PrimExpr &cond) {
// TODO: may perform some checks here
}
void UpdateVectorSize(const Array<PrimExpr> &indices, const Buffer &buffer) {
if (!inner_for_)
return;
auto extent_ptr = inner_for_->extent.as<IntImmNode>();
if (!extent_ptr)
return;
const DataType &access_type = buffer->dtype;
// i // 2, i % 8 can also be vectorized as factor 16
int max_vector_size = vector_load_bits_max_ / access_type.bits();
// so we should disable this GCD optimization
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
auto last_dim = buffer->shape.back();
auto mod_set = analyzer_.modular_set(last_dim);
// when dynamic shape like [m, k]: coeff=1, base=0, GCD will block
// conditionally tail vectorize
if (buffer->shape.back().as<IntImmNode>()) {
max_vector_size = arith::ZeroAwareGCD(max_vector_size, mod_set->coeff);
auto gcd_base = arith::ZeroAwareGCD(max_vector_size, mod_set->base);
// If gcd_base is equal to the last dimension,
// we should analyze the second-to-last dimension
// in relation to the last dimension.
if (gcd_base < Downcast<IntImm>(last_dim)->value) {
max_vector_size = gcd_base;
}
vector_size_ = arith::ZeroAwareGCD(max_vector_size, vector_size_);
PrimExpr elem_offset = 0;
PrimExpr stride = 1;
for (int i = indices.size() - 1; i >= 0; --i) {
elem_offset = elem_offset + indices[i] * stride;
stride = stride * buffer->shape[i];
}
while (!IndiceCanVectorizeDynamic(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_,
&analyzer_)) {
vector_size_ /= 2;
}
} else {
// dynamic shape load: get the vectorization condition
dynamic_ = true;
if (!disable_dynamic_tail_split_ &&
vector_size_ >= vector_load_bits_max_ / buffer->dtype.bits()) {
vector_size_ = vector_load_bits_max_ / buffer->dtype.bits();
}
PrimExpr offset = buffer.OffsetOf(indices).back();
// condition for alignment, maybe useless
condition_ = (FloorMod(offset, vector_size_) == 0);
}
}
// Use dynamic alignment from pass config
int vector_load_bits_max_;
int dynamic_alignment_;
bool disable_dynamic_tail_split_;
int vector_size_;
const ForNode *inner_for_{};
Map<Var, Range> iter_map_;
bool has_nonlocal_memory_access_ = false;
// conditionally vectorize
bool dynamic_ = false;
PrimExpr condition_;
};
class VectorizedBodyMutator : public StmtExprMutator {
public:
VectorizedBodyMutator(Var inner_var, int vector_size,
std::vector<PrimExpr> conditions)
: inner_var_(std::move(inner_var)), vector_size_(vector_size),
conditions_(std::move(conditions)) {}
private:
PrimExpr VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::if_then_else())) {
// TODO: Currently not ramp, but only reserve the "then" part (because
// conditions are move outside this vectorized loop)
PrimExpr ifexpr = op->args[0];
PrimExpr thenexpr = op->args[1];
bool flag = false;
for (auto &cond : conditions_) {
if (ifexpr.get() == cond.get()) {
flag = true;
}
}
if (flag) {
return thenexpr;
} else {
return GetRef<PrimExpr>(op);
}
} else {
return GetRef<PrimExpr>(op);
}
}
Var inner_var_;
int vector_size_;
std::vector<PrimExpr> conditions_;
};
class VectorizedConditionExtractor : public StmtExprVisitor {
public:
VectorizedConditionExtractor() = default;
std::vector<PrimExpr> GetConditions(const Stmt &body) {
this->VisitStmt(body);
return conditions_;
}
private:
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::if_then_else())) {
PrimExpr cond = op->args[0];
conditions_.emplace_back(cond);
}
StmtExprVisitor::VisitExpr_(op);
}
void VisitStmt_(const IfThenElseNode *node) final {
conditions_.emplace_back(node->condition);
StmtExprVisitor::VisitStmt_(node);
}
std::vector<PrimExpr> conditions_;
};
// backward-compatibility: extracter -> extractor
using VectorizedConditionExtracter = VectorizedConditionExtractor;
class NestedLoopChecker : public StmtExprVisitor {
public:
NestedLoopChecker() : loop_num_(0) {}
int GetNestLoopNum(const Stmt &body) {
this->VisitStmt(body);
return loop_num_;
}
private:
void VisitStmt_(const ForNode *node) final {
loop_num_++;
StmtExprVisitor::VisitStmt_(node);
}
int loop_num_;
};
// Modify every subexpression in the condition
class VectorizedConditionMutator : public StmtExprMutator {
public:
VectorizedConditionMutator(Var inner_var, int extent)
: inner_var_(std::move(inner_var)), vector_size_(extent) {}
private:
PrimExpr VisitExpr_(const GENode *node) final {
PrimExpr lhs = StmtExprMutator::VisitExpr(node->a);
PrimExpr rhs = StmtExprMutator::VisitExpr(node->b);
auto span = node->span;
Map<Var, PrimExpr> vmap_lhs, vmap_rhs;
vmap_lhs.Set(inner_var_, 0);
PrimExpr lhs_bound = Substitute(lhs, vmap_lhs);
vmap_rhs.Set(inner_var_, vector_size_ - 1);
PrimExpr rhs_bound = Substitute(rhs, vmap_rhs);
return GE(lhs_bound, rhs_bound, span);
}
PrimExpr VisitExpr_(const GTNode *node) final {
PrimExpr lhs = StmtExprMutator::VisitExpr(node->a);
PrimExpr rhs = StmtExprMutator::VisitExpr(node->b);
auto span = node->span;
Map<Var, PrimExpr> vmap_lhs, vmap_rhs;
vmap_lhs.Set(inner_var_, 0);
PrimExpr lhs_bound = Substitute(lhs, vmap_lhs);
vmap_rhs.Set(inner_var_, vector_size_ - 1);
PrimExpr rhs_bound = Substitute(rhs, vmap_rhs);
return GT(lhs_bound, rhs_bound, span);
}
PrimExpr VisitExpr_(const LENode *node) final {
PrimExpr lhs = StmtExprMutator::VisitExpr(node->a);
PrimExpr rhs = StmtExprMutator::VisitExpr(node->b);
auto span = node->span;
Map<Var, PrimExpr> vmap_lhs, vmap_rhs;
vmap_lhs.Set(inner_var_, vector_size_ - 1);
PrimExpr lhs_bound = Substitute(lhs, vmap_lhs);
vmap_rhs.Set(inner_var_, 0);
PrimExpr rhs_bound = Substitute(rhs, vmap_rhs);
return LE(lhs_bound, rhs_bound, span);
}
PrimExpr VisitExpr_(const LTNode *node) final {
PrimExpr lhs = StmtExprMutator::VisitExpr(node->a);
PrimExpr rhs = StmtExprMutator::VisitExpr(node->b);
auto span = node->span;
Map<Var, PrimExpr> vmap_lhs, vmap_rhs;
vmap_lhs.Set(inner_var_, vector_size_ - 1);
PrimExpr lhs_bound = Substitute(lhs, vmap_lhs);
vmap_rhs.Set(inner_var_, 0);
PrimExpr rhs_bound = Substitute(rhs, vmap_rhs);
return LT(lhs_bound, rhs_bound, span);
}
Var inner_var_;
int vector_size_;
};
class VectorizeRewriterDynamic : public StmtExprMutator {
public:
VectorizeRewriterDynamic(const VectorizePlanResult &plan,
bool disable_dynamic_tail_split)
: vector_size_(plan.vector_size), condition_(plan.condition),
dynamic_(plan.dynamic),
disable_dynamic_tail_split_(disable_dynamic_tail_split) {}
private:
Stmt VisitStmt_(const ForNode *node) final {
// Get pass config `tl.disable_dynamic_tail_split`
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
Optional<Bool> opt_disable_dynamic_tail_split =
ctxt->GetConfig(kDisableDynamicTailSplit, Optional<Bool>());
bool disable_dynamic_tail_split =
opt_disable_dynamic_tail_split.value_or(Bool(false));
inner_for_ = node;
auto ret = StmtExprMutator::VisitStmt_(node);
if (inner_for_ != node) {
return ret;
}
For fnode = ret.as<For>().value();
auto old_var = fnode->loop_var;
if (!fnode->extent.as<IntImmNode>()) {
return ret;
}
int extent = Downcast<IntImm>(fnode->extent)->value;
if (!dynamic_) {
return fnode;
}
if (!disable_dynamic_tail_split) {
// To handle the fact that cp.async only support address aligned with
// access size
vector_size_ = 1;
}
ICHECK(extent % vector_size_ == 0)
<< "extent: " << extent << " vector_size_: " << vector_size_;
ICHECK(is_zero(fnode->min));
Var inner_var = Var("vec");
Var outer_var = Var(old_var->name_hint);
Map<Var, PrimExpr> vmap;
vmap.Set(fnode->loop_var, outer_var * vector_size_ + inner_var);
Stmt body = Substitute(fnode->body, vmap);
VectorizedConditionExtractor extractor;
std::vector<PrimExpr> conditions = extractor.GetConditions(body);
VectorizedConditionMutator condition_mutator(inner_var, vector_size_);
// Adaptively set vectorized variable to the min/max value of the extent
PrimExpr condition_bound;
if (!conditions.empty()) {
condition_bound = condition_mutator(conditions[0]);
for (int i = 1; i < conditions.size(); ++i) {
condition_bound = condition_bound && condition_mutator(conditions[i]);
}
}
if (!disable_dynamic_tail_split) {
// If dynamic_tail_split is true, we will vectorize the loop with
// if-then-else conditions modify body in the vectorized loop
VectorizedBodyMutator mutator(inner_var, vector_size_, conditions);
Stmt vectorize_body = mutator(body);
// add condition ifthenelse here
For vectorize_for =
For(inner_var, 0, vector_size_, ForKind::kVectorized, vectorize_body);
For serial_for = For(inner_var, 0, vector_size_, ForKind::kSerial, body);
if (!conditions.empty()) {
body = IfThenElse(condition_bound, vectorize_for, serial_for);
} else {
body = vectorize_for;
}
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
} else {
// If dynamic_tail_split is false, we will directly vectorize the loop
// without dynamic tail split and if_then_else, which may lead to error
VectorizedBodyMutator mutator(inner_var, vector_size_, conditions);
Stmt vectorize_body = mutator(body);
For vectorize_for =
For(inner_var, 0, vector_size_, ForKind::kVectorized, vectorize_body);
body =
For(outer_var, 0, extent / vector_size_, fnode->kind, vectorize_for,
fnode->thread_binding, fnode->annotations, fnode->span);
return body;
}
}
const ForNode *inner_for_{};
int vector_size_;
const PrimExpr condition_;
const bool dynamic_;
const bool disable_dynamic_tail_split_;
};
VectorizePlanResult
GetVectorizePlanResultDynamic(const For &loop, int dynamic_alignment,
bool disable_dynamic_tail_split) {
VectorizePlannerDynamic planner(dynamic_alignment,
disable_dynamic_tail_split);
int vector_size = planner.Plan(loop);
bool dynamic = planner.GetDynamic();
PrimExpr condition = planner.GetCondition();
return {vector_size, dynamic, condition};
}
class LoopVectorizerDynamic : public IRMutatorWithAnalyzer {
public:
static Stmt Substitute(Stmt stmt, bool disable_dynamic_tail_split,
int dynamic_alignment) {
arith::Analyzer analyzer;
LoopVectorizerDynamic substituter(&analyzer, disable_dynamic_tail_split,
dynamic_alignment);
stmt = substituter.VisitStmt(stmt);
return stmt;
}
private:
LoopVectorizerDynamic(arith::Analyzer *analyzer,
bool disable_dynamic_tail_split, int dynamic_alignment)
: arith::IRMutatorWithAnalyzer(analyzer),
disable_dynamic_tail_split_(disable_dynamic_tail_split),
dynamic_alignment_(dynamic_alignment) {}
Stmt VisitStmt_(const ForNode *op) final {
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
VectorizePlanResult res{vector_load_bits_max_, false, 0};
res = GetVectorizePlanResultDynamic(for_node, dynamic_alignment_,
disable_dynamic_tail_split_);
NestedLoopChecker checker;
int nest_num = checker.GetNestLoopNum(for_node);
if (nest_num > 1 ||
for_node->kind == ForKind::kVectorized) { // only rewrite the innermost
// non-vectorized loop
return for_node;
}
auto rewriter = VectorizeRewriterDynamic(res, disable_dynamic_tail_split_);
return Downcast<For>(rewriter(for_node));
}
const int vector_load_bits_max_ = 128;
int dynamic_alignment_;
bool disable_dynamic_tail_split_;
};
class VectorizeSkipperDynamic : public StmtMutator {
public:
Stmt VisitStmt_(const ForNode *op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
if (op->kind == ForKind::kVectorized) {
return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body);
} else {
return stmt;
}
}
};
tvm::transform::Pass LoopVectorizeDynamic() {
using namespace tir::transform;
auto pass_func = [=](PrimFunc f, const IRModule &m, PassContext ctx) {
bool disable_dynamic_tail_split =
ctx->GetConfig<Bool>(kDisableDynamicTailSplit, Bool(true)).value();
int dynamic_alignment =
(int)(ctx->GetConfig<Integer>(kDynamicAlignment, Integer(8))
.value_or(Integer(8))
->value);
// Ensure tl.dynamic_alignment is a power of 2
if (disable_dynamic_tail_split &&
((dynamic_alignment & (dynamic_alignment - 1)) != 0)) {
LOG(FATAL) << "tl.dynamic_alignment must be a power of 2, but got "
<< dynamic_alignment;
}
auto *n = f.CopyOnWrite();
n->body = LoopVectorizerDynamic::Substitute(
std::move(n->body), disable_dynamic_tail_split, dynamic_alignment);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tl.LoopVectorizeDynamic", {});
}
// Register the pass globally so it can be used in the compilation pipeline
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LoopVectorizeDynamic",
LoopVectorizeDynamic);
});
} // namespace tl
} // namespace tvm
......@@ -36,7 +36,7 @@ namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;
namespace {
struct KernelInfo {
// The device on which the PrimFunc runs
......@@ -372,8 +372,8 @@ tvm::transform::Pass LowerDeviceKernelLaunch() {
IRModule updates;
for (const auto &[gvar, base_func] : mod->functions) {
if (auto *ptr = base_func.as<PrimFuncNode>()) {
auto prim_func =
mutator.RewriteKernelLaunchSite(gvar, GetRef<PrimFunc>(ptr));
auto prim_func = mutator.RewriteKernelLaunchSite(
gvar, tvm::ffi::GetRef<PrimFunc>(ptr));
if (!prim_func.same_as(base_func)) {
updates->Add(gvar, prim_func);
}
......@@ -388,8 +388,8 @@ tvm::transform::Pass LowerDeviceKernelLaunch() {
IRModule updates;
for (const auto &[gvar, base_func] : mod->functions) {
if (auto *ptr = base_func.as<PrimFuncNode>()) {
auto prim_func =
mutator.UpdateKernelAttributes(gvar, GetRef<PrimFunc>(ptr));
auto prim_func = mutator.UpdateKernelAttributes(
gvar, tvm::ffi::GetRef<PrimFunc>(ptr));
if (!prim_func.same_as(base_func)) {
updates->Add(gvar, prim_func);
}
......@@ -407,11 +407,11 @@ tvm::transform::Pass LowerDeviceKernelLaunch() {
"tl.LowerDeviceKernelLaunch", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerDeviceKernelLaunch",
LowerDeviceKernelLaunch);
});
}
} // namespace transform
} // namespace tl
......
......@@ -45,7 +45,7 @@ public:
Stmt VisitStmt_(const AllocateNode *op) final {
auto scope = StorageScope::Create(GetPtrStorageScope(op->buffer_var));
if (!scope.tag.empty() && scope.tag != ".dyn" && scope.tag != ".var" &&
scope.tag != ".barrier" && scope.tag != ".descriptor") {
scope.tag != ".barrier" && scope.tag.find(".descriptor") != 0) {
auto info = GetMemoryInfo(GetPtrStorageScope(op->buffer_var));
ICHECK(info.defined())
<< "Cannot find memory info of " << scope.to_string();
......@@ -143,11 +143,11 @@ Pass LowerDeviceStorageAccessInfo() {
{});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerDeviceStorageAccessInfo",
LowerDeviceStorageAccessInfo);
});
}
} // namespace transform
} // namespace tl
......
......@@ -113,14 +113,14 @@ public:
if (call->op.same_as(create_tma_descriptor()) ||
call->op.same_as(create_tma_im2col_descriptor())) {
Var var;
auto iter = desc_map_.find(GetRef<Call>(call));
auto iter = desc_map_.find(tvm::ffi::GetRef<Call>(call));
if (iter != desc_map_.end()) {
var = iter->second;
} else {
String name = call->args[2].as<Var>().value()->name_hint;
var = Var(name + "_desc",
PointerType(PrimType(cuTensorMapType()), "grid_constant"));
desc_map_[GetRef<Call>(call)] = var;
desc_map_[tvm::ffi::GetRef<Call>(call)] = var;
prefetch_calls_.push_back(
Evaluate(Call(DataType::Handle(), builtin::call_extern(),
{StringImm("tl::prefetch_tma_descriptor"), var})));
......@@ -161,10 +161,10 @@ tvm::transform::Pass LowerHopperIntrin() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerHopperIntrin", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerHopperIntrin", LowerHopperIntrin);
});
}
#endif // (CUDA_MAJOR_VERSION >= 12)
} // namespace tl
......
......@@ -37,6 +37,7 @@
namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;
class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer {
public:
......@@ -70,9 +71,9 @@ public:
PrimExpr VisitExpr_(const CallNode *op) final {
if (auto *ptr_op = op->op.as<OpNode>()) {
for (const auto &f_attr_map : attr_maps_) {
FLowerGeneral f = f_attr_map.get(GetRef<Op>(ptr_op), nullptr);
FLowerGeneral f = f_attr_map.get(tvm::ffi::GetRef<Op>(ptr_op), nullptr);
if (f != nullptr) {
PrimExpr e = GetRef<PrimExpr>(op);
PrimExpr e = tvm::ffi::GetRef<PrimExpr>(op);
PrimExpr r = f(e);
ICHECK(r.defined()) << "intrinsic rule must always return valid Expr";
if (!r.same_as(e)) {
......@@ -99,7 +100,7 @@ public:
// We use floordiv for integer analysis,
// but will need to lower them to native truncdiv instructions
PrimExpr VisitExpr_(const FloorDivNode *op) final {
auto e = GetRef<PrimExpr>(op);
auto e = tvm::ffi::GetRef<PrimExpr>(op);
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<FloorDivNode>();
if (op == nullptr)
......@@ -305,7 +306,7 @@ public:
using namespace arith;
PVar<PrimExpr> x, y;
PVar<IntImm> c;
auto e = GetRef<PrimExpr>(op);
auto e = tvm::ffi::GetRef<PrimExpr>(op);
if (max(floordiv(x, y), c).Match(e) && c.Eval()->value >= 0 &&
analyzer_->CanProveGreaterEqual(y.Eval(), 0)) {
return max(VisitExpr(truncdiv(x, y).Eval()), c.Eval());
......@@ -316,7 +317,7 @@ public:
PrimExpr VisitExpr_(const EQNode *op) final {
using namespace arith;
PVar<PrimExpr> x, y;
auto e = GetRef<PrimExpr>(op);
auto e = tvm::ffi::GetRef<PrimExpr>(op);
if ((floormod(x, y) == 0).Match(e)) {
return VisitExpr((truncmod(x, y) == 0).Eval());
}
......@@ -326,7 +327,7 @@ public:
PrimExpr VisitExpr_(const NENode *op) final {
using namespace arith;
PVar<PrimExpr> x, y;
auto e = GetRef<PrimExpr>(op);
auto e = tvm::ffi::GetRef<PrimExpr>(op);
if ((floormod(x, y) != 0).Match(e)) {
return VisitExpr((truncmod(x, y) != 0).Eval());
}
......@@ -413,10 +414,10 @@ tir::transform::Pass LowerIntrin() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerIntrin", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerIntrin", LowerIntrin);
});
}
} // namespace transform
......
......@@ -98,10 +98,10 @@ tvm::transform::Pass LowerL2Persistent() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerL2Persistent", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerL2Persistent", LowerL2Persistent);
});
}
} // namespace tl
} // namespace tvm
......@@ -119,7 +119,7 @@ private:
// Step 1. Update unit loop info.
PrimExpr min = this->VisitExpr(op->min);
PrimExpr extent = this->VisitExpr(op->extent);
if (is_one(extent) && op->annotations.empty()) {
if (is_one(extent) && IsEffectivelyEmptyAnnotation(op->annotations)) {
// handling unit loop
unit_loop_vars_[op->loop_var] = min;
}
......@@ -135,7 +135,8 @@ private:
ICHECK(op->thread_binding.defined());
String thread_tag = op->thread_binding.value()->thread_tag;
body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body);
} else if (is_one(extent) && op->annotations.empty()) {
} else if (is_one(extent) &&
IsEffectivelyEmptyAnnotation(op->annotations)) {
// Case 2. Unit loop
return body;
} else {
......@@ -150,8 +151,25 @@ private:
return body;
}
// Treat annotations as empty if they are truly empty or contain only
// the unroll hint `pragma_unroll_explicit`. This allows unit-length
// loops produced by unroll pragmas to be simplified away.
bool
IsEffectivelyEmptyAnnotation(const Map<String, ffi::Any> &annotations) const {
if (annotations.empty()) {
return true;
}
if (annotations.size() == 1) {
auto it = annotations.find(tir::attr::pragma_unroll_explicit);
if (it != annotations.end()) {
return true;
}
}
return false;
}
PrimExpr VisitExpr_(const VarNode *op) final {
Var var = GetRef<Var>(op);
Var var = tvm::ffi::GetRef<Var>(op);
auto it = unit_loop_vars_.find(var);
if (it == unit_loop_vars_.end()) {
return var;
......@@ -286,10 +304,10 @@ tir::transform::Pass LowerOpaqueBlock() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerOpaqueBlock", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerOpaqueBlock", LowerOpaqueBlock);
});
}
} // namespace tl
} // namespace tvm
......@@ -32,7 +32,7 @@ private:
: disable_shuffle_elect_(disable_shuffle_elect) {}
Stmt VisitStmt_(const BlockNode *op) final {
Block block = GetRef<Block>(op);
Block block = tvm::ffi::GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers;
// Record the mapping from buffer data var to buffer for later lookup
......@@ -204,10 +204,10 @@ tvm::transform::Pass LowerSharedBarrier() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedBarrier", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerSharedBarrier", LowerSharedBarrier);
});
}
} // namespace transform
} // namespace tl
......
......@@ -30,7 +30,7 @@ public:
private:
Stmt VisitStmt_(const BlockNode *op) final {
Block block = GetRef<Block>(op);
Block block = tvm::ffi::GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers;
if (op->annotations.count(attr::kLayoutMap)) {
auto layout_map = op->annotations.Get(attr::kLayoutMap);
......@@ -88,6 +88,8 @@ private:
Array<Var> new_data_vars;
for (auto buffer : tmem_buffers) {
auto data = buffer->data;
if (var_remap_.count(data))
continue;
auto new_data =
Var(data->name_hint, PointerType(PrimType(tmem_dtype_), "shared"));
var_remap_.Set(data, new_data);
......@@ -107,6 +109,7 @@ private:
buffer->buffer_type);
new_buffers.push_back(new_buffer);
buffer_remap_.Set(buffer, new_buffer);
buffer_data_to_buffer_.Set(new_data, new_buffer);
}
// remove the tmem buffers
......@@ -255,7 +258,15 @@ private:
op->dtype, op->op,
{op->args[0], new_data, op->args[2], op->args[3], op->args[4]});
}
return StmtExprMutator::VisitExpr_(op);
auto expr = StmtExprMutator::VisitExpr_(op);
return expr;
}
PrimExpr VisitExpr_(const VarNode *op) final {
Var var = tvm::ffi::GetRef<Var>(op);
if (var_remap_.count(var)) {
return var_remap_[var];
}
return var;
}
Stmt VisitStmt_(const AttrStmtNode *op) final {
......@@ -300,10 +311,10 @@ tvm::transform::Pass LowerSharedTmem() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedTmem", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerSharedTmem", LowerSharedTmem);
});
}
} // namespace transform
} // namespace tl
......
......@@ -39,6 +39,7 @@
namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;
using runtime::StorageRank;
using runtime::StorageScope;
......@@ -944,11 +945,11 @@ tvm::transform::Pass LowerThreadAllreduce() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerThreadAllreduce", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerThreadAllreduce",
LowerThreadAllreduce);
});
}
} // namespace transform
} // namespace tl
......
......@@ -10,6 +10,7 @@
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include <unordered_map>
#include <vector>
#include "../layout/layout.h"
#include "../layout/utils.h"
......@@ -103,55 +104,6 @@ private:
Map<Buffer, Layout> layout_remap_;
};
class BufferGemmCollector : public StmtExprVisitor {
public:
BufferGemmCollector() { Clear(); }
void Clear() { buffer_var_gemm_.clear(); }
void Collect(const Stmt &stmt) { VisitStmt(stmt); }
Array<Var> GetBufferVarGemm() { return buffer_var_gemm_; }
private:
void VisitStmt_(const EvaluateNode *op) {
const CallNode *call_node = op->value.as<CallNode>();
// Value of EvaluateNode may not be a call
if (!call_node) {
return;
}
auto call = Downcast<Call>(call_node);
if (call->op.same_as(Gemm::Get())) {
auto srcA_buffer_access_ptr = Downcast<Call>(call->args[0]);
ICHECK(srcA_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto srcA_buffer_var = Downcast<Var>(srcA_buffer_access_ptr->args[1]);
auto srcB_buffer_access_ptr = Downcast<Call>(call->args[1]);
ICHECK(srcB_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto srcB_buffer_var = Downcast<Var>(srcB_buffer_access_ptr->args[1]);
auto dst_buffer_access_ptr = Downcast<Call>(call->args[2]);
ICHECK(dst_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto dst_buffer_var = Downcast<Var>(dst_buffer_access_ptr->args[1]);
buffer_var_gemm_.push_back(srcA_buffer_var);
buffer_var_gemm_.push_back(srcB_buffer_var);
buffer_var_gemm_.push_back(dst_buffer_var);
} else if (call->op.same_as(GemmSP::Get())) {
auto srcA_buffer_access_ptr = Downcast<Call>(call->args[0]);
ICHECK(srcA_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto srcA_buffer_var = Downcast<Var>(srcA_buffer_access_ptr->args[1]);
auto srcB_buffer_access_ptr = Downcast<Call>(call->args[1]);
ICHECK(srcB_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto srcB_buffer_var = Downcast<Var>(srcB_buffer_access_ptr->args[1]);
auto dst_buffer_access_ptr = Downcast<Call>(call->args[2]);
ICHECK(dst_buffer_access_ptr->op.same_as(builtin::tvm_access_ptr()));
auto dst_buffer_var = Downcast<Var>(dst_buffer_access_ptr->args[1]);
buffer_var_gemm_.push_back(srcA_buffer_var);
buffer_var_gemm_.push_back(srcB_buffer_var);
buffer_var_gemm_.push_back(dst_buffer_var);
}
}
Array<Var> buffer_var_gemm_;
};
/*!
* \brief A class that rewrites buffer references in a statement based on a
......@@ -253,11 +205,6 @@ public:
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "LowerTileOpPass: Require the target attribute";
substituter.target_ = target.value();
// For TMA 1D, we should collect the buffers which are not used in GEMM and
// do not need swizzle
BufferGemmCollector collector;
collector.Collect(f->body);
substituter.buffer_var_gemm_ = collector.GetBufferVarGemm();
PrimFuncNode *fptr = f.CopyOnWrite();
fptr->body = substituter.VisitStmt(f->body);
fptr->body =
......@@ -301,6 +248,9 @@ private:
layout_map_.Set(buffer, layout);
}
}
// Begin a new workspace collection frame for this block scope
workspace_stack_.emplace_back();
auto block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
auto block_ptr = block.CopyOnWrite();
for (size_t i = 0; i < block->alloc_buffers.size(); i++) {
......@@ -309,9 +259,13 @@ private:
block_ptr->alloc_buffers.Set(i, buffer_remap_[buffer]);
}
}
for (const auto &buffer : workspaces_)
block_ptr->alloc_buffers.push_back(buffer);
workspaces_.clear();
// Attach any workspaces requested within this block to its alloc_buffers
if (!workspace_stack_.empty()) {
for (const auto &buffer : workspace_stack_.back()) {
block_ptr->alloc_buffers.push_back(buffer);
}
workspace_stack_.pop_back();
}
return block;
}
......@@ -435,7 +389,7 @@ private:
return expr;
}
if (const auto *var_node = expr.as<VarNode>()) {
Var var = GetRef<Var>(var_node);
Var var = tvm::ffi::GetRef<Var>(var_node);
auto it = let_bindings_.find(var);
if (it != let_bindings_.end()) {
return it->second;
......@@ -611,7 +565,7 @@ private:
let_bindings_.erase(op->var);
}
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else {
auto n = this->CopyOnWrite(op);
n->value = value;
......@@ -652,13 +606,22 @@ private:
if (call && call->op.as<GlobalVarNode>())
return Downcast<Evaluate>(IRMutatorWithAnalyzer::VisitStmt_(op));
auto tile_op = ParseOperator(GetRef<Stmt>(op), buffer_data_to_buffer_);
auto tile_op =
ParseOperator(tvm::ffi::GetRef<Stmt>(op), buffer_data_to_buffer_);
if (!tile_op.defined())
return IRMutatorWithAnalyzer::VisitStmt_(op);
AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) {
auto workspace =
decl_buffer({PrimExpr(num_elem)}, dtype, "workspace", "shared.dyn");
workspaces_.push_back(workspace);
// Record workspace under the innermost block scope so its lifetime
// covers the statements that requested it and does not sink into
// subsequently created inner blocks (e.g., GEMM macro blocks).
if (!workspace_stack_.empty()) {
workspace_stack_.back().push_back(workspace);
} else {
// Fallback: create a temporary frame (should be rare)
workspace_stack_.emplace_back(Array<Buffer>{workspace});
}
return workspace.access_ptr(2); // write
};
......@@ -676,10 +639,10 @@ private:
thread_bounds = Range::FromMinExtent(0, 1);
}
auto lowered = tile_op->Lower(
LowerArgs{target_, thread_bounds, thread_var_->var, callback,
layout_map_, buffer_remap_, buffer_var_gemm_},
analyzer_);
auto lowered =
tile_op->Lower(LowerArgs{target_, thread_bounds, thread_var_->var,
callback, layout_map_, buffer_remap_},
analyzer_);
return IRMutatorWithAnalyzer::VisitStmt(lowered);
}
......@@ -706,7 +669,8 @@ private:
IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"),
IterVarType::kDataPar);
size_t thread_block_size_ = 0;
Array<Buffer> workspaces_;
// Stack of per-Block workspace buffers gathered while visiting children
std::vector<Array<Buffer>> workspace_stack_;
// For ptx Node, we need to remap the buffer and indices
// By access CallNode instead of BufferLoad Node.
bool is_ptx_{false};
......@@ -716,7 +680,6 @@ private:
std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
Map<Var, Var> var_remap_;
bool has_tma_{false};
Array<Var> buffer_var_gemm_;
};
namespace transform {
......@@ -730,10 +693,10 @@ tvm::transform::Pass LowerTileOp() {
return CreatePrimFuncPass(pass_func, 0, "tl.LowerTileOp", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerTileOp", LowerTileOp);
});
}
} // namespace transform
} // namespace tl
......
......@@ -42,6 +42,7 @@
namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;
static constexpr const char *kDeviceContextVar = "device_api_context";
namespace {
......@@ -168,7 +169,7 @@ private:
auto node = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (auto *gvar_ptr = node->op.as<GlobalVarNode>()) {
auto gvar = GetRef<GlobalVar>(gvar_ptr);
auto gvar = tvm::ffi::GetRef<GlobalVar>(gvar_ptr);
if (auto symbol = packed_func_methods.Get(gvar)) {
Array<PrimExpr> cpacked_args;
cpacked_args.push_back(tir::StringImm(symbol.value()));
......@@ -220,7 +221,7 @@ Optional<String> RequiresPackedAPI(const PrimFunc &func) {
// Internal function calls do not need the PackedFunc API
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
if (!global_symbol.defined()) {
if (!global_symbol) {
return std::nullopt;
}
......@@ -229,7 +230,7 @@ Optional<String> RequiresPackedAPI(const PrimFunc &func) {
PrimFunc MakePackedAPI(PrimFunc func) {
auto global_symbol = RequiresPackedAPI(func);
if (!global_symbol.defined()) {
if (!global_symbol) {
return func;
}
std::string name_hint = global_symbol.value();
......@@ -406,7 +407,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
StringImm(name_hint + "_compute_"), body);
// Set device context
if (vmap.count(device_id.get())) {
ObjectRef node = String("default");
auto node = String("default");
seq_check.push_back(AttrStmt(node, tir::attr::device_id, device_id, nop));
seq_check.push_back(
AttrStmt(node, tir::attr::device_type, device_type, nop));
......@@ -432,7 +433,7 @@ PrimFunc MakePackedAPI(PrimFunc func) {
auto shape_vectorize_expr = [&]() -> PrimExpr {
PrimExpr result = IntImm(kv.second->DefaultIndexType(), 1);
result = result * vectorize_dim;
result = FloorMod(result, dynamic_alignment);
result = FloorMod(result, IntImm(result->dtype, dynamic_alignment));
return result;
}();
shape_checks.emplace_back(AssertStmt(
......@@ -513,11 +514,11 @@ tvm::transform::Pass MakePackedAPI() {
return tvm::transform::CreateModulePass(pass_func, 0, "tl.MakePackedAPI", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MakePackedAPI",
[]() { return MakePackedAPI(); });
});
}
} // namespace tl
} // namespace tvm
......@@ -98,10 +98,10 @@ tvm::transform::Pass MergeIfStmt() {
return CreatePrimFuncPass(pass_func, 0, "tl.MergeIfStmt", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MergeIfStmt", MergeIfStmt);
});
}
} // namespace tl
} // namespace tvm
......@@ -31,6 +31,12 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <algorithm>
#include <functional>
#include <limits>
#include <optional>
#include <queue>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
#include <utility>
......@@ -38,7 +44,6 @@
#include "../op/builtin.h"
#include "../target/utils.h"
#include "runtime/thread_storage_scope.h"
#include "support/arena.h"
#include "tir/transforms/ir_utils.h"
#include "tvm/tir/function.h"
......@@ -141,6 +146,8 @@ public:
void VisitStmt_(const AllocateNode *op) final {
size_t level = scope_.size();
const VarNode *buf = op->buffer_var.get();
// Record the allocation site and depth so liveness can reason about the
// original scope.
alloc_info_[buf].alloc = op;
alloc_info_[buf].level = level;
StmtExprVisitor::VisitStmt_(op);
......@@ -155,7 +162,7 @@ public:
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
ICHECK_LT(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
if (IsAppropriateSharedMemory(tvm::ffi::GetRef<Var>(buf))) {
// set into scope_.size() - 1 for aggressive memory reuse
auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
......@@ -194,17 +201,23 @@ public:
const VarNode *buf = op->buffer->data.get();
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
// Allow buffer access at the same level or deeper scope
// Changed from < to <= to handle cases where buffer is accessed
// in expressions at the same scope level where it's allocated
// Earlier we required `alloc_level < scope_.size()`, assuming every load
// would occur strictly inside a nested scope. In practice the lowering
// pipeline may materialise reads in the very same frame that owns the
// allocation (e.g. when the buffer value is passed directly to a call),
// which used to trigger the CHECK. Treat same-level accesses as valid so
// the merged allocator can reason about their lifetime correctly.
ICHECK_LE(it->second.level, scope_.size())
<< "Load memory in places other than store.";
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
if (IsAppropriateSharedMemory(tvm::ffi::GetRef<Var>(buf))) {
auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf);
} else {
// When accessing at the same level, use that level
// When the access happens in the same scope frame as the allocation
// we attribute it to that frame instead of the outer parent. This
// keeps the liveness window tight while still accounting for nested
// scopes that legitimately touch the buffer deeper in the tree.
size_t access_level = std::min(it->second.level, scope_.size() - 1);
scope_[access_level].touched.push_back(buf);
}
......@@ -216,14 +229,17 @@ public:
// Directly reference to the variable count as a read.
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
// Allow buffer access at the same level or deeper scope
// Same rationale as the BufferLoad path above: direct references can be
// emitted at the allocation level after flattening, so accept them and
// record the touch for liveness planning.
ICHECK_LE(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
if (IsAppropriateSharedMemory(tvm::ffi::GetRef<Var>(buf))) {
auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf);
} else {
// When accessing at the same level, use that level
// Attribute same-level uses to the allocation frame, mirroring the
// BufferLoad handling to keep reuse decisions consistent.
size_t access_level = std::min(it->second.level, scope_.size() - 1);
scope_[access_level].touched.push_back(buf);
}
......@@ -245,6 +261,8 @@ public:
scope_.pop_back();
int64_t end_index = static_cast<int64_t>(linear_seq_.size());
ICHECK_GT(end_index, begin_index);
// The paired entries serve as scope sentinels once we flatten the
// control-flow tree.
e.scope_pair_offset = begin_index - end_index;
linear_seq_.push_back(e);
// record the pointer to end index.
......@@ -336,9 +354,30 @@ public:
}
private:
// Helper to record alignment for a shared/shared.dyn Var under alignment
// scope
void MarkSharedVarIfNeeded(const VarNode *op) {
if (!op || !under_alignment_scope_)
return;
auto ptr_type = op->type_annotation.as<PointerTypeNode>();
if (!ptr_type)
return;
auto scope = GetPtrStorageScope(tvm::ffi::GetRef<Var>(op));
if (scope == "shared" || scope == "shared.dyn") {
auto target = Target::Current();
ICHECK(target.defined()) << "Target is not defined";
const int alignment = TargetIsHopper(target) ? 1024 : 16;
shmem_alignment_map_[op] = alignment;
}
}
void VisitExpr_(const CallNode *op) {
if (op->op.same_as(tl::tl_gemm()) || op->op.same_as(tl::tl_gemm_sp()) ||
op->op.same_as(tl::tma_load()) || op->op.same_as(tl::tma_store())) {
op->op.same_as(tl::tma_load()) || op->op.same_as(tl::tma_store()) ||
op->op.same_as(tl::initialize_wgmma_descriptor()) ||
op->op.same_as(tl::initialize_tcgen05_descriptor())) {
// These intrinsics introduce stricter SMEM alignment requirements; mark
// the subtree.
under_alignment_scope_ = true;
StmtExprVisitor::VisitExpr_(op);
under_alignment_scope_ = false;
......@@ -348,15 +387,16 @@ private:
}
void VisitExpr_(const VarNode *op) {
auto ptr_type = op->type_annotation.as<PointerTypeNode>();
if (ptr_type && under_alignment_scope_) {
auto scope = GetPtrStorageScope(GetRef<Var>(op));
if (scope == "shared" || scope == "shared.dyn") {
auto target = Target::Current();
ICHECK(target.defined()) << "Target is not defined";
const int alignment = TargetIsHopper(target) ? 1024 : 16;
shmem_alignment_map_[op] = alignment;
}
MarkSharedVarIfNeeded(op);
StmtExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const BufferLoadNode *op) {
// If we encounter address_of(BufferLoad(...)) or any direct BufferLoad
// within an alignment scope, make sure we mark the underlying shared var.
if (op && under_alignment_scope_) {
const VarNode *data_var = op->buffer->data.get();
MarkSharedVarIfNeeded(data_var);
}
StmtExprVisitor::VisitExpr_(op);
}
......@@ -394,6 +434,8 @@ public:
enable_aggressive_merge, verbose);
finder(stmt);
shmem_alignment_map_ = SharedMemoryAlignmentPlanner::Plan(stmt);
// First compute liveness over the flattened schedule, then feed it into the
// arena packer.
this->LivenessAnalysis(finder.linear_seq_, finder.stmt_attrs_);
this->PlanMemory(finder.linear_seq_, finder.stmt_attrs_);
}
......@@ -403,65 +445,6 @@ private:
if (op->attr_key == tir::attr::thread_extent && !allocated_) {
// Allocate one dynamic shared memory allocation at the beginning of
// thread scope
int max_layer_num = 0;
std::vector<const StorageEntry *> all_entry;
for (const auto &e : const_free_map_) {
all_entry.push_back(e.second);
}
for (const StorageEntry *e : sym_free_list_) {
all_entry.push_back(e);
}
// Sort the storage entries in descending order of their total allocation
// size (in bits). This ensures that larger allocations are placed first,
// which can help minimize fragmentation and improve memory packing
// efficiency when merging shared memory buffers.
std::sort(all_entry.begin(), all_entry.end(),
[](const StorageEntry *a, const StorageEntry *b) {
return a->const_nbits > b->const_nbits;
});
for (const StorageEntry *e : all_entry) {
max_layer_num =
std::max(max_layer_num, static_cast<int>(e->allocs.size()));
}
// calculate align for each layer of each storage entry.
std::vector<int> align(max_layer_num, 0);
for (const StorageEntry *e : all_entry) {
for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
for (const VarNode *buffer : e->allocs[i]) {
const AllocateNode *alloc = shmem_allocs_[buffer];
align[i] =
std::max(align[i], alloc->dtype.bytes() * alloc->dtype.lanes());
align[i] = std::max(align[i], align_bytes_);
}
}
}
for (const StorageEntry *e : all_entry) {
PrimExpr max_inner_offset = 0;
for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
PrimExpr inner_offset = 0;
for (const VarNode *buffer : e->allocs[i]) {
const AllocateNode *alloc = shmem_allocs_[buffer];
auto alignment = align[i];
// Modern nvidia architecture performs hardware swizzling (hopper
// wgmma/tma for example) requires dynamic shared memory address to
// be aligned to 1024 bytes For other devices, we align to 16 bytes
if (shmem_alignment_map_.find(buffer) !=
shmem_alignment_map_.end()) {
alignment = std::max(align[i], shmem_alignment_map_[buffer]);
}
PrimExpr start_offset = merged_alloc_size_ + inner_offset;
PrimExpr aligned_offset =
indexdiv(start_offset + alignment - 1, alignment) * alignment;
buffer_byte_offsets_[buffer] = aligned_offset;
inner_offset =
aligned_offset - merged_alloc_size_ +
alloc->extents[0] * alloc->dtype.bytes() * alloc->dtype.lanes();
}
max_inner_offset = max(max_inner_offset, inner_offset);
}
merged_alloc_size_ += max_inner_offset;
}
if (verbose_) {
......@@ -626,18 +609,199 @@ private:
using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry;
using StmtAttr = SharedMemLinearAccessPatternFinder::StmtAttr;
struct StorageEntry {
// The constant size of the buffer in bits, only used if it is constant
uint64_t const_nbits{0};
// Allocs that shares this entry.
// The inner vector means a "layer"
// For example, it we need to allocate C in the memory of A and B:
// | A: 4096 bytes | B: 4096 bytes |
// | C: 8192 bytes |
// Then the allocs = {{A, B}, {C}}
std::vector<std::vector<const VarNode *>> allocs;
// Metadata about a single shared-memory allocation prior to merging. This
// is used to build lifetimes, alignment requirements, and final offsets.
struct BufInfo {
const VarNode *var{nullptr};
std::string name;
PrimExpr size_expr;
std::optional<int64_t> const_size_bytes; // in bytes if compile-time known.
int alignment{0}; // required byte alignment.
int start{0}; // first statement index touching the buf.
int end{0}; // one-past-last statement index.
DataType size_dtype{DataType::Int(32)};
};
// Interval describing the liveness window of a (constant-sized) allocation.
struct Interval {
int start{0};
int end{0};
size_t size_bytes{0};
int alignment{0};
const VarNode *var{nullptr};
};
// Result of a linear-scan arena packing. Offsets contain the byte offset for
// each constant-sized buffer, arena_size is the total constant footprint.
struct ArenaPlan {
size_t arena_size{0};
std::unordered_map<const VarNode *, size_t> offsets;
};
static size_t AlignUpSize(size_t value, size_t alignment) {
if (alignment == 0) {
return value;
}
size_t remainder = value % alignment;
if (remainder == 0) {
return value;
}
return value + (alignment - remainder);
}
struct FreeBlock {
size_t offset{0};
size_t size{0};
};
class FreeList {
public:
std::optional<size_t> Allocate(size_t need, size_t alignment) {
// Best-fit search: pick the slot that wastes the least space after
// alignment.
int best = -1;
size_t best_waste = std::numeric_limits<size_t>::max();
for (int i = 0, n = static_cast<int>(blocks_.size()); i < n; ++i) {
size_t aligned = AlignUpSize(blocks_[i].offset, alignment);
size_t head = aligned - blocks_[i].offset;
if (head <= blocks_[i].size && (blocks_[i].size - head) >= need) {
size_t waste = blocks_[i].size - head - need;
if (waste < best_waste) {
best_waste = waste;
best = i;
}
}
}
if (best < 0) {
return std::nullopt;
}
FreeBlock blk = blocks_[best];
size_t aligned = AlignUpSize(blk.offset, alignment);
size_t head = aligned - blk.offset;
size_t tail = blk.size - head - need;
blocks_.erase(blocks_.begin() + best);
if (head) {
blocks_.push_back({blk.offset, head});
}
if (tail) {
blocks_.push_back({aligned + need, tail});
}
Normalize();
return aligned;
}
void Free(size_t offset, size_t size) {
if (size == 0)
return;
blocks_.push_back({offset, size});
Normalize();
}
private:
void Normalize() {
if (blocks_.empty())
return;
std::sort(blocks_.begin(), blocks_.end(),
[](const FreeBlock &a, const FreeBlock &b) {
return a.offset < b.offset;
});
std::vector<FreeBlock> merged;
merged.reserve(blocks_.size());
for (const FreeBlock &blk : blocks_) {
if (merged.empty()) {
merged.push_back(blk);
continue;
}
FreeBlock &last = merged.back();
size_t last_end = last.offset + last.size;
if (blk.offset <= last_end) {
size_t blk_end = blk.offset + blk.size;
if (blk_end > last_end) {
last.size = blk_end - last.offset;
}
} else {
merged.push_back(blk);
}
}
blocks_ = std::move(merged);
}
std::vector<FreeBlock> blocks_;
};
struct ActiveInterval {
int end{0};
size_t offset{0};
size_t size{0};
const VarNode *var{nullptr};
bool operator>(const ActiveInterval &other) const {
return end > other.end;
}
};
static ArenaPlan LinearScanPack(std::vector<Interval> intervals) {
// Process intervals in program order so lifetimes correspond to the
// linearised CFG.
std::sort(intervals.begin(), intervals.end(),
[](const Interval &lhs, const Interval &rhs) {
if (lhs.start != rhs.start) {
return lhs.start < rhs.start;
}
if (lhs.size_bytes != rhs.size_bytes) {
return lhs.size_bytes > rhs.size_bytes;
}
return lhs.var < rhs.var;
});
std::priority_queue<ActiveInterval, std::vector<ActiveInterval>,
std::greater<ActiveInterval>>
active;
FreeList freelist;
size_t arena_top = 0;
std::unordered_map<const VarNode *, size_t> offsets;
// Expire intervals that end before or at program counter `pc`.
auto retire = [&](int pc) {
while (!active.empty() && active.top().end <= pc) {
const ActiveInterval top = active.top();
active.pop();
freelist.Free(top.offset, top.size);
}
};
for (const Interval &interval : intervals) {
retire(interval.start);
size_t offset = 0;
// Try to recycle previously freed memory first; fall back to bumping the
// arena.
if (auto slot =
freelist.Allocate(interval.size_bytes, interval.alignment)) {
offset = slot.value();
} else {
offset = AlignUpSize(arena_top, interval.alignment);
arena_top = offset + interval.size_bytes;
}
active.push(ActiveInterval{interval.end, offset, interval.size_bytes,
interval.var});
offsets[interval.var] = offset;
}
return ArenaPlan{arena_top, std::move(offsets)};
}
PrimExpr AlignPrimExpr(const PrimExpr &value, int alignment) const {
if (alignment <= 1) {
return value;
}
DataType dtype = value.dtype();
ICHECK(dtype.is_int() || dtype.is_uint())
<< "Expected integer dtype for alignment, but got " << dtype;
PrimExpr align_expr = make_const(dtype, alignment);
PrimExpr adjust = make_const(dtype, alignment - 1);
return indexdiv(value + adjust, align_expr) * align_expr;
}
// Event entry in liveness analysis
struct EventEntry {
// variables we generate
......@@ -905,173 +1069,228 @@ private:
void
PlanMemory(const std::vector<StmtEntry> &seq,
const std::unordered_map<const Object *, StmtAttr> &stmt_attrs) {
std::unordered_set<const VarNode *> inplace_flag;
buffer_byte_offsets_.clear();
(void)stmt_attrs;
if (shmem_allocs_.empty()) {
merged_alloc_size_ = make_const(DataType::Int(64), 0);
return;
}
// Discover the first and last touch for every allocation.
std::unordered_map<const VarNode *, int> start_index;
std::unordered_map<const VarNode *, int> end_index;
for (size_t i = 0; i < seq.size(); ++i) {
auto it = event_map_.find(seq[i].stmt);
// scope_pair_offset <= 0 means it is either
// - leaf stmt(offset = 0)
// - end of scope(offset < 0)
// In both cases, we need to handle the kill event correctly
auto is_leaf_alloc = [&](const VarNode *var) {
return seq[i].scope_pair_offset == 0 &&
std::find(it->second.gen.begin(), it->second.gen.end(), var) !=
it->second.gen.end();
};
if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
for (const VarNode *var : it->second.kill) {
if (!is_leaf_alloc(var))
this->Free(var);
}
if (it == event_map_.end())
continue;
for (const VarNode *var : it->second.gen) {
start_index.emplace(var, static_cast<int>(i));
}
// scope_pair_offset >= 0 means it is either
// - leaf stmt(offset = 0)
// - beginning of scope(offset < 0)
// In both cases, we need to handle the gen event correctly
if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) {
for (const VarNode *var : it->second.gen) {
ICHECK(shmem_allocs_.count(var));
const AllocateNode *alloc = shmem_allocs_[var];
StorageEntry *dst_entry = FindAlloc(alloc);
alloc_map_[var] = dst_entry;
}
for (const VarNode *var : it->second.kill) {
end_index[var] = std::max(end_index[var], static_cast<int>(i) + 1);
}
}
const int seq_len = static_cast<int>(seq.size());
for (const auto &kv : start_index) {
if (!end_index.count(kv.first)) {
end_index[kv.first] = seq_len;
}
}
std::vector<BufInfo> buf_infos;
buf_infos.reserve(shmem_allocs_.size());
// Build a BufInfo for all allocations that participate in liveness.
for (const auto &kv : shmem_allocs_) {
const VarNode *var = kv.first;
auto start_it = start_index.find(var);
if (start_it == start_index.end()) {
continue;
}
if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
for (const VarNode *var : it->second.kill) {
if (is_leaf_alloc(var))
this->Free(var);
BufInfo info;
info.var = var;
info.name = var->name_hint;
info.start = start_it->second;
info.end = std::max(end_index[var], info.start + 1);
info.alignment = align_bytes_;
auto align_it = shmem_alignment_map_.find(var);
if (align_it != shmem_alignment_map_.end()) {
info.alignment = std::max(info.alignment, align_it->second);
}
const AllocateNode *alloc = kv.second;
int64_t bytes_per_elem =
static_cast<int64_t>(alloc->dtype.bytes() * alloc->dtype.lanes());
DataType size_dtype = DataType::Int(32);
if (!alloc->extents.empty()) {
size_dtype = alloc->extents[0].dtype();
}
if (!size_dtype.is_int() && !size_dtype.is_uint()) {
size_dtype = DataType::Int(32);
}
PrimExpr size_expr = make_const(size_dtype, bytes_per_elem);
for (const PrimExpr &extent : alloc->extents) {
PrimExpr e = extent;
if (e.dtype() != size_dtype) {
e = cast(size_dtype, e);
}
size_expr = size_expr * e;
}
info.size_dtype = size_dtype;
info.size_expr = size_expr;
int64_t const_extent = alloc->ConstantAllocationSize();
if (const_extent >= 0) {
info.const_size_bytes = const_extent * bytes_per_elem;
}
buf_infos.push_back(std::move(info));
}
}
/*!
* \brief Allocate new storage entry.
* \param op the allocate node
* \param the size of the allocation in bits
* \return the new storage entry
*/
StorageEntry *NewAlloc(const AllocateNode *op, size_t const_nbits) {
ICHECK(op != nullptr);
// Reuse not successful, allocate a new buffer.
StorageEntry *entry = arena_.make<StorageEntry>();
entry->allocs.push_back({op->buffer_var.get()});
entry->const_nbits = const_nbits;
return entry;
}
/*!
* @brief Locate or create a storage entry from free lists to satisfy an
* AllocateNode.
*
* Finds a reusable StorageEntry for the given AllocateNode (constant or
* symbolic size) using two-tiered strategies:
* - For constant-size allocations (>0): prefer a free entry that is >=
* required size; if none, coalesce smaller free constant-size entries until
* the sum meets the request and return a new StorageEntry representing the
* merged space. Very small constant allocations (<= 32 bits) are not reused
* and will allocate a fresh entry.
* - For symbolic-size (unknown at compile time): pick and remove an arbitrary
* entry from the symbolic free list.
*
* If no suitable free entry is found, a fresh StorageEntry is created via
* NewAlloc.
*
* @param op Pointer to the AllocateNode to satisfy. Must be non-null.
* @return StorageEntry* A storage entry that will hold the allocation (may be
* newly created).
*/
StorageEntry *FindAlloc(const AllocateNode *op) {
ICHECK(op != nullptr);
// skip plan for local variable,
// compiler can do a better job with register allocation.
const uint64_t match_range = 16;
uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes();
uint64_t const_nbits =
static_cast<uint64_t>(op->ConstantAllocationSize() * op_elem_bits);
// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
if (const_nbits > 0 && const_nbits <= 32) {
return NewAlloc(op, const_nbits);
// Stable order so the later passes have deterministic behaviour.
std::sort(buf_infos.begin(), buf_infos.end(),
[](const BufInfo &a, const BufInfo &b) {
if (a.start != b.start)
return a.start < b.start;
if (a.end != b.end)
return a.end < b.end;
return a.name < b.name;
});
std::vector<Interval> intervals;
intervals.reserve(buf_infos.size());
for (const BufInfo &info : buf_infos) {
if (!info.const_size_bytes.has_value())
continue;
// Only constant-sized buffers participate in the arena packing because
// dynamic sizes must be placed sequentially later.
Interval interval;
interval.start = info.start;
interval.end = info.end;
interval.size_bytes = static_cast<size_t>(
std::max<int64_t>(0, info.const_size_bytes.value()));
interval.alignment = info.alignment;
interval.var = info.var;
intervals.push_back(interval);
}
if (const_nbits != 0) {
// constant allocation.
auto begin = const_free_map_.lower_bound(0);
auto mid = const_free_map_.lower_bound(const_nbits);
auto end = const_free_map_.upper_bound(const_nbits * match_range);
// Start looking at the buffer that is bigger than the required size
// first. If we find one, directly allocate the buffer in its location and
// remove its entry in the free list
for (auto it = mid; it != end; ++it) {
StorageEntry *e = it->second;
e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it);
it->second->allocs.push_back({op->buffer_var.get()});
return e;
ArenaPlan plan = LinearScanPack(std::move(intervals));
size_t arena_size_const = plan.arena_size;
if (verbose_) {
LOG(DEBUG) << "ArenaPlan (constant buffers): arena_size="
<< arena_size_const;
for (const auto &kv : plan.offsets) {
const VarNode *var = kv.first;
LOG(DEBUG) << " " << var->name_hint << " -> offset=" << kv.second;
}
// Then start looking at smaller buffers.
// Keep collecting the buffer until the sum of their size exceeds the
// buffer to allocate and finally free all these entry in the free list
std::vector<std::multimap<uint64_t, StorageEntry *>::iterator> delete_it;
// the alloc list for the new entry
std::vector<std::vector<const VarNode *>> reuse_allocs;
uint64_t mem_ct = 0;
for (auto it = mid; it != begin;) {
--it;
delete_it.push_back(it);
mem_ct += it->second->const_nbits;
int n = it->second->allocs.size();
if (n > static_cast<int>(reuse_allocs.size())) {
reuse_allocs.resize(n, {});
}
for (int i = 0; i < n; i++) {
for (const VarNode *alloc : it->second->allocs[i]) {
reuse_allocs[i].push_back(alloc);
}
}
if (mem_ct >= const_nbits) {
break;
}
}
// Cursor tracks the running byte offset within the merged arena.
DataType offset_dtype =
buf_infos.empty() ? DataType::Int(32) : buf_infos.front().size_dtype;
PrimExpr total_size = make_const(offset_dtype, 0);
PrimExpr cursor = AlignPrimExpr(
make_const(offset_dtype, static_cast<int64_t>(arena_size_const)),
align_bytes_);
auto CastToOffset = [&](PrimExpr expr) -> PrimExpr {
if (expr.dtype() == offset_dtype) {
return expr;
}
reuse_allocs.push_back({op->buffer_var.get()});
if (mem_ct != 0) {
StorageEntry *e = arena_.make<StorageEntry>();
e->const_nbits = std::max(const_nbits, mem_ct);
e->allocs = reuse_allocs;
for (auto it : delete_it) {
const_free_map_.erase(it);
}
return e;
return cast(offset_dtype, expr);
};
for (const BufInfo &info : buf_infos) {
PrimExpr offset_expr;
auto it = plan.offsets.find(info.var);
if (it != plan.offsets.end()) {
offset_expr =
make_const(offset_dtype, static_cast<int64_t>(it->second));
} else {
// Dynamic-sized buffers are appended after the constant arena.
cursor = AlignPrimExpr(cursor, info.alignment);
PrimExpr size_expr = CastToOffset(info.size_expr);
offset_expr = cursor;
cursor = offset_expr + size_expr;
}
} else {
// if its symbolic allocation, just arbitrarily choose one entry to fit in
// because we don't know its actual size
for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) {
StorageEntry *e = *it;
sym_free_list_.erase(it);
return e;
buffer_byte_offsets_[info.var] = offset_expr;
PrimExpr buf_end = offset_expr + CastToOffset(info.size_expr);
total_size = max(total_size, buf_end);
}
merged_alloc_size_ = buf_infos.empty()
? make_const(offset_dtype, 0)
: AlignPrimExpr(total_size, align_bytes_);
bool overlap_detected = false;
if (verbose_) {
LOG(DEBUG) << "Memory Allocation Plan for "
<< (is_dynamic_ ? "Dynamic" : "Static") << " Shared Memory:";
LOG(DEBUG) << " Total Merged Size (aligned): " << merged_alloc_size_;
for (const BufInfo &info : buf_infos) {
const PrimExpr &offset = buffer_byte_offsets_.at(info.var);
LOG(DEBUG) << " Buffer: " << info.name << " start=" << info.start
<< " end=" << info.end << " alignment=" << info.alignment
<< " offset=" << offset << " size=" << info.size_expr;
}
// Sanity check for overlapping constant buffers.
for (size_t i = 0; i < buf_infos.size(); ++i) {
const BufInfo &a = buf_infos[i];
auto a_off_imm = buffer_byte_offsets_.at(a.var).as<IntImmNode>();
if (!a.const_size_bytes.has_value() || a_off_imm == nullptr)
continue;
int64_t a_off = a_off_imm->value;
int64_t a_end = a_off + a.const_size_bytes.value();
for (size_t j = i + 1; j < buf_infos.size(); ++j) {
const BufInfo &b = buf_infos[j];
auto b_off_imm = buffer_byte_offsets_.at(b.var).as<IntImmNode>();
if (!b.const_size_bytes.has_value() || b_off_imm == nullptr)
continue;
bool live_overlap = !(a.end <= b.start || b.end <= a.start);
if (!live_overlap)
continue;
int64_t b_off = b_off_imm->value;
int64_t b_end = b_off + b.const_size_bytes.value();
bool mem_overlap = !(a_end <= b_off || b_end <= a_off);
if (mem_overlap) {
overlap_detected = true;
LOG(WARNING) << "Buffer overlap detected between " << a.name
<< " and " << b.name << " (lifetime overlap with "
<< "offset ranges [" << a_off << ", " << a_end
<< ") and [" << b_off << ", " << b_end << ")).";
}
}
}
}
return NewAlloc(op, const_nbits);
}
/*!
* \brief add the storage entry to the buffer var into the free list.
* \param var the buffer var
*/
void Free(const VarNode *var) {
auto it = alloc_map_.find(var);
ICHECK(it != alloc_map_.end());
StorageEntry *e = it->second;
ICHECK_NE(e->allocs.size(), 0U);
// normal free.
if (e->const_nbits != 0) {
const_free_map_.insert({e->const_nbits, e});
} else {
sym_free_list_.push_back(e);
if (overlap_detected) {
LOG(WARNING) << "Detected overlapping constant buffers; falling back to "
<< "sequential allocation without reuse.";
buffer_byte_offsets_.clear();
// In the fallback path we simply lay buffers out sequentially.
PrimExpr new_cursor = make_const(offset_dtype, 0);
PrimExpr new_total = make_const(offset_dtype, 0);
for (const BufInfo &info : buf_infos) {
new_cursor = AlignPrimExpr(new_cursor, info.alignment);
PrimExpr size_expr = CastToOffset(info.size_expr);
buffer_byte_offsets_[info.var] = new_cursor;
PrimExpr buf_end = new_cursor + size_expr;
new_total = max(new_total, buf_end);
new_cursor = buf_end;
}
merged_alloc_size_ = buf_infos.empty()
? make_const(offset_dtype, 0)
: AlignPrimExpr(new_total, align_bytes_);
}
}
// Whether enable dynamic analysis.
bool is_dynamic_{true};
......@@ -1095,14 +1314,6 @@ private:
bool allocated_{false};
// Locations of free ops.
std::unordered_map<const Object *, EventEntry> event_map_;
// constant size free map.
std::multimap<uint64_t, StorageEntry *> const_free_map_;
// symbolic free list, for non constant items.
std::list<StorageEntry *> sym_free_list_;
// The allocation assign map
std::unordered_map<const VarNode *, StorageEntry *> alloc_map_;
/*! \brief allocator of all the StorageEntry*/
support::Arena arena_;
// The mapping of buffer bytes alignment
std::unordered_map<const VarNode *, int> shmem_alignment_map_;
};
......@@ -1150,11 +1361,11 @@ Pass MergeSharedMemoryAllocations(bool enable_aggressive_merge = false,
{});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MergeSharedMemoryAllocations",
MergeSharedMemoryAllocations);
});
}
} // namespace transform
} // namespace tl
......
......@@ -57,7 +57,7 @@ public:
// Check reads from global
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
/*body*/ GetRef<Stmt>(op));
/*body*/ tvm::ffi::GetRef<Stmt>(op));
auto access = GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
auto reads = access[0];
Role role = Role::kProducer;
......@@ -253,7 +253,8 @@ private:
}
static Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
ObjectPtr<BufferNode> new_buffer =
tvm::ffi::make_object<BufferNode>(*(buffer.get()));
new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
if (!new_buffer->strides.empty()) {
ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
......@@ -493,10 +494,10 @@ tvm::transform::Pass MultiVersionBuffer() {
return CreatePrimFuncPass(pass_func, 0, "tl.MultiVersionBuffer", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.MultiVersionBuffer", MultiVersionBuffer);
});
}
} // namespace tl
} // namespace tvm
......@@ -59,10 +59,10 @@ tvm::transform::Pass PersistThreadblock() {
return CreatePrimFuncPass(pass_func, 0, "tl.PersistThreadblock", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.PersistThreadblock", PersistThreadblock);
});
}
} // namespace tl
} // namespace tvm
......@@ -103,7 +103,7 @@ private:
ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
auto var = call->args[1].as<VarNode>();
ICHECK(var);
auto it = buffer_data_to_buffer_.find(GetRef<Var>(var));
auto it = buffer_data_to_buffer_.find(tvm::ffi::GetRef<Var>(var));
ICHECK(it != buffer_data_to_buffer_.end());
return (*it).second;
};
......@@ -210,7 +210,7 @@ private:
if (const auto *load = op->args[0].as<BufferLoadNode>()) {
buffer_region = BufferRegion::FullRegion(load->buffer);
} else if (const auto *var_node = op->args[0].as<VarNode>()) {
Var data_var = GetRef<Var>(var_node);
Var data_var = tvm::ffi::GetRef<Var>(var_node);
auto it = buffer_data_to_buffer_.find(data_var);
if (it != buffer_data_to_buffer_.end()) {
buffer_region = BufferRegion::FullRegion((*it).second);
......@@ -223,7 +223,7 @@ private:
} else if (op->op.same_as(builtin::tvm_access_ptr())) {
const VarNode *buffer_var = op->args[1].as<VarNode>();
ICHECK(buffer_var);
auto it = buffer_data_to_buffer_.find(GetRef<Var>(buffer_var));
auto it = buffer_data_to_buffer_.find(tvm::ffi::GetRef<Var>(buffer_var));
if (it != buffer_data_to_buffer_.end()) {
const Buffer &buffer = (*it).second;
const BufferRegion buffer_region = BufferRegion::FullRegion(buffer);
......@@ -402,7 +402,7 @@ private:
if (TargetHasAsyncCopy(target_) && use_async_copy_)
annotations.Set(tir::attr::software_pipeline_async_stages,
Array<Integer>{0});
auto for_node = GetRef<For>(loop);
auto for_node = tvm::ffi::GetRef<For>(loop);
for_node.CopyOnWrite()->annotations = annotations;
return for_node;
}
......@@ -728,10 +728,10 @@ tvm::transform::Pass PipelinePlanning() {
return CreatePrimFuncPass(pass_func, 0, "tl.PipelinePlanning", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.PipelinePlanning", PipelinePlanning);
});
}
} // namespace tl
} // namespace tvm
......@@ -23,6 +23,7 @@ namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;
using namespace arith;
struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> {
......@@ -62,8 +63,8 @@ struct SimplifyConfigNode : public AttrsNodeReflAdapter<SimplifyConfigNode> {
"branch",
refl::DefaultValue(false));
}
static constexpr const char *_type_key = "tl.transform.SimplifyConfig";
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(SimplifyConfigNode, BaseAttrsNode);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.transform.SimplifyConfig",
SimplifyConfigNode, BaseAttrsNode);
RewriteSimplifier::Extension GetEnabledExtensions() const {
RewriteSimplifier::Extension flags = RewriteSimplifier::kNone;
......@@ -209,12 +210,11 @@ CollectVarsUsedInBufferDefinition(const Stmt &stmt) {
class SimplifyConfig : public Attrs {
public:
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SimplifyConfig, Attrs,
SimplifyConfigNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(SimplifyConfig, Attrs,
SimplifyConfigNode);
};
TVM_FFI_STATIC_INIT_BLOCK({ SimplifyConfigNode::RegisterReflection(); });
TVM_FFI_STATIC_INIT_BLOCK() { SimplifyConfigNode::RegisterReflection(); }
TVM_REGISTER_NODE_TYPE(SimplifyConfigNode);
TVM_REGISTER_PASS_CONFIG_OPTION("tl.Simplify", SimplifyConfig);
class StmtSimplifier : public IRMutatorWithAnalyzer {
......@@ -391,7 +391,7 @@ private:
if (can_inline && !used_in_buffer_def) {
return body;
} else if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else {
auto n = this->CopyOnWrite(op);
n->value = std::move(value);
......@@ -522,10 +522,10 @@ tvm::transform::Pass Simplify(bool simplify_arguments = true) {
return CreatePrimFuncPass(pass_func, 0, "tl.Simplify", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.Simplify", Simplify);
});
}
} // namespace tl
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment