Unverified Commit f7ba45d8 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Implement classic arena algorithm for shmem merge and WAW conflict detection (#1146)

* atomic_fix

* atomic_fix

* mem fix

* lint fix

* add some comments

* fix

* fix

* lint fix

* handle async copy

* lint fix
parent c70b2697
...@@ -31,6 +31,12 @@ ...@@ -31,6 +31,12 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include <algorithm>
#include <functional>
#include <limits>
#include <optional>
#include <queue>
#include <sstream>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
...@@ -38,7 +44,6 @@ ...@@ -38,7 +44,6 @@
#include "../op/builtin.h" #include "../op/builtin.h"
#include "../target/utils.h" #include "../target/utils.h"
#include "runtime/thread_storage_scope.h" #include "runtime/thread_storage_scope.h"
#include "support/arena.h"
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
#include "tvm/tir/function.h" #include "tvm/tir/function.h"
...@@ -141,6 +146,8 @@ public: ...@@ -141,6 +146,8 @@ public:
void VisitStmt_(const AllocateNode *op) final { void VisitStmt_(const AllocateNode *op) final {
size_t level = scope_.size(); size_t level = scope_.size();
const VarNode *buf = op->buffer_var.get(); const VarNode *buf = op->buffer_var.get();
// Record the allocation site and depth so liveness can reason about the
// original scope.
alloc_info_[buf].alloc = op; alloc_info_[buf].alloc = op;
alloc_info_[buf].level = level; alloc_info_[buf].level = level;
StmtExprVisitor::VisitStmt_(op); StmtExprVisitor::VisitStmt_(op);
...@@ -194,9 +201,12 @@ public: ...@@ -194,9 +201,12 @@ public:
const VarNode *buf = op->buffer->data.get(); const VarNode *buf = op->buffer->data.get();
auto it = alloc_info_.find(buf); auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) { if (it != alloc_info_.end() && it->second.alloc) {
// Allow buffer access at the same level or deeper scope // Earlier we required `alloc_level < scope_.size()`, assuming every load
// Changed from < to <= to handle cases where buffer is accessed // would occur strictly inside a nested scope. In practice the lowering
// in expressions at the same scope level where it's allocated // pipeline may materialise reads in the very same frame that owns the
// allocation (e.g. when the buffer value is passed directly to a call),
// which used to trigger the CHECK. Treat same-level accesses as valid so
// the merged allocator can reason about their lifetime correctly.
ICHECK_LE(it->second.level, scope_.size()) ICHECK_LE(it->second.level, scope_.size())
<< "Load memory in places other than store."; << "Load memory in places other than store.";
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) { if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
...@@ -204,7 +214,10 @@ public: ...@@ -204,7 +214,10 @@ public:
if (enable_aggressive_merge) { if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf); scope_[scope_.size() - 1].touched.push_back(buf);
} else { } else {
// When accessing at the same level, use that level // When the access happens in the same scope frame as the allocation
// we attribute it to that frame instead of the outer parent. This
// keeps the liveness window tight while still accounting for nested
// scopes that legitimately touch the buffer deeper in the tree.
size_t access_level = std::min(it->second.level, scope_.size() - 1); size_t access_level = std::min(it->second.level, scope_.size() - 1);
scope_[access_level].touched.push_back(buf); scope_[access_level].touched.push_back(buf);
} }
...@@ -216,14 +229,17 @@ public: ...@@ -216,14 +229,17 @@ public:
// Directly reference to the variable count as a read. // Directly reference to the variable count as a read.
auto it = alloc_info_.find(buf); auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) { if (it != alloc_info_.end() && it->second.alloc) {
// Allow buffer access at the same level or deeper scope // Same rationale as the BufferLoad path above: direct references can be
// emitted at the allocation level after flattening, so accept them and
// record the touch for liveness planning.
ICHECK_LE(it->second.level, scope_.size()); ICHECK_LE(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) { if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
auto enable_aggressive_merge = enable_aggressive_merge_; auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) { if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf); scope_[scope_.size() - 1].touched.push_back(buf);
} else { } else {
// When accessing at the same level, use that level // Attribute same-level uses to the allocation frame, mirroring the
// BufferLoad handling to keep reuse decisions consistent.
size_t access_level = std::min(it->second.level, scope_.size() - 1); size_t access_level = std::min(it->second.level, scope_.size() - 1);
scope_[access_level].touched.push_back(buf); scope_[access_level].touched.push_back(buf);
} }
...@@ -245,6 +261,8 @@ public: ...@@ -245,6 +261,8 @@ public:
scope_.pop_back(); scope_.pop_back();
int64_t end_index = static_cast<int64_t>(linear_seq_.size()); int64_t end_index = static_cast<int64_t>(linear_seq_.size());
ICHECK_GT(end_index, begin_index); ICHECK_GT(end_index, begin_index);
// The paired entries serve as scope sentinels once we flatten the
// control-flow tree.
e.scope_pair_offset = begin_index - end_index; e.scope_pair_offset = begin_index - end_index;
linear_seq_.push_back(e); linear_seq_.push_back(e);
// record the pointer to end index. // record the pointer to end index.
...@@ -338,7 +356,11 @@ public: ...@@ -338,7 +356,11 @@ public:
private: private:
void VisitExpr_(const CallNode *op) { void VisitExpr_(const CallNode *op) {
if (op->op.same_as(tl::tl_gemm()) || op->op.same_as(tl::tl_gemm_sp()) || if (op->op.same_as(tl::tl_gemm()) || op->op.same_as(tl::tl_gemm_sp()) ||
op->op.same_as(tl::tma_load()) || op->op.same_as(tl::tma_store())) { op->op.same_as(tl::tma_load()) || op->op.same_as(tl::tma_store()) ||
op->op.same_as(tl::ptx_wgmma_ss()) ||
op->op.same_as(tl::ptx_wgmma_rs())) {
// These intrinsics introduce stricter SMEM alignment requirements; mark
// the subtree.
under_alignment_scope_ = true; under_alignment_scope_ = true;
StmtExprVisitor::VisitExpr_(op); StmtExprVisitor::VisitExpr_(op);
under_alignment_scope_ = false; under_alignment_scope_ = false;
...@@ -394,6 +416,8 @@ public: ...@@ -394,6 +416,8 @@ public:
enable_aggressive_merge, verbose); enable_aggressive_merge, verbose);
finder(stmt); finder(stmt);
shmem_alignment_map_ = SharedMemoryAlignmentPlanner::Plan(stmt); shmem_alignment_map_ = SharedMemoryAlignmentPlanner::Plan(stmt);
// First compute liveness over the flattened schedule, then feed it into the
// arena packer.
this->LivenessAnalysis(finder.linear_seq_, finder.stmt_attrs_); this->LivenessAnalysis(finder.linear_seq_, finder.stmt_attrs_);
this->PlanMemory(finder.linear_seq_, finder.stmt_attrs_); this->PlanMemory(finder.linear_seq_, finder.stmt_attrs_);
} }
...@@ -403,65 +427,6 @@ private: ...@@ -403,65 +427,6 @@ private:
if (op->attr_key == tir::attr::thread_extent && !allocated_) { if (op->attr_key == tir::attr::thread_extent && !allocated_) {
// Allocate one dynamic shared memory allocation at the beginning of // Allocate one dynamic shared memory allocation at the beginning of
// thread scope // thread scope
int max_layer_num = 0;
std::vector<const StorageEntry *> all_entry;
for (const auto &e : const_free_map_) {
all_entry.push_back(e.second);
}
for (const StorageEntry *e : sym_free_list_) {
all_entry.push_back(e);
}
// Sort the storage entries in descending order of their total allocation
// size (in bits). This ensures that larger allocations are placed first,
// which can help minimize fragmentation and improve memory packing
// efficiency when merging shared memory buffers.
std::sort(all_entry.begin(), all_entry.end(),
[](const StorageEntry *a, const StorageEntry *b) {
return a->const_nbits > b->const_nbits;
});
for (const StorageEntry *e : all_entry) {
max_layer_num =
std::max(max_layer_num, static_cast<int>(e->allocs.size()));
}
// calculate align for each layer of each storage entry.
std::vector<int> align(max_layer_num, 0);
for (const StorageEntry *e : all_entry) {
for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
for (const VarNode *buffer : e->allocs[i]) {
const AllocateNode *alloc = shmem_allocs_[buffer];
align[i] =
std::max(align[i], alloc->dtype.bytes() * alloc->dtype.lanes());
align[i] = std::max(align[i], align_bytes_);
}
}
}
for (const StorageEntry *e : all_entry) {
PrimExpr max_inner_offset = 0;
for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
PrimExpr inner_offset = 0;
for (const VarNode *buffer : e->allocs[i]) {
const AllocateNode *alloc = shmem_allocs_[buffer];
auto alignment = align[i];
// Modern nvidia architecture performs hardware swizzling (hopper
// wgmma/tma for example) requires dynamic shared memory address to
// be aligned to 1024 bytes For other devices, we align to 16 bytes
if (shmem_alignment_map_.find(buffer) !=
shmem_alignment_map_.end()) {
alignment = std::max(align[i], shmem_alignment_map_[buffer]);
}
PrimExpr start_offset = merged_alloc_size_ + inner_offset;
PrimExpr aligned_offset =
indexdiv(start_offset + alignment - 1, alignment) * alignment;
buffer_byte_offsets_[buffer] = aligned_offset;
inner_offset =
aligned_offset - merged_alloc_size_ +
alloc->extents[0] * alloc->dtype.bytes() * alloc->dtype.lanes();
}
max_inner_offset = max(max_inner_offset, inner_offset);
}
merged_alloc_size_ += max_inner_offset;
}
if (verbose_) { if (verbose_) {
...@@ -626,18 +591,199 @@ private: ...@@ -626,18 +591,199 @@ private:
using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry; using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry;
using StmtAttr = SharedMemLinearAccessPatternFinder::StmtAttr; using StmtAttr = SharedMemLinearAccessPatternFinder::StmtAttr;
struct StorageEntry {
// The constant size of the buffer in bits, only used if it is constant // Metadata about a single shared-memory allocation prior to merging. This
uint64_t const_nbits{0}; // is used to build lifetimes, alignment requirements, and final offsets.
// Allocs that shares this entry. struct BufInfo {
// The inner vector means a "layer" const VarNode *var{nullptr};
// For example, it we need to allocate C in the memory of A and B: std::string name;
// | A: 4096 bytes | B: 4096 bytes | PrimExpr size_expr;
// | C: 8192 bytes | std::optional<int64_t> const_size_bytes; // in bytes if compile-time known.
// Then the allocs = {{A, B}, {C}} int alignment{0}; // required byte alignment.
std::vector<std::vector<const VarNode *>> allocs; int start{0}; // first statement index touching the buf.
int end{0}; // one-past-last statement index.
DataType size_dtype{DataType::Int(32)};
};
// Interval describing the liveness window of a (constant-sized) allocation.
struct Interval {
int start{0};
int end{0};
size_t size_bytes{0};
int alignment{0};
const VarNode *var{nullptr};
};
// Result of a linear-scan arena packing. Offsets contain the byte offset for
// each constant-sized buffer, arena_size is the total constant footprint.
struct ArenaPlan {
size_t arena_size{0};
std::unordered_map<const VarNode *, size_t> offsets;
};
static size_t AlignUpSize(size_t value, size_t alignment) {
if (alignment == 0) {
return value;
}
size_t remainder = value % alignment;
if (remainder == 0) {
return value;
}
return value + (alignment - remainder);
}
struct FreeBlock {
size_t offset{0};
size_t size{0};
};
class FreeList {
public:
std::optional<size_t> Allocate(size_t need, size_t alignment) {
// Best-fit search: pick the slot that wastes the least space after
// alignment.
int best = -1;
size_t best_waste = std::numeric_limits<size_t>::max();
for (int i = 0, n = static_cast<int>(blocks_.size()); i < n; ++i) {
size_t aligned = AlignUpSize(blocks_[i].offset, alignment);
size_t head = aligned - blocks_[i].offset;
if (head <= blocks_[i].size && (blocks_[i].size - head) >= need) {
size_t waste = blocks_[i].size - head - need;
if (waste < best_waste) {
best_waste = waste;
best = i;
}
}
}
if (best < 0) {
return std::nullopt;
}
FreeBlock blk = blocks_[best];
size_t aligned = AlignUpSize(blk.offset, alignment);
size_t head = aligned - blk.offset;
size_t tail = blk.size - head - need;
blocks_.erase(blocks_.begin() + best);
if (head) {
blocks_.push_back({blk.offset, head});
}
if (tail) {
blocks_.push_back({aligned + need, tail});
}
Normalize();
return aligned;
}
void Free(size_t offset, size_t size) {
if (size == 0)
return;
blocks_.push_back({offset, size});
Normalize();
}
private:
void Normalize() {
if (blocks_.empty())
return;
std::sort(blocks_.begin(), blocks_.end(),
[](const FreeBlock &a, const FreeBlock &b) {
return a.offset < b.offset;
});
std::vector<FreeBlock> merged;
merged.reserve(blocks_.size());
for (const FreeBlock &blk : blocks_) {
if (merged.empty()) {
merged.push_back(blk);
continue;
}
FreeBlock &last = merged.back();
size_t last_end = last.offset + last.size;
if (blk.offset <= last_end) {
size_t blk_end = blk.offset + blk.size;
if (blk_end > last_end) {
last.size = blk_end - last.offset;
}
} else {
merged.push_back(blk);
}
}
blocks_ = std::move(merged);
}
std::vector<FreeBlock> blocks_;
};
struct ActiveInterval {
int end{0};
size_t offset{0};
size_t size{0};
const VarNode *var{nullptr};
bool operator>(const ActiveInterval &other) const {
return end > other.end;
}
}; };
static ArenaPlan LinearScanPack(std::vector<Interval> intervals) {
// Process intervals in program order so lifetimes correspond to the
// linearised CFG.
std::sort(intervals.begin(), intervals.end(),
[](const Interval &lhs, const Interval &rhs) {
if (lhs.start != rhs.start) {
return lhs.start < rhs.start;
}
if (lhs.size_bytes != rhs.size_bytes) {
return lhs.size_bytes > rhs.size_bytes;
}
return lhs.var < rhs.var;
});
std::priority_queue<ActiveInterval, std::vector<ActiveInterval>,
std::greater<ActiveInterval>>
active;
FreeList freelist;
size_t arena_top = 0;
std::unordered_map<const VarNode *, size_t> offsets;
// Expire intervals that end before or at program counter `pc`.
auto retire = [&](int pc) {
while (!active.empty() && active.top().end <= pc) {
const ActiveInterval top = active.top();
active.pop();
freelist.Free(top.offset, top.size);
}
};
for (const Interval &interval : intervals) {
retire(interval.start);
size_t offset = 0;
// Try to recycle previously freed memory first; fall back to bumping the
// arena.
if (auto slot =
freelist.Allocate(interval.size_bytes, interval.alignment)) {
offset = slot.value();
} else {
offset = AlignUpSize(arena_top, interval.alignment);
arena_top = offset + interval.size_bytes;
}
active.push(ActiveInterval{interval.end, offset, interval.size_bytes,
interval.var});
offsets[interval.var] = offset;
}
return ArenaPlan{arena_top, std::move(offsets)};
}
PrimExpr AlignPrimExpr(const PrimExpr &value, int alignment) const {
if (alignment <= 1) {
return value;
}
DataType dtype = value.dtype();
ICHECK(dtype.is_int() || dtype.is_uint())
<< "Expected integer dtype for alignment, but got " << dtype;
PrimExpr align_expr = make_const(dtype, alignment);
PrimExpr adjust = make_const(dtype, alignment - 1);
return indexdiv(value + adjust, align_expr) * align_expr;
}
// Event entry in liveness analysis // Event entry in liveness analysis
struct EventEntry { struct EventEntry {
// variables we generate // variables we generate
...@@ -905,173 +1051,228 @@ private: ...@@ -905,173 +1051,228 @@ private:
void void
PlanMemory(const std::vector<StmtEntry> &seq, PlanMemory(const std::vector<StmtEntry> &seq,
const std::unordered_map<const Object *, StmtAttr> &stmt_attrs) { const std::unordered_map<const Object *, StmtAttr> &stmt_attrs) {
std::unordered_set<const VarNode *> inplace_flag; buffer_byte_offsets_.clear();
(void)stmt_attrs;
if (shmem_allocs_.empty()) {
merged_alloc_size_ = make_const(DataType::Int(64), 0);
return;
}
// Discover the first and last touch for every allocation.
std::unordered_map<const VarNode *, int> start_index;
std::unordered_map<const VarNode *, int> end_index;
for (size_t i = 0; i < seq.size(); ++i) { for (size_t i = 0; i < seq.size(); ++i) {
auto it = event_map_.find(seq[i].stmt); auto it = event_map_.find(seq[i].stmt);
// scope_pair_offset <= 0 means it is either if (it == event_map_.end())
// - leaf stmt(offset = 0) continue;
// - end of scope(offset < 0) for (const VarNode *var : it->second.gen) {
// In both cases, we need to handle the kill event correctly start_index.emplace(var, static_cast<int>(i));
auto is_leaf_alloc = [&](const VarNode *var) {
return seq[i].scope_pair_offset == 0 &&
std::find(it->second.gen.begin(), it->second.gen.end(), var) !=
it->second.gen.end();
};
if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
for (const VarNode *var : it->second.kill) {
if (!is_leaf_alloc(var))
this->Free(var);
}
} }
// scope_pair_offset >= 0 means it is either for (const VarNode *var : it->second.kill) {
// - leaf stmt(offset = 0) end_index[var] = std::max(end_index[var], static_cast<int>(i) + 1);
// - beginning of scope(offset < 0) }
// In both cases, we need to handle the gen event correctly }
if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) {
for (const VarNode *var : it->second.gen) { const int seq_len = static_cast<int>(seq.size());
ICHECK(shmem_allocs_.count(var)); for (const auto &kv : start_index) {
const AllocateNode *alloc = shmem_allocs_[var]; if (!end_index.count(kv.first)) {
StorageEntry *dst_entry = FindAlloc(alloc); end_index[kv.first] = seq_len;
alloc_map_[var] = dst_entry;
}
} }
if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { }
for (const VarNode *var : it->second.kill) {
if (is_leaf_alloc(var)) std::vector<BufInfo> buf_infos;
this->Free(var); buf_infos.reserve(shmem_allocs_.size());
// Build a BufInfo for all allocations that participate in liveness.
for (const auto &kv : shmem_allocs_) {
const VarNode *var = kv.first;
auto start_it = start_index.find(var);
if (start_it == start_index.end()) {
continue;
}
BufInfo info;
info.var = var;
info.name = var->name_hint;
info.start = start_it->second;
info.end = std::max(end_index[var], info.start + 1);
info.alignment = align_bytes_;
auto align_it = shmem_alignment_map_.find(var);
if (align_it != shmem_alignment_map_.end()) {
info.alignment = std::max(info.alignment, align_it->second);
}
const AllocateNode *alloc = kv.second;
int64_t bytes_per_elem =
static_cast<int64_t>(alloc->dtype.bytes() * alloc->dtype.lanes());
DataType size_dtype = DataType::Int(32);
if (!alloc->extents.empty()) {
size_dtype = alloc->extents[0].dtype();
}
if (!size_dtype.is_int() && !size_dtype.is_uint()) {
size_dtype = DataType::Int(32);
}
PrimExpr size_expr = make_const(size_dtype, bytes_per_elem);
for (const PrimExpr &extent : alloc->extents) {
PrimExpr e = extent;
if (e.dtype() != size_dtype) {
e = cast(size_dtype, e);
} }
size_expr = size_expr * e;
} }
info.size_dtype = size_dtype;
info.size_expr = size_expr;
int64_t const_extent = alloc->ConstantAllocationSize();
if (const_extent >= 0) {
info.const_size_bytes = const_extent * bytes_per_elem;
}
buf_infos.push_back(std::move(info));
} }
}
/*! // Stable order so the later passes have deterministic behaviour.
* \brief Allocate new storage entry. std::sort(buf_infos.begin(), buf_infos.end(),
* \param op the allocate node [](const BufInfo &a, const BufInfo &b) {
* \param the size of the allocation in bits if (a.start != b.start)
* \return the new storage entry return a.start < b.start;
*/ if (a.end != b.end)
StorageEntry *NewAlloc(const AllocateNode *op, size_t const_nbits) { return a.end < b.end;
ICHECK(op != nullptr); return a.name < b.name;
// Reuse not successful, allocate a new buffer. });
StorageEntry *entry = arena_.make<StorageEntry>();
entry->allocs.push_back({op->buffer_var.get()}); std::vector<Interval> intervals;
entry->const_nbits = const_nbits; intervals.reserve(buf_infos.size());
return entry; for (const BufInfo &info : buf_infos) {
} if (!info.const_size_bytes.has_value())
/*! continue;
* @brief Locate or create a storage entry from free lists to satisfy an // Only constant-sized buffers participate in the arena packing because
* AllocateNode. // dynamic sizes must be placed sequentially later.
* Interval interval;
* Finds a reusable StorageEntry for the given AllocateNode (constant or interval.start = info.start;
* symbolic size) using two-tiered strategies: interval.end = info.end;
* - For constant-size allocations (>0): prefer a free entry that is >= interval.size_bytes = static_cast<size_t>(
* required size; if none, coalesce smaller free constant-size entries until std::max<int64_t>(0, info.const_size_bytes.value()));
* the sum meets the request and return a new StorageEntry representing the interval.alignment = info.alignment;
* merged space. Very small constant allocations (<= 32 bits) are not reused interval.var = info.var;
* and will allocate a fresh entry. intervals.push_back(interval);
* - For symbolic-size (unknown at compile time): pick and remove an arbitrary
* entry from the symbolic free list.
*
* If no suitable free entry is found, a fresh StorageEntry is created via
* NewAlloc.
*
* @param op Pointer to the AllocateNode to satisfy. Must be non-null.
* @return StorageEntry* A storage entry that will hold the allocation (may be
* newly created).
*/
StorageEntry *FindAlloc(const AllocateNode *op) {
ICHECK(op != nullptr);
// skip plan for local variable,
// compiler can do a better job with register allocation.
const uint64_t match_range = 16;
uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes();
uint64_t const_nbits =
static_cast<uint64_t>(op->ConstantAllocationSize() * op_elem_bits);
// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
if (const_nbits > 0 && const_nbits <= 32) {
return NewAlloc(op, const_nbits);
} }
if (const_nbits != 0) { ArenaPlan plan = LinearScanPack(std::move(intervals));
// constant allocation. size_t arena_size_const = plan.arena_size;
auto begin = const_free_map_.lower_bound(0);
auto mid = const_free_map_.lower_bound(const_nbits); if (verbose_) {
auto end = const_free_map_.upper_bound(const_nbits * match_range); LOG(DEBUG) << "ArenaPlan (constant buffers): arena_size="
// Start looking at the buffer that is bigger than the required size << arena_size_const;
// first. If we find one, directly allocate the buffer in its location and for (const auto &kv : plan.offsets) {
// remove its entry in the free list const VarNode *var = kv.first;
for (auto it = mid; it != end; ++it) { LOG(DEBUG) << " " << var->name_hint << " -> offset=" << kv.second;
StorageEntry *e = it->second;
e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it);
it->second->allocs.push_back({op->buffer_var.get()});
return e;
} }
// Then start looking at smaller buffers. }
// Keep collecting the buffer until the sum of their size exceeds the
// buffer to allocate and finally free all these entry in the free list // Cursor tracks the running byte offset within the merged arena.
std::vector<std::multimap<uint64_t, StorageEntry *>::iterator> delete_it; DataType offset_dtype =
// the alloc list for the new entry buf_infos.empty() ? DataType::Int(32) : buf_infos.front().size_dtype;
std::vector<std::vector<const VarNode *>> reuse_allocs; PrimExpr total_size = make_const(offset_dtype, 0);
uint64_t mem_ct = 0; PrimExpr cursor = AlignPrimExpr(
for (auto it = mid; it != begin;) { make_const(offset_dtype, static_cast<int64_t>(arena_size_const)),
--it; align_bytes_);
delete_it.push_back(it);
mem_ct += it->second->const_nbits; auto CastToOffset = [&](PrimExpr expr) -> PrimExpr {
int n = it->second->allocs.size(); if (expr.dtype() == offset_dtype) {
if (n > static_cast<int>(reuse_allocs.size())) { return expr;
reuse_allocs.resize(n, {});
}
for (int i = 0; i < n; i++) {
for (const VarNode *alloc : it->second->allocs[i]) {
reuse_allocs[i].push_back(alloc);
}
}
if (mem_ct >= const_nbits) {
break;
}
} }
reuse_allocs.push_back({op->buffer_var.get()}); return cast(offset_dtype, expr);
if (mem_ct != 0) { };
StorageEntry *e = arena_.make<StorageEntry>();
e->const_nbits = std::max(const_nbits, mem_ct); for (const BufInfo &info : buf_infos) {
e->allocs = reuse_allocs; PrimExpr offset_expr;
for (auto it : delete_it) { auto it = plan.offsets.find(info.var);
const_free_map_.erase(it); if (it != plan.offsets.end()) {
} offset_expr =
return e; make_const(offset_dtype, static_cast<int64_t>(it->second));
} else {
// Dynamic-sized buffers are appended after the constant arena.
cursor = AlignPrimExpr(cursor, info.alignment);
PrimExpr size_expr = CastToOffset(info.size_expr);
offset_expr = cursor;
cursor = offset_expr + size_expr;
} }
} else {
// if its symbolic allocation, just arbitrarily choose one entry to fit in buffer_byte_offsets_[info.var] = offset_expr;
// because we don't know its actual size PrimExpr buf_end = offset_expr + CastToOffset(info.size_expr);
for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) { total_size = max(total_size, buf_end);
StorageEntry *e = *it; }
sym_free_list_.erase(it);
return e; merged_alloc_size_ = buf_infos.empty()
? make_const(offset_dtype, 0)
: AlignPrimExpr(total_size, align_bytes_);
bool overlap_detected = false;
if (verbose_) {
LOG(DEBUG) << "Memory Allocation Plan for "
<< (is_dynamic_ ? "Dynamic" : "Static") << " Shared Memory:";
LOG(DEBUG) << " Total Merged Size (aligned): " << merged_alloc_size_;
for (const BufInfo &info : buf_infos) {
const PrimExpr &offset = buffer_byte_offsets_.at(info.var);
LOG(DEBUG) << " Buffer: " << info.name << " start=" << info.start
<< " end=" << info.end << " alignment=" << info.alignment
<< " offset=" << offset << " size=" << info.size_expr;
}
// Sanity check for overlapping constant buffers.
for (size_t i = 0; i < buf_infos.size(); ++i) {
const BufInfo &a = buf_infos[i];
auto a_off_imm = buffer_byte_offsets_.at(a.var).as<IntImmNode>();
if (!a.const_size_bytes.has_value() || a_off_imm == nullptr)
continue;
int64_t a_off = a_off_imm->value;
int64_t a_end = a_off + a.const_size_bytes.value();
for (size_t j = i + 1; j < buf_infos.size(); ++j) {
const BufInfo &b = buf_infos[j];
auto b_off_imm = buffer_byte_offsets_.at(b.var).as<IntImmNode>();
if (!b.const_size_bytes.has_value() || b_off_imm == nullptr)
continue;
bool live_overlap = !(a.end <= b.start || b.end <= a.start);
if (!live_overlap)
continue;
int64_t b_off = b_off_imm->value;
int64_t b_end = b_off + b.const_size_bytes.value();
bool mem_overlap = !(a_end <= b_off || b_end <= a_off);
if (mem_overlap) {
overlap_detected = true;
LOG(WARNING) << "Buffer overlap detected between " << a.name
<< " and " << b.name << " (lifetime overlap with "
<< "offset ranges [" << a_off << ", " << a_end
<< ") and [" << b_off << ", " << b_end << ")).";
}
}
} }
} }
return NewAlloc(op, const_nbits);
}
/*! if (overlap_detected) {
* \brief add the storage entry to the buffer var into the free list. LOG(WARNING) << "Detected overlapping constant buffers; falling back to "
* \param var the buffer var << "sequential allocation without reuse.";
*/ buffer_byte_offsets_.clear();
void Free(const VarNode *var) { // In the fallback path we simply lay buffers out sequentially.
auto it = alloc_map_.find(var); PrimExpr new_cursor = make_const(offset_dtype, 0);
ICHECK(it != alloc_map_.end()); PrimExpr new_total = make_const(offset_dtype, 0);
StorageEntry *e = it->second; for (const BufInfo &info : buf_infos) {
ICHECK_NE(e->allocs.size(), 0U); new_cursor = AlignPrimExpr(new_cursor, info.alignment);
PrimExpr size_expr = CastToOffset(info.size_expr);
// normal free. buffer_byte_offsets_[info.var] = new_cursor;
if (e->const_nbits != 0) { PrimExpr buf_end = new_cursor + size_expr;
const_free_map_.insert({e->const_nbits, e}); new_total = max(new_total, buf_end);
} else { new_cursor = buf_end;
sym_free_list_.push_back(e); }
merged_alloc_size_ = buf_infos.empty()
? make_const(offset_dtype, 0)
: AlignPrimExpr(new_total, align_bytes_);
} }
} }
// Whether enable dynamic analysis. // Whether enable dynamic analysis.
bool is_dynamic_{true}; bool is_dynamic_{true};
...@@ -1095,14 +1296,6 @@ private: ...@@ -1095,14 +1296,6 @@ private:
bool allocated_{false}; bool allocated_{false};
// Locations of free ops. // Locations of free ops.
std::unordered_map<const Object *, EventEntry> event_map_; std::unordered_map<const Object *, EventEntry> event_map_;
// constant size free map.
std::multimap<uint64_t, StorageEntry *> const_free_map_;
// symbolic free list, for non constant items.
std::list<StorageEntry *> sym_free_list_;
// The allocation assign map
std::unordered_map<const VarNode *, StorageEntry *> alloc_map_;
/*! \brief allocator of all the StorageEntry*/
support::Arena arena_;
// The mapping of buffer bytes alignment // The mapping of buffer bytes alignment
std::unordered_map<const VarNode *, int> shmem_alignment_map_; std::unordered_map<const VarNode *, int> shmem_alignment_map_;
}; };
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include "../op/builtin.h"
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
namespace tvm { namespace tvm {
...@@ -301,6 +302,24 @@ void TileLangStorageAccessVisitor::VisitStmt_(const WhileNode *op) { ...@@ -301,6 +302,24 @@ void TileLangStorageAccessVisitor::VisitStmt_(const WhileNode *op) {
} }
void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
// Mark async TMA load context so that tvm_access_ptr within the call
// can be tagged accordingly.
auto is_tma_load = [&]() {
if (auto opt = op->op.as<Op>()) {
const Op &call_op = opt.value();
return call_op.same_as(tl::tma_load()) ||
call_op.same_as(tl::tma_load_im2col());
}
return false;
}();
if (is_tma_load) {
tma_depth_++;
for (const auto &a : op->args) {
this->VisitExpr(a);
}
tma_depth_--;
return;
}
if (op->op.same_as(builtin::address_of())) { if (op->op.same_as(builtin::address_of())) {
ICHECK_EQ(op->args.size(), 1U); ICHECK_EQ(op->args.size(), 1U);
if (auto load = op->args[0].as<BufferLoadNode>()) { if (auto load = op->args[0].as<BufferLoadNode>()) {
...@@ -395,10 +414,12 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) { ...@@ -395,10 +414,12 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
e.scope = scope; e.scope = scope;
if (flag->value & 1) { if (flag->value & 1) {
e.type = kRead; e.type = kRead;
e.is_async_copy = (tma_depth_ > 0);
curr_stmt_.access.emplace_back(e); curr_stmt_.access.emplace_back(e);
} }
if (flag->value & 2) { if (flag->value & 2) {
e.type = kWrite; e.type = kWrite;
e.is_async_copy = (tma_depth_ > 0);
curr_stmt_.access.emplace_back(e); curr_stmt_.access.emplace_back(e);
} }
} }
......
...@@ -83,6 +83,10 @@ public: ...@@ -83,6 +83,10 @@ public:
bool double_buffer_write = false; bool double_buffer_write = false;
/*! \brief Whether the access is pointer access */ /*! \brief Whether the access is pointer access */
bool is_pointer_access = false; bool is_pointer_access = false;
/*! \brief Whether this access originates from an async copy context
* (e.g., inside a TMA load) and therefore multiple writes
* among themselves should not force barriers between them. */
bool is_async_copy = false;
}; };
/*! \brief Access pattern about a single statement */ /*! \brief Access pattern about a single statement */
...@@ -159,6 +163,8 @@ private: ...@@ -159,6 +163,8 @@ private:
bool allow_append_{false}; bool allow_append_{false};
// Whether we are in device environment // Whether we are in device environment
bool in_device_env_{false}; bool in_device_env_{false};
// Nesting depth of tma_load/tma_load_im2col calls
int tma_depth_{0};
// Whether we are inside condition. // Whether we are inside condition.
int condition_counter_{0}; int condition_counter_{0};
// The current double buffer write scope. // The current double buffer write scope.
......
...@@ -86,6 +86,7 @@ protected: ...@@ -86,6 +86,7 @@ protected:
// check if sync before statement is needed. // check if sync before statement is needed.
bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0); bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
// Apply the syncs added already. // Apply the syncs added already.
if (sync_before_stmt) { if (sync_before_stmt) {
reads.clear(); reads.clear();
writes.clear(); writes.clear();
...@@ -98,7 +99,8 @@ protected: ...@@ -98,7 +99,8 @@ protected:
break; break;
} }
} else if (acc.type == kWrite) { } else if (acc.type == kWrite) {
if (FindConflict(reads, acc, false)) { if (FindConflict(reads, acc, false) ||
FindConflict(writes, acc, false)) {
sync_before_stmt = true; sync_before_stmt = true;
break; break;
} }
...@@ -123,27 +125,51 @@ protected: ...@@ -123,27 +125,51 @@ protected:
writes.clear(); writes.clear();
} }
} }
if (sync_before_stmt) { if (sync_before_stmt) {
insert_syncs(s.stmt); insert_syncs(s.stmt);
} }
} }
if (loop != nullptr) { if (loop != nullptr) {
// Check if the loop body contains any reads in the same sync scope.
// If there are reads, we conservatively keep the sync within the loop
// body to preserve per-iteration ordering when needed. If there are no
// reads (e.g., only writes to shared.dyn), we can safely hoist the sync
// to before the loop to avoid redundant barriers.
bool has_read_in_scope = false;
for (const StmtEntry &s : seq) {
for (const AccessEntry &acc : s.access) {
if (acc.type == kRead && acc.scope == sync_scope_) {
has_read_in_scope = true;
break;
}
}
if (has_read_in_scope)
break;
}
// If there is a loop-carried dependency, insert a single sync
// before the loop rather than hoisting a sync into the loop body.
// This reduces redundant per-iteration synchronizations for cases
// where each iteration touches disjoint regions (e.g., stmatrix
// writes to shared.dyn) and only a global ordering before/after the
// loop is required.
for (size_t i = 0; i < seq.size(); ++i) { for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry &s = seq[i]; const StmtEntry &s = seq[i];
if (syncs_inserted_.count(s.stmt) != 0) if (syncs_inserted_.count(s.stmt) != 0)
break; break;
if (reads.empty() && writes.empty()) if (reads.empty() && writes.empty())
break; break;
bool sync_before_stmt = false; bool need_loop_sync = false;
for (const AccessEntry &acc : s.access) { for (const AccessEntry &acc : s.access) {
if (acc.type == kRead) { if (acc.type == kRead) {
if (FindConflict(writes, acc, true)) { if (FindConflict(writes, acc, true)) {
sync_before_stmt = true; need_loop_sync = true;
break; break;
} }
} else if (acc.type == kWrite) { } else if (acc.type == kWrite) {
if (FindConflict(reads, acc, true)) { if (FindConflict(reads, acc, true) ||
sync_before_stmt = true; FindConflict(writes, acc, true)) {
need_loop_sync = true;
break; break;
} }
} else if (acc.type == kSync) { } else if (acc.type == kSync) {
...@@ -151,8 +177,17 @@ protected: ...@@ -151,8 +177,17 @@ protected:
writes.clear(); writes.clear();
} }
} }
if (sync_before_stmt) { if (need_loop_sync) {
insert_syncs(s.stmt); if (!has_read_in_scope) {
// Mark the loop itself to receive a sync before it, instead of
// inserting inside the loop body. This ensures a single sync is
// emitted outside the loop and avoids per-iteration overhead.
insert_syncs(loop);
} else {
// Fall back to inserting before the first conflicting statement
// inside the loop to maintain correctness when reads are present.
insert_syncs(s.stmt);
}
break; break;
} }
} }
...@@ -217,6 +252,14 @@ private: ...@@ -217,6 +252,14 @@ private:
bool FindConflict(const AccessEntry &prev, const AccessEntry &curr, bool FindConflict(const AccessEntry &prev, const AccessEntry &curr,
bool loop_carry) { bool loop_carry) {
// Special case: ignore conflicts between async-copy writes (e.g., TMA
// loads into shared memory). Multiple async writes do not require
// interspersed barriers among themselves. We still respect conflicts with
// reads to ensure visibility before consumption.
if (prev.type == kWrite && curr.type == kWrite && prev.is_async_copy &&
curr.is_async_copy) {
return false;
}
// Access to different buffers does not conflict. // Access to different buffers does not conflict.
if (!prev.buffer.same_as(curr.buffer)) { if (!prev.buffer.same_as(curr.buffer)) {
return false; return false;
...@@ -241,10 +284,15 @@ private: ...@@ -241,10 +284,15 @@ private:
return true; return true;
} }
if (prev.is_pointer_access || curr.is_pointer_access) { if (prev.is_pointer_access || curr.is_pointer_access) {
// If either access is a pointer access, conservatively assume a // For accesses created via tvm_access_ptr we may still be able to prove
// conflict. For example, address_of(A[0, 0]) may refer to an unknown // disjointness using their byte ranges. If both sides expose a touched
// memory region, so we cannot safely determine if it overlaps with // interval and we can show they don't overlap, skip the conflict.
// previous accesses. if (prev.is_pointer_access && curr.is_pointer_access &&
PointerAccessIsDisjoint(prev, curr)) {
return false;
}
// Otherwise fall back to the conservative answer: treat them as
// overlapping.
return true; return true;
} }
...@@ -327,7 +375,7 @@ private: ...@@ -327,7 +375,7 @@ private:
} }
} }
if (!(has_same_index)) { if (!has_same_index) {
break; break;
} }
} }
...@@ -350,6 +398,26 @@ private: ...@@ -350,6 +398,26 @@ private:
return range_is_overlap; return range_is_overlap;
} }
bool PointerAccessIsDisjoint(const AccessEntry &lhs, const AccessEntry &rhs) {
if (lhs.touched.size() != 1 || rhs.touched.size() != 1) {
return false;
}
PrimExpr lhs_min = analyzer_.Simplify(lhs.touched[0].min());
PrimExpr lhs_max = analyzer_.Simplify(lhs.touched[0].max());
PrimExpr rhs_min = analyzer_.Simplify(rhs.touched[0].min());
PrimExpr rhs_max = analyzer_.Simplify(rhs.touched[0].max());
if (analyzer_.CanProve(lhs_max < rhs_min,
arith::ProofStrength::kSymbolicBound)) {
return true;
}
if (analyzer_.CanProve(rhs_max < lhs_min,
arith::ProofStrength::kSymbolicBound)) {
return true;
}
return false;
}
void VisitStmt_(const AttrStmtNode *op) final { void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tvm::tir::attr::thread_extent) { if (op->attr_key == tvm::tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node); IterVar iv = Downcast<IterVar>(op->node);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment