Unverified Commit bbbf4207 authored by guchaoyang's avatar guchaoyang Committed by GitHub
Browse files

Merge branch 'main' into dcu

parents 8f4628e0 5eb30a4f
......@@ -249,7 +249,6 @@ private:
new_args.push_back(dst_node);
new_args.push_back(value_node);
}
new_args.push_back(memory_order);
Call new_call =
......
......@@ -10,6 +10,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include "../support/ffi_aliases.h"
namespace tvm {
namespace tir {
......@@ -66,7 +68,8 @@ public:
}
if (mem_reuse_max > 0) {
std::string tag_str = cluster_tag; // Convert to std::string
std::string tag_str =
static_cast<std::string>(cluster_tag); // Convert to std::string
if (tag_str.rfind("blockIdx", 0) == 0) {
// starts with "blockIdx"
tag_str = "clusterIdx" + tag_str.substr(strlen("blockIdx"));
......@@ -74,7 +77,7 @@ public:
// Unexpected format — maybe just prefix
tag_str = "clusterIdx" + tag_str;
}
cluster_tag = tvm::ffi::String(tag_str); // Convert back
cluster_tag = String(tag_str); // Convert back
return WithAttr(f, cluster_tag, Integer(cluster_size_));
} else {
return f;
......@@ -122,10 +125,10 @@ tvm::transform::Pass ClusterPlanning() {
return CreatePrimFuncPass(pass_func, 0, "tl.ClusterPlanning", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ClusterPlanning", ClusterPlanning);
});
}
} // namespace transform
} // namespace tir
......
......@@ -41,7 +41,7 @@ public:
return StmtMutator::VisitStmt_(op);
// Collect loop variables and ranges
auto for_node = GetRef<For>(op);
auto for_node = tvm::ffi::GetRef<For>(op);
Array<Var> loop_vars;
Array<PrimExpr> loop_extents;
Stmt body = op->body;
......@@ -81,7 +81,7 @@ public:
// post order visit the index
PostOrderVisit(index, [&](const ObjectRef &obj) {
if (const VarNode *v = obj.as<VarNode>()) {
used_vars.insert(GetRef<Var>(v));
used_vars.insert(tvm::ffi::GetRef<Var>(v));
}
});
if (used_vars.empty()) {
......
......@@ -211,7 +211,7 @@ public:
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector();
bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector();
......@@ -265,7 +265,7 @@ public:
PrimExpr VisitExpr_(const NotNode *op) final {
PrimExpr a = this->VisitExpr(op->a);
if (a.same_as(op->a)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return !(a);
}
......@@ -306,10 +306,10 @@ public:
PrimExpr value = this->VisitExpr(op->value);
if (value.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return Broadcast(op->value, op->lanes);
}
......@@ -321,7 +321,7 @@ public:
PrimExpr f = this->VisitExpr(op->false_value);
if (cond.same_as(op->condition) && t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int cond_lanes = cond.dtype().get_lanes_or_vscale_factor();
int t_lanes = t.dtype().get_lanes_or_vscale_factor();
......@@ -339,7 +339,7 @@ public:
PrimExpr VisitExpr_(const CastNode *op) final {
PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
if (value.dtype().is_scalable_vector()) {
return Cast(op->dtype.with_scalable_vscale_factor(
......@@ -352,20 +352,20 @@ public:
}
PrimExpr VisitExpr_(const FloatImmNode *op) final {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const IntImmNode *op) final {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const StringImmNode *op) final {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
// Variable
PrimExpr VisitExpr_(const VarNode *op) final {
Var var = GetRef<Var>(op);
Var var = tvm::ffi::GetRef<Var>(op);
if (var.same_as(var_)) {
return ramp_;
......@@ -382,13 +382,13 @@ public:
PrimExpr cond = this->VisitExpr(op->args[0]);
if (cond.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr t = this->VisitExpr(op->args[1]);
PrimExpr f = this->VisitExpr(op->args[2]);
if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) &&
f.same_as(op->args[2])) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int t_lanes = t.dtype().get_lanes_or_vscale_factor();
int f_lanes = f.dtype().get_lanes_or_vscale_factor();
......@@ -410,7 +410,7 @@ public:
ICHECK(op->op.same_as(builtin::reinterpret()));
PrimExpr value = this->VisitExpr(op->args[0]);
if (value.same_as(op->args[0])) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int lanes = value.dtype().get_lanes_or_vscale_factor();
if (value.dtype().is_scalable_vector()) {
......@@ -455,12 +455,12 @@ public:
auto new_arg = this->VisitExpr(arg);
if (new_arg.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
new_args.push_back(new_arg);
}
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return Call(op->dtype, op->op, new_args);
}
......@@ -469,7 +469,7 @@ public:
Array<PrimExpr> new_args = MutateArray(op->args, &lane);
// normal code path.
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return Call(op->dtype.with_lanes(lane), op->op, new_args);
}
......@@ -477,7 +477,7 @@ public:
}
// BufferLoad
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
auto load = GetRef<BufferLoad>(op);
auto load = tvm::ffi::GetRef<BufferLoad>(op);
auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index);
......@@ -514,7 +514,7 @@ public:
let_binding_[op->var] = op->var;
PrimExpr body = this->VisitExpr(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
return Let(op->var, value, body);
}
......@@ -522,7 +522,7 @@ public:
}
// BufferStore
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = GetRef<BufferStore>(op);
auto store = tvm::ffi::GetRef<BufferStore>(op);
auto fmutate = [this](const PrimExpr &index) {
return this->VisitExpr(index);
......@@ -585,11 +585,11 @@ public:
ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector());
PrimExpr extent = this->VisitExpr(op->extent);
if (extent.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op));
return Scalarize(tvm::ffi::GetRef<Stmt>(op));
}
Stmt body = this->VisitStmt(op->body);
if (extent.same_as(op->extent) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else {
return For(op->loop_var, op->min, extent, op->kind, body,
op->thread_binding, op->annotations);
......@@ -600,7 +600,7 @@ public:
ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector());
PrimExpr condition = this->VisitExpr(op->condition);
if (condition.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op));
return Scalarize(tvm::ffi::GetRef<Stmt>(op));
}
Stmt then_case = this->VisitStmt(op->then_case);
Optional<Stmt> else_case = std::nullopt;
......@@ -609,7 +609,7 @@ public:
}
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else {
return IfThenElse(condition, then_case, else_case);
}
......@@ -634,7 +634,7 @@ public:
let_binding_[op->var] = op->var;
Stmt body = this->VisitStmt(op->body);
if (value.same_as(op->value) && body.same_as(op->body)) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
} else {
return LetStmt(op->var, value, body);
}
......@@ -647,7 +647,7 @@ public:
if (condition.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op));
return Scalarize(tvm::ffi::GetRef<Stmt>(op));
}
// Mutate the extents
......@@ -657,7 +657,7 @@ public:
if (new_ext.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of "
<< op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op));
return Scalarize(tvm::ffi::GetRef<Stmt>(op));
}
extents.push_back(new_ext);
}
......@@ -738,7 +738,7 @@ private:
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor();
......@@ -754,7 +754,7 @@ private:
PrimExpr a = this->VisitExpr(op->a);
PrimExpr b = this->VisitExpr(op->b);
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
} else {
int a_lanes = a.dtype().get_lanes_or_vscale_factor();
int b_lanes = b.dtype().get_lanes_or_vscale_factor();
......
......@@ -38,7 +38,7 @@ protected:
if (is_enabled_ && op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(_index_bitwidth_), op->value);
}
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const CastNode *op) final {
......@@ -88,23 +88,23 @@ private:
PrimExpr VisitExpr_(const VarNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), GetRef<Var>(op));
return cast(DataType::Int(64), tvm::ffi::GetRef<Var>(op));
}
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const IntImmNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const CastNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
......@@ -183,11 +183,11 @@ tvm::transform::Pass ConfigIndexBitwidth() {
return CreatePrimFuncPass(pass_func, 0, "tl.ConfigIndexBitwidth", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.ConfigIndexBitwidth",
ConfigIndexBitwidth);
});
}
} // namespace tl
} // namespace tvm
......@@ -35,9 +35,7 @@ public:
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == "thread_extent") {
const VarNode *var = nullptr;
if (op->node->IsInstance<VarNode>()) {
var = op->node.as<VarNode>();
if (const auto *var = op->node.as<VarNode>()) {
if (var->name_hint == "threadIdx.x") {
thread_extent_ = op;
}
......@@ -82,7 +80,7 @@ public:
}
Stmt VisitStmt_(const ForNode *op) final {
PostOrderVisit(GetRef<For>(op), [&](const ObjectRef &node) {
PostOrderVisit(tvm::ffi::GetRef<For>(op), [&](const ObjectRef &node) {
if (const auto *call = node.as<CallNode>()) {
if (call->op.same_as(create_list_of_mbarrier()) ||
call->op.same_as(mbarrier_wait_parity()) ||
......@@ -116,11 +114,11 @@ tvm::transform::Pass EliminateStorageSyncForMBarrier() {
{});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.EliminateStorageSyncForMBarrier",
EliminateStorageSyncForMBarrier);
});
}
} // namespace transform
} // namespace tl
......
......@@ -75,23 +75,23 @@ private:
PrimExpr VisitExpr_(const VarNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), GetRef<Var>(op));
return cast(DataType::Int(64), tvm::ffi::GetRef<Var>(op));
}
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const IntImmNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return IntImm(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
PrimExpr VisitExpr_(const CastNode *op) final {
if (op->dtype.is_int() && op->dtype.bits() < 64) {
return cast(DataType::Int(64), op->value);
}
return GetRef<PrimExpr>(op);
return tvm::ffi::GetRef<PrimExpr>(op);
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
......@@ -115,7 +115,7 @@ private:
<< "All MatchBufferRegion should be removed in "
"tir.transform.LowerMatchBuffer.";
Block block = GetRef<Block>(op);
Block block = tvm::ffi::GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers;
alloc_buffers.MutateByApply(
......@@ -385,10 +385,10 @@ tvm::transform::Pass FlattenBuffer() {
return CreatePrimFuncPass(pass_func, 0, "tl.FlattenBuffer", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.FlattenBuffer", FlattenBuffer);
});
}
} // namespace tl
} // namespace tvm
......@@ -89,10 +89,10 @@ Pass LetInline() {
return CreatePrimFuncPass(pass_func, 0, "tl.LetInline", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LetInline", LetInline);
});
}
} // namespace tl
} // namespace tvm
......@@ -33,7 +33,7 @@ private:
auto then_case = VisitStmt(op->then_case);
Optional<Stmt> else_case = op->else_case;
if (else_case.defined()) {
return GetRef<Stmt>(op);
return tvm::ffi::GetRef<Stmt>(op);
}
ICHECK(then_case.defined()) << "then_case must be defined";
ICHECK(!else_case.defined()) << "else_case must be undefined";
......@@ -81,10 +81,10 @@ tvm::transform::Pass IfStmtBinding() {
return CreatePrimFuncPass(pass_func, 0, "tl.IfStmtBinding", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.IfStmtBinding", IfStmtBinding);
});
}
} // namespace tl
} // namespace tvm
......@@ -156,9 +156,9 @@ tvm::transform::Pass InjectAssumes() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectAssumes", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectAssumes", InjectAssumes);
});
}
} // namespace tvm::tl
......@@ -108,7 +108,8 @@ bool IsKnownGeneric(const CallNode *call) {
return false;
}
return call->op.same_as(ptx_ldmatrix()) || call->op.same_as(ptx_stmatrix()) ||
call->op.same_as(initialize_descriptor());
call->op.same_as(initialize_wgmma_descriptor()) ||
call->op.same_as(initialize_tcgen05_descriptor());
}
ProxyKind ProxyFromAttrValue(const ObjectRef &value) {
......@@ -319,10 +320,10 @@ tvm::transform::Pass InjectFenceProxy() {
{});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectFenceProxy", InjectFenceProxy);
});
}
} // namespace tl
} // namespace tvm
......@@ -37,9 +37,14 @@
namespace tvm {
namespace tl {
using namespace tir;
using namespace ffi;
namespace software_pipeline {
struct LetWrapper {
Var var;
PrimExpr value;
};
/*!
* \brief Create a block and infer the access region with the given body.
*
......@@ -233,10 +238,12 @@ class PipelineRewriter : public StmtExprMutator {
public:
PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
const Array<Buffer> &pipeline_allocs,
const For &pipeline_loop, const PipelineInfo &pipeline_info)
const For &pipeline_loop, const PipelineInfo &pipeline_info,
const std::vector<LetWrapper> &loop_var_let_wrappers)
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
pipeline_info_(pipeline_info) {}
pipeline_info_(pipeline_info),
loop_var_let_wrappers_(loop_var_let_wrappers) {}
Stmt BuildPipeline() {
// Step 1: Analyze accesses to the buffers in the pipeline and compute the
......@@ -459,7 +466,8 @@ private:
* \return The resized buffer.
*/
Buffer RewriteAllocBuffer(const Buffer &buffer, int num_versions) {
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*(buffer.get()));
ObjectPtr<BufferNode> new_buffer =
tvm::ffi::make_object<BufferNode>(*(buffer.get()));
new_buffer->shape.insert(new_buffer->shape.begin(), PrimExpr(num_versions));
if (!new_buffer->strides.empty()) {
ICHECK(new_buffer->strides.size() + 1 == new_buffer->shape.size());
......@@ -676,6 +684,20 @@ private:
new_block = Downcast<Block>(Substitute(
new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
// If there were Let-wrappers outside the original pipeline body that
// depended on the pipeline loop var, push them into each rewritten
// block with the correct per-block substitution.
if (!loop_var_let_wrappers_.empty()) {
BlockNode *n = new_block.CopyOnWrite();
Stmt inner = n->body;
for (const auto &lw : loop_var_let_wrappers_) {
PrimExpr substituted = Substitute(
lw.value, {{pipeline_loop_->loop_var, normalized_access_index}});
inner = LetStmt(lw.var, substituted, inner);
}
n->body = inner;
}
if (pipeline_info_[block].async) {
auto &local_state = async_states_local[stage];
local_state.producer_head = normalized_access_index;
......@@ -737,6 +759,7 @@ private:
Map<Buffer, Buffer> buffer_remap_;
Array<Block> ordered_stmts_;
std::map<int, AsyncStateGlobal> async_states;
std::vector<LetWrapper> loop_var_let_wrappers_;
};
/*!
......@@ -864,8 +887,9 @@ private:
const SeqStmtNode *pipeline_body_seq = nullptr;
std::vector<std::function<Stmt(Stmt)>> rewrap_fns;
std::vector<LetWrapper> loop_var_let_wrappers;
auto append_attr_wrapper = [&rewrap_fns](const AttrStmtNode *attr) {
ObjectRef node = attr->node;
Any node = attr->node;
String attr_key = attr->attr_key;
PrimExpr value = attr->value;
Span span = attr->span;
......@@ -896,6 +920,16 @@ private:
continue;
}
if (const auto *let_stmt = current.as<LetStmtNode>()) {
// If this Let value uses the pipeline loop var, record it and push
// inside each rewritten block later so the loop var can be
// substituted with the correct per-iteration index. Otherwise, keep
// it as a normal wrapper.
bool uses_loop_var = UsesVar(
let_stmt->value,
[v = op->loop_var.get()](const VarNode *vn) { return vn == v; });
if (uses_loop_var) {
loop_var_let_wrappers.push_back({let_stmt->var, let_stmt->value});
} else {
Var var = let_stmt->var;
PrimExpr value = let_stmt->value;
Span span = let_stmt->span;
......@@ -904,6 +938,7 @@ private:
span](Stmt body) -> Stmt {
return LetStmt(var, value, body, span);
});
}
current = let_stmt->body;
continue;
}
......@@ -981,7 +1016,8 @@ private:
// Step 4: Rewrite the pipeline body.
Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
GetRef<For>(op), pipeline_info)
tvm::ffi::GetRef<For>(op), pipeline_info,
loop_var_let_wrappers)
.BuildPipeline();
auto apply_wrappers = [&](Stmt stmt) {
for (auto it = rewrap_fns.rbegin(); it != rewrap_fns.rend(); ++it) {
......@@ -1072,11 +1108,11 @@ tir::transform::Pass InjectSoftwarePipeline() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectSoftwarePipeline", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectSoftwarePipeline",
InjectSoftwarePipeline);
});
}
} // namespace tl
} // namespace tvm
......@@ -232,10 +232,10 @@ tvm::transform::Pass InjectPTXAsyncCopy() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectPTXAsyncCopy", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectPTXAsyncCopy", InjectPTXAsyncCopy);
});
}
} // namespace tl
} // namespace tvm
......@@ -204,9 +204,9 @@ private:
void VisitStmt_(const EvaluateNode *op) final {
if (const auto *call = op->value.as<CallNode>()) {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
pending_tma_ops_.push_back(GetRef<Call>(call));
pending_tma_ops_.push_back(tvm::ffi::GetRef<Call>(call));
} else if (call->op.same_as(mbarrier_expect_tx())) {
pending_tma_ops_.push_back(GetRef<Call>(call));
pending_tma_ops_.push_back(tvm::ffi::GetRef<Call>(call));
} else if (call->op.same_as(builtin::ptx_arrive_barrier())) {
PrimExpr barrier_id = call->args[0];
for (const auto &tma_call : pending_tma_ops_) {
......@@ -295,13 +295,15 @@ public:
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(mbarrier_expect_tx())) {
PrimExpr e =
tma_op_to_barrier_id_[GetRef<Call>(op)].as<CallNode>()->args[0];
auto call_ref = tvm::ffi::GetRef<Call>(op);
if (tma_op_to_barrier_id_.count(call_ref)) {
PrimExpr e = tma_op_to_barrier_id_[call_ref].as<CallNode>()->args[0];
auto int_set = arith::EvalSet(e, var_int_set_);
expect_.push_back(if_depth_ == 1);
sequence.push_back(0);
int_sets_.push_back(int_set);
expect_tx_count_ += 1;
}
} else if (op->op.same_as(builtin::ptx_arrive_barrier())) {
sequence.push_back(1);
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
......@@ -336,32 +338,61 @@ public:
class BarrierCreationRewriter : public StmtExprMutator {
public:
BarrierCreationRewriter(std::vector<int> restore_barrier_ids,
PrimExpr producer_thread_extent)
PrimExpr producer_thread_extent,
int ensure_min_count = 0,
PrimExpr default_barrier_thread_count = 1)
: restore_barrier_ids_(std::move(restore_barrier_ids)),
producer_thread_extent_(std::move(producer_thread_extent)) {}
producer_thread_extent_(std::move(producer_thread_extent)),
ensure_min_count_(ensure_min_count),
default_barrier_thread_count_(std::move(default_barrier_thread_count)) {
}
PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(create_list_of_mbarrier())) {
std::vector<bool> tmp_(op->args.size(), false);
Array<PrimExpr> new_args;
size_t cur_n = op->args.size();
size_t need_n =
std::max<size_t>(cur_n, static_cast<size_t>(ensure_min_count_));
// Mark barriers to restore across the full needed length, not just the
// original length, so newly appended entries can be restored as well.
std::vector<bool> replace(need_n, false);
for (auto &id : restore_barrier_ids_) {
tmp_[id] = true;
if (id >= 0 && static_cast<size_t>(id) < replace.size()) {
replace[id] = true;
}
}
for (size_t i{0}; i < op->args.size(); ++i) {
if (tmp_[i]) {
Array<PrimExpr> new_args;
new_args.reserve(need_n);
// Preserve/override existing entries
for (size_t i{0}; i < cur_n; ++i) {
if (replace[i]) {
new_args.push_back(producer_thread_extent_);
} else {
new_args.push_back(op->args[i]);
}
}
// Append additional barriers if required
for (size_t i = cur_n; i < need_n; ++i) {
if (replace[i]) {
new_args.push_back(producer_thread_extent_);
} else {
new_args.push_back(default_barrier_thread_count_);
}
}
return Call(op->dtype, op->op, new_args);
} else {
return StmtExprMutator::VisitExpr_(op);
}
}
private:
std::vector<int> restore_barrier_ids_;
PrimExpr producer_thread_extent_;
int ensure_min_count_{0};
PrimExpr default_barrier_thread_count_{1};
};
// we trust mbarrier_wait_parity to be correct
......@@ -398,15 +429,38 @@ public:
collector.barrier_id_to_range(),
has_create_list_of_mbarrier);
f.CopyOnWrite()->body = rewriter(f->body);
// Compute the minimum number of barriers actually referenced in the body
// after TMA barrier rewrites (e.g., get_mbarrier(0) inserted for TMA).
struct GetMbarrierMaxIdxCollector : public StmtExprVisitor {
int max_idx{-1};
void VisitExpr_(const CallNode *op) final {
if (op->op.same_as(get_mbarrier())) {
if (op->args.size() == 1) {
if (const auto *imm = op->args[0].as<IntImmNode>()) {
max_idx = std::max(max_idx, static_cast<int>(imm->value));
}
}
}
StmtExprVisitor::VisitExpr_(op);
}
};
GetMbarrierMaxIdxCollector max_idx_collector;
max_idx_collector(f->body);
int ensure_min_count = max_idx_collector.max_idx + 1; // 0-based -> count
// For simple TMA-only producers, default barrier arrive count should be 1
// (only the elected leader performs the TMA arrive/expect).
auto barrier_creation_rewriter = BarrierCreationRewriter(
rewriter.restore_barrier_ids_, rewriter.producer_thread_extent_);
rewriter.restore_barrier_ids_, rewriter.producer_thread_extent_,
ensure_min_count, Integer(1));
f.CopyOnWrite()->body = barrier_creation_rewriter(f->body);
return f;
}
private:
Stmt VisitStmt_(const BlockNode *op) {
auto block = GetRef<Block>(op);
auto block = tvm::ffi::GetRef<Block>(op);
if (!has_create_list_of_mbarrier_ && !barrier_id_to_range_.empty() &&
op->name_hint == MainBlockName) {
ICHECK(false) << "Please declare create_list_of_mbarrier.";
......@@ -452,10 +506,27 @@ private:
PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
// check this must be in the tma_op_to_barrier_id_
ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op)))
<< "tma_load must be in the tma_op_to_barrier_id_";
auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)];
auto call_ref = tvm::ffi::GetRef<Call>(op);
if (!tma_op_to_barrier_id_.count(call_ref)) {
// For 1D TMA loads, promote raw integer barrier id to get_mbarrier(id)
// so codegen can emit mbarrier[index]. This handles degenerate
// producer-only kernels where no arrive() is seen and mapping is empty.
auto arg0 = op->args[0].as<Call>();
bool is_1d_tma_load =
arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
!arg0.value()->op.same_as(create_tma_im2col_descriptor());
if (is_1d_tma_load && op->args.size() >= 3) {
if (const auto *imm = op->args[2].as<IntImmNode>()) {
Array<PrimExpr> new_args = op->args;
new_args.Set(2, Call(DataType::Handle(), get_mbarrier(),
{IntImm(DataType::Int(32),
static_cast<int>(imm->value))}));
return Call(op->dtype, op->op, new_args);
}
}
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
auto barrier_id = tma_op_to_barrier_id_[call_ref];
auto new_args = op->args;
auto arg0 = op->args[0].as<Call>();
auto is_1d_tma_load =
......@@ -468,9 +539,11 @@ private:
}
return Call(op->dtype, op->op, new_args);
} else if (op->op.same_as(mbarrier_expect_tx())) {
ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op)))
<< "mbarrier_expect_tx must be in the tma_op_to_barrier_id_";
auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)];
auto call_ref = tvm::ffi::GetRef<Call>(op);
if (!tma_op_to_barrier_id_.count(call_ref)) {
return IRMutatorWithAnalyzer::VisitExpr_(op);
}
auto barrier_id = tma_op_to_barrier_id_[call_ref];
auto new_args = op->args;
new_args.Set(0, barrier_id);
if (!has_warp_specialization_)
......@@ -522,10 +595,10 @@ tvm::transform::Pass InjectTmaBarrier() {
return CreatePrimFuncPass(pass_func, 0, "tl.InjectTmaBarrier", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.InjectTmaBarrier", InjectTmaBarrier);
});
}
} // namespace tl
} // namespace tvm
......@@ -11,6 +11,7 @@
#include <tvm/tir/transform.h>
#include <tvm/tir/utils.h>
#include <algorithm>
#include <queue>
#include "../layout/utils.h"
......@@ -105,20 +106,60 @@ public:
"required for layout inference.";
// Run InferLayout
DLOG(INFO) << "[RunInferStep] working on " << cur_infer_id << '\n';
auto updates =
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
&analyzer_, buffer_oob},
level);
// Process the returned updates
for (const auto &[buffer, layout] : updates) {
DLOG(INFO) << " consider update " << buffer << " as "
<< layout->DebugOutput() << '\n';
// Basic validity checks
ICHECK(buffer.defined()) << "InferLayout returned an undefined buffer.";
ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";
// Helper: propagate inferred layout to alias buffers (same data Var)
auto propagate_alias = [&](const Buffer &src_buffer,
const Layout &src_layout) {
if (!buffer_data_to_buffers_.count(src_buffer->data))
return;
const auto &siblings = buffer_data_to_buffers_[src_buffer->data];
for (const auto &sib : siblings) {
if (sib.same_as(src_buffer))
continue;
bool shapes_equal =
src_layout->InputShape().size() == sib->shape.size();
if (shapes_equal) {
for (size_t i = 0; i < src_layout->InputShape().size(); ++i) {
if (!analyzer_.CanProveEqual(src_layout->InputShape()[i],
sib->shape[i])) {
shapes_equal = false;
break;
}
}
}
Layout target_layout =
shapes_equal ? src_layout
: src_layout->Reshape(sib->shape, &analyzer_);
if (layout_map.count(sib)) {
ICHECK(target_layout->IsEqual(layout_map[sib].get()))
<< "Get different layout for alias buffer " << sib
<< " (data-shared with " << src_buffer
<< ")\n current: " << target_layout->DebugOutput()
<< "\n previous: " << layout_map[sib]->DebugOutput();
} else {
layout_map.Set(sib, target_layout);
if (update_queue && use_list_.count(sib)) {
for (int idx : use_list_[sib]) {
if (!in_queue[idx] && idx != cur_infer_id) {
in_queue[idx] = true;
q.push(idx);
}
}
}
}
}
};
if (layout_map.count(buffer)) {
// If new layout contains the old one, update map
if (buffer.scope() == "local.fragment" &&
......@@ -153,8 +194,8 @@ public:
if (ProveFragmentContains(src_layout, dst_layout, indices, indices,
inner_analyzer)) {
layout_map.Set(buffer, layout);
DLOG(INFO) << " layout broadcast from "
<< src_layout->DebugOutput() << ", accepted" << '\n';
// Propagate to alias buffers as well
propagate_alias(buffer, layout);
continue;
}
}
......@@ -163,10 +204,13 @@ public:
<< "Get different layout for " << buffer
<< "\n current layout: " << layout->DebugOutput()
<< "\n previous layout: " << layout_map[buffer]->DebugOutput();
// Ensure aliases are consistent too
propagate_alias(buffer, layout);
} else {
// Otherwise, update map
layout_map.Set(buffer, layout);
DLOG(INFO) << " new layout accepted" << '\n';
// Propagate to alias buffers (may enqueue their users)
propagate_alias(buffer, layout);
if (!update_queue)
continue;
......@@ -272,6 +316,46 @@ public:
// step 3: relax constraints to free and re-run
InferInFreeMode(layout_map, strict_layout_map);
// step 4: finalize alias layouts by Var
// For each storage var, if any buffer in the group has a layout,
// propagate (reshape if needed) to the rest to ensure completeness.
for (const auto &[var, buffers] : buffer_data_to_buffers_) {
// Find a representative with existing layout
Optional<Buffer> rep;
Optional<Layout> rep_layout;
for (const auto &buf : buffers) {
if (layout_map.count(buf)) {
rep = buf;
rep_layout = layout_map[buf];
break;
}
}
if (!rep_layout.defined())
continue;
for (const auto &buf : buffers) {
if (!layout_map.count(buf)) {
bool shapes_equal =
rep_layout.value()->InputShape().size() == buf->shape.size();
if (shapes_equal) {
for (size_t i = 0; i < rep_layout.value()->InputShape().size();
++i) {
if (!analyzer_.CanProveEqual(rep_layout.value()->InputShape()[i],
buf->shape[i])) {
shapes_equal = false;
break;
}
}
}
Layout reshaped =
shapes_equal
? rep_layout.value()
: rep_layout.value()->Reshape(buf->shape, &analyzer_);
layout_map.Set(buf, reshaped);
}
}
}
// Check that all local.fragment buffers have inferred layouts
for (const auto &[buffer, _] : use_list_) {
if (buffer.scope() == "local.fragment") {
......@@ -314,7 +398,13 @@ public:
void Collect(const PrimFunc &f) {
for (const auto &[_, buffer] : f->buffer_map) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
if (buffer_data_to_buffers_.count(buffer->data)) {
auto buffers = buffer_data_to_buffers_[buffer->data];
buffers.push_back(buffer);
buffer_data_to_buffers_.Set(buffer->data, buffers);
} else {
buffer_data_to_buffers_.Set(buffer->data, {buffer});
}
}
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined())
......@@ -324,13 +414,25 @@ public:
}
private:
Map<Var, Buffer> GetBufferMap() const {
Map<Var, Buffer> buffer_map;
for (const auto &[var, buffers] : buffer_data_to_buffers_) {
// Use the first buffer for each var
// TODO(lei): phaseout buffer_map in future.
if (!buffers.empty()) {
buffer_map.Set(var, buffers[0]);
}
}
return buffer_map;
}
void VisitExpr_(const CallNode *op) final {
IRVisitorWithAnalyzer::VisitExpr_(op);
// Do not analysis the call node to the global function.
if (op->op.as<GlobalVarNode>())
return;
auto p = ParseOperator(GetRef<Call>(op), buffer_data_to_buffer_);
auto p = ParseOperator(tvm::ffi::GetRef<Call>(op), GetBufferMap());
if (p.defined()) {
for (const auto &arg : op->args) {
if (auto buffer = getBufferFromAccessPtr(arg)) {
......@@ -381,7 +483,7 @@ private:
}
// Add the tile operator to infer_list_
infer_list_stmt_.push_back(GetRef<ObjectRef>(op));
infer_list_stmt_.push_back(tvm::ffi::GetRef<ObjectRef>(op));
infer_list_.push_back(std::move(p));
}
}
......@@ -394,12 +496,18 @@ private:
if (call->op.same_as(builtin::tvm_access_ptr())) {
auto var_opt = call->args[1].as<Var>();
if (!var_opt.has_value()) {
DLOG(WARNING) << "[getBufferFromAccessPtr] args[1] is not a Var, type: "
LOG(WARNING) << "[getBufferFromAccessPtr] args[1] is not a Var, type: "
<< call->args[1]->GetTypeKey();
return std::nullopt;
}
const auto &var = var_opt.value();
return buffer_data_to_buffer_[var];
if (buffer_data_to_buffers_.count(var)) {
const auto &buffers = buffer_data_to_buffers_[var];
if (!buffers.empty()) {
return buffers[0]; // Return the first buffer
}
}
return std::nullopt;
} else if (call->op.same_as(RegionOp::Get())) {
return call->args[0].as<BufferLoadNode>()->buffer;
}
......@@ -416,11 +524,11 @@ private:
void VisitStmt_(const ForNode *op) final {
if (op->kind == ForKind::kParallel) {
auto infer = ParallelOp(GetRef<For>(op));
auto infer = ParallelOp(tvm::ffi::GetRef<For>(op));
for (const auto &[buffer, _] : infer->GetIndiceMap()) {
addToUseList(buffer);
}
infer_list_stmt_.push_back(GetRef<ObjectRef>(op));
infer_list_stmt_.push_back(tvm::ffi::GetRef<ObjectRef>(op));
infer_list_.push_back(std::move(infer));
thread_var_vec_.push_back(thread_var_);
if (thread_var_.defined() &&
......@@ -442,21 +550,55 @@ private:
void VisitStmt_(const BlockNode *op) final {
for (auto buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
if (buffer_data_to_buffers_.count(buffer->data)) {
auto buffers = buffer_data_to_buffers_[buffer->data];
buffers.push_back(buffer);
buffer_data_to_buffers_.Set(buffer->data, buffers);
} else {
buffer_data_to_buffers_.Set(buffer->data, {buffer});
}
}
// First, visit the block body to collect all buffers from
// BufferLoad/BufferStore
IRVisitorWithAnalyzer::VisitStmt_(op);
// After visiting, apply layouts to all collected buffers
if (op->annotations.count(attr::kLayoutMap)) {
// Check if the layout map is Map<Var, Layout>
auto map =
op->annotations.Get(attr::kLayoutMap)->as<Map<Var, Layout>>().value();
for (const auto &[var, layout] : map) {
ICHECK(buffer_data_to_buffer_.count(var))
ICHECK(buffer_data_to_buffers_.count(var))
<< "buffer " << var << " is not found in the block";
auto buffer = buffer_data_to_buffer_[var];
ICHECK(StructuralEqual()(layout->InputShape(), buffer->shape));
const auto &buffers = buffer_data_to_buffers_[var];
ICHECK(!buffers.empty()) << "buffer list for " << var << " is empty";
// Apply layout to all buffers associated with this var
for (const auto &buffer : buffers) {
// Reshape the layout to match the buffer's shape
// Check if shapes are structurally equal
bool shapes_equal =
layout->InputShape().size() == buffer->shape.size();
if (shapes_equal) {
for (size_t i = 0; i < layout->InputShape().size(); ++i) {
if (!analyzer_.CanProveEqual(layout->InputShape()[i],
buffer->shape[i])) {
shapes_equal = false;
break;
}
}
}
if (shapes_equal) {
annotated_layout_map_.Set(buffer, layout);
} else {
auto reshaped_layout = layout->Reshape(buffer->shape, &analyzer_);
annotated_layout_map_.Set(buffer, reshaped_layout);
}
}
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}
void VisitStmt_(const AttrStmtNode *op) final {
......@@ -470,7 +612,67 @@ private:
IRVisitorWithAnalyzer::VisitStmt_(op);
}
Map<Var, Buffer> buffer_data_to_buffer_;
void VisitExpr_(const BufferLoadNode *op) final {
// Collect buffer from BufferLoad
if (op->buffer.defined() && op->buffer->data.defined()) {
if (buffer_data_to_buffers_.count(op->buffer->data)) {
// Check if this buffer is already in the list
auto buffers = buffer_data_to_buffers_[op->buffer->data];
bool found = false;
for (const auto &buf : buffers) {
if (buf.same_as(op->buffer)) {
found = true;
break;
}
}
if (!found) {
buffers.push_back(op->buffer);
buffer_data_to_buffers_.Set(op->buffer->data, buffers);
DLOG(INFO) << "[LayoutInference] BufferLoad: added buffer "
<< op->buffer << " buffer.get() = " << op->buffer.get()
<< " data = " << op->buffer->data.get();
}
} else {
buffer_data_to_buffers_.Set(op->buffer->data, {op->buffer});
DLOG(INFO) << "[LayoutInference] BufferLoad: new buffer " << op->buffer
<< " buffer.get() = " << op->buffer.get()
<< " data = " << op->buffer->data.get();
}
}
IRVisitorWithAnalyzer::VisitExpr_(op);
}
void VisitStmt_(const BufferStoreNode *op) final {
// Collect buffer from BufferStore
if (op->buffer.defined() && op->buffer->data.defined()) {
if (buffer_data_to_buffers_.count(op->buffer->data)) {
// Check if this buffer is already in the list
auto buffers = buffer_data_to_buffers_[op->buffer->data];
bool found = false;
for (const auto &buf : buffers) {
if (buf.same_as(op->buffer)) {
found = true;
break;
}
}
if (!found) {
buffers.push_back(op->buffer);
buffer_data_to_buffers_.Set(op->buffer->data, buffers);
DLOG(INFO) << "[LayoutInference] BufferStore: added buffer "
<< op->buffer << " buffer.get() = " << op->buffer.get()
<< " data = " << op->buffer->data.get();
}
} else {
buffer_data_to_buffers_.Set(op->buffer->data, {op->buffer});
DLOG(INFO) << "[LayoutInference] BufferStore: new buffer " << op->buffer
<< " buffer.get() = " << op->buffer.get()
<< " data = " << op->buffer->data.get();
}
}
IRVisitorWithAnalyzer::VisitStmt_(op);
}
Map<Var, Array<Buffer>> buffer_data_to_buffers_;
std::vector<ObjectRef> infer_list_stmt_;
std::vector<TileOperator> infer_list_;
std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
......@@ -513,12 +715,33 @@ private:
if (infer_indices.empty())
continue;
// Union all infer_list_ indices that share the same buffer
// Union all infer_list_ indices that share the same Buffer object
int first_idx = infer_indices[0];
for (size_t i = 1; i < infer_indices.size(); i++) {
uf.Union(first_idx, infer_indices[i]);
}
}
// Additionally, union across buffers that share the same underlying
// buffer->data (Var). This handles cases like reshape where multiple
// Buffer objects alias the same storage.
for (const auto &[var, buffers] : buffer_data_to_buffers_) {
std::vector<int> merged;
for (const auto &buf : buffers) {
auto it = use_list_.find(buf);
if (it != use_list_.end()) {
const auto &vec = it->second;
merged.insert(merged.end(), vec.begin(), vec.end());
}
}
if (merged.size() > 1) {
std::sort(merged.begin(), merged.end());
merged.erase(std::unique(merged.begin(), merged.end()), merged.end());
int first = merged[0];
for (size_t i = 1; i < merged.size(); ++i) {
uf.Union(first, merged[i]);
}
}
}
std::unordered_map<int, std::vector<int>> components;
for (int i = 0; i < infer_list_.size(); i++) {
int root = uf.Find(i);
......@@ -597,7 +820,9 @@ private:
}
}
// Update the best plan if this one uses fewer registers
if (reg_num < min_reg_num) {
if (reg_num < min_reg_num ||
(reg_num == min_reg_num &&
attempt_infer_root < min_reg_num_infer_root)) {
best_infer_list =
BackupInferList(); // Use backup to avoid moving out infer_list_
best_layout_map = tmp_layout_map;
......@@ -711,8 +936,8 @@ private:
.value();
For for_node = Downcast<For>(IRMutatorWithAnalyzer::VisitStmt_(op));
if (result_.for_map.count(GetRef<For>(op))) {
auto root = GetRef<For>(op);
if (result_.for_map.count(tvm::ffi::GetRef<For>(op))) {
auto root = tvm::ffi::GetRef<For>(op);
// This check is a workaround to support T.Parallel for local buffers.
// For example:
// for i in T.Parallel(1024):
......@@ -787,7 +1012,18 @@ private:
}
});
if (has_non_local && !has_reducer) {
// If a cast operation exists, vectorization may still be required
bool has_cast_operations = false;
PostOrderVisit(for_node->body, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
// Check if this is a non-reducer store with Cast operation
if (store->value.as<CastNode>()) {
has_cast_operations = true;
}
}
});
if ((has_non_local || has_cast_operations) && !has_reducer) {
for_node = VectorizeLoop(for_node);
}
......@@ -831,10 +1067,10 @@ tvm::transform::Pass LayoutInference() {
return CreatePrimFuncPass(pass_func, 0, "tl.LayoutInference", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LayoutInference", LayoutInference);
});
}
} // namespace tl
} // namespace tvm
......@@ -14,6 +14,7 @@
#include "../layout/layout.h"
#include "../op/fill.h"
#include "../op/finalize_reducer.h"
#include "../op/region.h"
#include "arith/ir_mutator_with_analyzer.h"
#include "layout_reducer.h"
......@@ -275,17 +276,34 @@ private:
auto op = op_ref.CopyOnWrite();
if (op->op.same_as(Fill::Get())) {
ICHECK(!op->args.empty());
if (auto arg0_call = op->args[0].as<Call>();
arg0_call &&
arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) {
if (auto arg0_call = op->args[0].as<Call>()) {
// Case 1: tl.region(...) — extract buffer var from its first arg
if (arg0_call.value()->op.same_as(RegionOp::Get())) {
ICHECK(!arg0_call.value()->args.empty());
if (auto bl = arg0_call.value()->args[0].as<BufferLoadNode>()) {
Var var = bl->buffer->data;
if (reducer_info_map_.count(var)) {
ICHECK(inside_reducer_range_.count(var) == 0)
<< "T.fill on reducer must be enclosed with a "
"T.finalize_reducer "
"before next.";
inside_reducer_range_.Set(var,
reducer_info_map_.Get(var).value());
}
}
}
// Case 2: builtin.tvm_access_ptr(...) — existing path
else if (arg0_call.value()->op.same_as(builtin::tvm_access_ptr())) {
ICHECK(arg0_call.value()->args.size() > 1);
if (auto var = arg0_call.value()->args[1].as<Var>();
var && reducer_info_map_.count(var.value())) {
ICHECK(inside_reducer_range_.count(var.value()) == 0)
<< "T.fill on reducer must be enclosed with a T.finalize_reducer "
<< "T.fill on reducer must be enclosed with a "
"T.finalize_reducer "
"before next.";
inside_reducer_range_.Set(var.value(),
reducer_info_map_.Get(var.value()).value());
inside_reducer_range_.Set(
var.value(), reducer_info_map_.Get(var.value()).value());
}
}
}
} else if (op->op.same_as(FinalizeReducerOp::Get())) {
......@@ -362,10 +380,10 @@ tvm::transform::Pass LayoutReducer() {
return CreatePrimFuncPass(pass_func, 0, "tl.LayoutReducer", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LayoutReducer", LayoutReducer);
});
}
} // namespace tl
} // namespace tvm
......@@ -66,17 +66,17 @@ struct ReducerInfoNode : Object {
ReducerInfoNode() = default;
ReducerInfoNode(const String &op_str, const String &rep_str);
static constexpr const char *_type_key = "tl.ReducerInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(ReducerInfoNode, Object);
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("tl.ReducerInfo", ReducerInfoNode, Object);
};
struct ReducerInfo : ObjectRef {
public:
TVM_DLL ReducerInfo(const String &op_str, const String &rep_str) {
data_ = make_object<ReducerInfoNode>(op_str, rep_str);
data_ = tvm::ffi::make_object<ReducerInfoNode>(op_str, rep_str);
}
TVM_DEFINE_OBJECT_REF_METHODS(ReducerInfo, ObjectRef, ReducerInfoNode);
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(ReducerInfo, ObjectRef,
ReducerInfoNode);
};
namespace attr {
......
/*!
* \file legalize_negative_index.cc
* \brief Legalize negative indices in buffer load expressions.
*/
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <unordered_map>
#include <vector>
#include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h"
namespace tvm {
namespace tl {
using namespace tir;
using arith::IRVisitorWithAnalyzer;
enum class IndexSignState { kNonNegative, kNegative, kUnknown };
class NegativeIndexAnalyzer : public IRVisitorWithAnalyzer {
public:
explicit NegativeIndexAnalyzer(
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
*result)
: result_(result) {}
void VisitExpr_(const BufferLoadNode *op) final {
auto load = tvm::ffi::GetRef<BufferLoad>(op);
std::vector<IndexSignState> states;
states.reserve(op->indices.size());
bool needs_record = false;
for (size_t i = 0; i < op->indices.size(); ++i) {
PrimExpr simplified = analyzer_.Simplify(op->indices[i]);
// Handle scalar indices with the standard analyzer
if (simplified.dtype().lanes() == 1) {
if (analyzer_.CanProve(simplified >= 0)) {
states.push_back(IndexSignState::kNonNegative);
continue;
}
if (analyzer_.CanProve(simplified < 0)) {
states.push_back(IndexSignState::kNegative);
needs_record = true;
continue;
}
states.push_back(IndexSignState::kUnknown);
needs_record = true;
DLOG(WARNING)
<< "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << load->buffer->name << " (axis "
<< i << ").";
continue;
}
// Vector indices: try to reason about non-negativity/negativity
// Common patterns are Ramp(base, stride, lanes) and Broadcast(value,
// lanes).
IndexSignState vec_state = IndexSignState::kUnknown;
if (const auto *ramp = simplified.as<RampNode>()) {
// Compute a safe lower/upper bound for the vector lanes
// lower_bound = base_min + min(0, stride_min) * (lanes - 1)
// upper_bound = base_max + max(0, stride_max) * (lanes - 1)
auto base_bound = analyzer_.const_int_bound(ramp->base);
auto stride_bound = analyzer_.const_int_bound(ramp->stride);
int lanes = *as_const_int(ramp->lanes);
int64_t base_min = base_bound->min_value;
int64_t base_max = base_bound->max_value;
int64_t s_min = stride_bound->min_value;
int64_t s_max = stride_bound->max_value;
// Guard against overflow is not strictly necessary here because
// bounds may be +/-inf represented by sentinel values.
int64_t lower = base_min;
if (s_min < 0)
lower += s_min * (lanes - 1);
int64_t upper = base_max;
if (s_max > 0)
upper += s_max * (lanes - 1);
if (lower >= 0) {
vec_state = IndexSignState::kNonNegative;
} else if (upper < 0) {
vec_state = IndexSignState::kNegative;
} else {
vec_state = IndexSignState::kUnknown;
}
} else if (const auto *bc = simplified.as<BroadcastNode>()) {
auto v = analyzer_.Simplify(bc->value);
if (analyzer_.CanProve(v >= 0)) {
vec_state = IndexSignState::kNonNegative;
} else if (analyzer_.CanProve(v < 0)) {
vec_state = IndexSignState::kNegative;
} else {
// Try const bound if proof unavailable
auto vb = analyzer_.const_int_bound(v);
if (vb->min_value >= 0) {
vec_state = IndexSignState::kNonNegative;
} else if (vb->max_value < 0) {
vec_state = IndexSignState::kNegative;
} else {
vec_state = IndexSignState::kUnknown;
}
}
}
if (vec_state == IndexSignState::kNonNegative) {
states.push_back(IndexSignState::kNonNegative);
continue;
}
if (vec_state == IndexSignState::kNegative) {
states.push_back(IndexSignState::kNegative);
needs_record = true;
continue;
}
states.push_back(IndexSignState::kUnknown);
needs_record = true;
DLOG(WARNING) << "LegalizeNegativeIndex: cannot prove non-negative index "
<< simplified << " for buffer " << load->buffer->name
<< " (axis " << i << ").";
}
if (needs_record) {
(*result_)[op] = std::move(states);
}
IRVisitorWithAnalyzer::VisitExpr_(op);
}
private:
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
*result_;
};
class NegativeIndexRewriter : public arith::IRMutatorWithAnalyzer {
public:
static PrimFunc
Apply(PrimFunc func,
const std::unordered_map<const BufferLoadNode *,
std::vector<IndexSignState>> &states) {
arith::Analyzer analyzer;
NegativeIndexRewriter rewriter(&analyzer, states);
if (!func->body.defined()) {
return func;
}
PrimFuncNode *func_node = func.CopyOnWrite();
func_node->body = rewriter.VisitStmt(func_node->body);
return func;
}
private:
NegativeIndexRewriter(
arith::Analyzer *analyzer,
const std::unordered_map<const BufferLoadNode *,
std::vector<IndexSignState>> &states)
: arith::IRMutatorWithAnalyzer(analyzer), states_(states) {}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
BufferLoad load =
Downcast<BufferLoad>(arith::IRMutatorWithAnalyzer::VisitExpr_(op));
auto it = states_.find(op);
if (it == states_.end()) {
return load;
}
auto indices = load->indices;
bool changed = false;
const auto &state_vector = it->second;
ICHECK_EQ(state_vector.size(), indices.size())
<< "State vector size mismatch for buffer load " << load->buffer->name;
for (size_t i = 0; i < indices.size(); ++i) {
if (state_vector[i] != IndexSignState::kNegative) {
continue;
}
PrimExpr extent = load->buffer->shape[i];
indices.Set(i, analyzer_->Simplify(extent + indices[i]));
changed = true;
}
if (!changed) {
return load;
}
return BufferLoad(load->buffer, indices);
}
const std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
&states_;
};
PrimFunc LegalizeNegativeIndex(PrimFunc func) {
if (!func->body.defined()) {
return func;
}
std::unordered_map<const BufferLoadNode *, std::vector<IndexSignState>>
states;
NegativeIndexAnalyzer analyzer(&states);
analyzer(func->body);
if (states.empty()) {
return func;
}
return NegativeIndexRewriter::Apply(std::move(func), states);
}
tvm::transform::Pass LegalizeNegativeIndexPass() {
using namespace tir::transform;
auto pass_func = [](PrimFunc f, const IRModule &, PassContext) {
return LegalizeNegativeIndex(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.LegalizeNegativeIndex", {});
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LegalizeNegativeIndex",
LegalizeNegativeIndexPass);
}
} // namespace tl
} // namespace tvm
......@@ -38,7 +38,7 @@ private:
StmtVisitor::VisitStmt(op->body);
if (!has_child_for_) {
leaf_for_nodes.push_back(GetRef<For>(op));
leaf_for_nodes.push_back(tvm::ffi::GetRef<For>(op));
}
parent_has_child_for_ = parent_has_child_for;
......@@ -378,11 +378,11 @@ tvm::transform::Pass LegalizeSafeMemoryAccess() {
}
// Register the pass globally so it can be used in the compilation pipeline
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LegalizeSafeMemoryAccess",
LegalizeSafeMemoryAccess);
});
}
} // namespace tl
} // namespace tvm
......@@ -89,11 +89,11 @@ tvm::transform::Pass LegalizeVectorizedLoop() {
}
// Register the pass globally so it can be used in the compilation pipeline
TVM_FFI_STATIC_INIT_BLOCK({
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LegalizeVectorizedLoop",
LegalizeVectorizedLoop);
});
}
} // namespace tl
} // namespace tvm
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment