Unverified Commit f7ba45d8 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Implement classic arena algorithm for shmem merge and WAW conflict detection (#1146)

* atomic_fix

* atomic_fix

* mem fix

* lint fix

* add some comments

* fix

* fix

* lint fix

* handle async copy

* lint fix
parent c70b2697
......@@ -31,6 +31,12 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <algorithm>
#include <functional>
#include <limits>
#include <optional>
#include <queue>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
#include <utility>
......@@ -38,7 +44,6 @@
#include "../op/builtin.h"
#include "../target/utils.h"
#include "runtime/thread_storage_scope.h"
#include "support/arena.h"
#include "tir/transforms/ir_utils.h"
#include "tvm/tir/function.h"
......@@ -141,6 +146,8 @@ public:
void VisitStmt_(const AllocateNode *op) final {
size_t level = scope_.size();
const VarNode *buf = op->buffer_var.get();
// Record the allocation site and depth so liveness can reason about the
// original scope.
alloc_info_[buf].alloc = op;
alloc_info_[buf].level = level;
StmtExprVisitor::VisitStmt_(op);
......@@ -194,9 +201,12 @@ public:
const VarNode *buf = op->buffer->data.get();
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
// Allow buffer access at the same level or deeper scope
// Changed from < to <= to handle cases where buffer is accessed
// in expressions at the same scope level where it's allocated
// Earlier we required `alloc_level < scope_.size()`, assuming every load
// would occur strictly inside a nested scope. In practice the lowering
// pipeline may materialise reads in the very same frame that owns the
// allocation (e.g. when the buffer value is passed directly to a call),
// which used to trigger the CHECK. Treat same-level accesses as valid so
// the merged allocator can reason about their lifetime correctly.
ICHECK_LE(it->second.level, scope_.size())
<< "Load memory in places other than store.";
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
......@@ -204,7 +214,10 @@ public:
if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf);
} else {
// When accessing at the same level, use that level
// When the access happens in the same scope frame as the allocation
// we attribute it to that frame instead of the outer parent. This
// keeps the liveness window tight while still accounting for nested
// scopes that legitimately touch the buffer deeper in the tree.
size_t access_level = std::min(it->second.level, scope_.size() - 1);
scope_[access_level].touched.push_back(buf);
}
......@@ -216,14 +229,17 @@ public:
// Directly reference to the variable count as a read.
auto it = alloc_info_.find(buf);
if (it != alloc_info_.end() && it->second.alloc) {
// Allow buffer access at the same level or deeper scope
// Same rationale as the BufferLoad path above: direct references can be
// emitted at the allocation level after flattening, so accept them and
// record the touch for liveness planning.
ICHECK_LE(it->second.level, scope_.size());
if (IsAppropriateSharedMemory(GetRef<Var>(buf))) {
auto enable_aggressive_merge = enable_aggressive_merge_;
if (enable_aggressive_merge) {
scope_[scope_.size() - 1].touched.push_back(buf);
} else {
// When accessing at the same level, use that level
// Attribute same-level uses to the allocation frame, mirroring the
// BufferLoad handling to keep reuse decisions consistent.
size_t access_level = std::min(it->second.level, scope_.size() - 1);
scope_[access_level].touched.push_back(buf);
}
......@@ -245,6 +261,8 @@ public:
scope_.pop_back();
int64_t end_index = static_cast<int64_t>(linear_seq_.size());
ICHECK_GT(end_index, begin_index);
// The paired entries serve as scope sentinels once we flatten the
// control-flow tree.
e.scope_pair_offset = begin_index - end_index;
linear_seq_.push_back(e);
// record the pointer to end index.
......@@ -338,7 +356,11 @@ public:
private:
void VisitExpr_(const CallNode *op) {
if (op->op.same_as(tl::tl_gemm()) || op->op.same_as(tl::tl_gemm_sp()) ||
op->op.same_as(tl::tma_load()) || op->op.same_as(tl::tma_store())) {
op->op.same_as(tl::tma_load()) || op->op.same_as(tl::tma_store()) ||
op->op.same_as(tl::ptx_wgmma_ss()) ||
op->op.same_as(tl::ptx_wgmma_rs())) {
// These intrinsics introduce stricter SMEM alignment requirements; mark
// the subtree.
under_alignment_scope_ = true;
StmtExprVisitor::VisitExpr_(op);
under_alignment_scope_ = false;
......@@ -394,6 +416,8 @@ public:
enable_aggressive_merge, verbose);
finder(stmt);
shmem_alignment_map_ = SharedMemoryAlignmentPlanner::Plan(stmt);
// First compute liveness over the flattened schedule, then feed it into the
// arena packer.
this->LivenessAnalysis(finder.linear_seq_, finder.stmt_attrs_);
this->PlanMemory(finder.linear_seq_, finder.stmt_attrs_);
}
......@@ -403,65 +427,6 @@ private:
if (op->attr_key == tir::attr::thread_extent && !allocated_) {
// Allocate one dynamic shared memory allocation at the beginning of
// thread scope
int max_layer_num = 0;
std::vector<const StorageEntry *> all_entry;
for (const auto &e : const_free_map_) {
all_entry.push_back(e.second);
}
for (const StorageEntry *e : sym_free_list_) {
all_entry.push_back(e);
}
// Sort the storage entries in descending order of their total allocation
// size (in bits). This ensures that larger allocations are placed first,
// which can help minimize fragmentation and improve memory packing
// efficiency when merging shared memory buffers.
std::sort(all_entry.begin(), all_entry.end(),
[](const StorageEntry *a, const StorageEntry *b) {
return a->const_nbits > b->const_nbits;
});
for (const StorageEntry *e : all_entry) {
max_layer_num =
std::max(max_layer_num, static_cast<int>(e->allocs.size()));
}
// calculate align for each layer of each storage entry.
std::vector<int> align(max_layer_num, 0);
for (const StorageEntry *e : all_entry) {
for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
for (const VarNode *buffer : e->allocs[i]) {
const AllocateNode *alloc = shmem_allocs_[buffer];
align[i] =
std::max(align[i], alloc->dtype.bytes() * alloc->dtype.lanes());
align[i] = std::max(align[i], align_bytes_);
}
}
}
for (const StorageEntry *e : all_entry) {
PrimExpr max_inner_offset = 0;
for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
PrimExpr inner_offset = 0;
for (const VarNode *buffer : e->allocs[i]) {
const AllocateNode *alloc = shmem_allocs_[buffer];
auto alignment = align[i];
// Modern nvidia architecture performs hardware swizzling (hopper
// wgmma/tma for example) requires dynamic shared memory address to
// be aligned to 1024 bytes For other devices, we align to 16 bytes
if (shmem_alignment_map_.find(buffer) !=
shmem_alignment_map_.end()) {
alignment = std::max(align[i], shmem_alignment_map_[buffer]);
}
PrimExpr start_offset = merged_alloc_size_ + inner_offset;
PrimExpr aligned_offset =
indexdiv(start_offset + alignment - 1, alignment) * alignment;
buffer_byte_offsets_[buffer] = aligned_offset;
inner_offset =
aligned_offset - merged_alloc_size_ +
alloc->extents[0] * alloc->dtype.bytes() * alloc->dtype.lanes();
}
max_inner_offset = max(max_inner_offset, inner_offset);
}
merged_alloc_size_ += max_inner_offset;
}
if (verbose_) {
......@@ -626,18 +591,199 @@ private:
using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry;
using StmtAttr = SharedMemLinearAccessPatternFinder::StmtAttr;
struct StorageEntry {
// The constant size of the buffer in bits, only used if it is constant
uint64_t const_nbits{0};
// Allocs that shares this entry.
// The inner vector means a "layer"
// For example, it we need to allocate C in the memory of A and B:
// | A: 4096 bytes | B: 4096 bytes |
// | C: 8192 bytes |
// Then the allocs = {{A, B}, {C}}
std::vector<std::vector<const VarNode *>> allocs;
// Metadata about a single shared-memory allocation prior to merging. This
// is used to build lifetimes, alignment requirements, and final offsets.
struct BufInfo {
const VarNode *var{nullptr};
std::string name;
PrimExpr size_expr;
std::optional<int64_t> const_size_bytes; // in bytes if compile-time known.
int alignment{0}; // required byte alignment.
int start{0}; // first statement index touching the buf.
int end{0}; // one-past-last statement index.
DataType size_dtype{DataType::Int(32)};
};
// Interval describing the liveness window of a (constant-sized) allocation.
struct Interval {
int start{0};
int end{0};
size_t size_bytes{0};
int alignment{0};
const VarNode *var{nullptr};
};
// Result of a linear-scan arena packing. Offsets contain the byte offset for
// each constant-sized buffer, arena_size is the total constant footprint.
struct ArenaPlan {
size_t arena_size{0};
std::unordered_map<const VarNode *, size_t> offsets;
};
static size_t AlignUpSize(size_t value, size_t alignment) {
if (alignment == 0) {
return value;
}
size_t remainder = value % alignment;
if (remainder == 0) {
return value;
}
return value + (alignment - remainder);
}
struct FreeBlock {
size_t offset{0};
size_t size{0};
};
class FreeList {
public:
std::optional<size_t> Allocate(size_t need, size_t alignment) {
// Best-fit search: pick the slot that wastes the least space after
// alignment.
int best = -1;
size_t best_waste = std::numeric_limits<size_t>::max();
for (int i = 0, n = static_cast<int>(blocks_.size()); i < n; ++i) {
size_t aligned = AlignUpSize(blocks_[i].offset, alignment);
size_t head = aligned - blocks_[i].offset;
if (head <= blocks_[i].size && (blocks_[i].size - head) >= need) {
size_t waste = blocks_[i].size - head - need;
if (waste < best_waste) {
best_waste = waste;
best = i;
}
}
}
if (best < 0) {
return std::nullopt;
}
FreeBlock blk = blocks_[best];
size_t aligned = AlignUpSize(blk.offset, alignment);
size_t head = aligned - blk.offset;
size_t tail = blk.size - head - need;
blocks_.erase(blocks_.begin() + best);
if (head) {
blocks_.push_back({blk.offset, head});
}
if (tail) {
blocks_.push_back({aligned + need, tail});
}
Normalize();
return aligned;
}
void Free(size_t offset, size_t size) {
if (size == 0)
return;
blocks_.push_back({offset, size});
Normalize();
}
private:
void Normalize() {
if (blocks_.empty())
return;
std::sort(blocks_.begin(), blocks_.end(),
[](const FreeBlock &a, const FreeBlock &b) {
return a.offset < b.offset;
});
std::vector<FreeBlock> merged;
merged.reserve(blocks_.size());
for (const FreeBlock &blk : blocks_) {
if (merged.empty()) {
merged.push_back(blk);
continue;
}
FreeBlock &last = merged.back();
size_t last_end = last.offset + last.size;
if (blk.offset <= last_end) {
size_t blk_end = blk.offset + blk.size;
if (blk_end > last_end) {
last.size = blk_end - last.offset;
}
} else {
merged.push_back(blk);
}
}
blocks_ = std::move(merged);
}
std::vector<FreeBlock> blocks_;
};
struct ActiveInterval {
int end{0};
size_t offset{0};
size_t size{0};
const VarNode *var{nullptr};
bool operator>(const ActiveInterval &other) const {
return end > other.end;
}
};
static ArenaPlan LinearScanPack(std::vector<Interval> intervals) {
// Process intervals in program order so lifetimes correspond to the
// linearised CFG.
std::sort(intervals.begin(), intervals.end(),
[](const Interval &lhs, const Interval &rhs) {
if (lhs.start != rhs.start) {
return lhs.start < rhs.start;
}
if (lhs.size_bytes != rhs.size_bytes) {
return lhs.size_bytes > rhs.size_bytes;
}
return lhs.var < rhs.var;
});
std::priority_queue<ActiveInterval, std::vector<ActiveInterval>,
std::greater<ActiveInterval>>
active;
FreeList freelist;
size_t arena_top = 0;
std::unordered_map<const VarNode *, size_t> offsets;
// Expire intervals that end before or at program counter `pc`.
auto retire = [&](int pc) {
while (!active.empty() && active.top().end <= pc) {
const ActiveInterval top = active.top();
active.pop();
freelist.Free(top.offset, top.size);
}
};
for (const Interval &interval : intervals) {
retire(interval.start);
size_t offset = 0;
// Try to recycle previously freed memory first; fall back to bumping the
// arena.
if (auto slot =
freelist.Allocate(interval.size_bytes, interval.alignment)) {
offset = slot.value();
} else {
offset = AlignUpSize(arena_top, interval.alignment);
arena_top = offset + interval.size_bytes;
}
active.push(ActiveInterval{interval.end, offset, interval.size_bytes,
interval.var});
offsets[interval.var] = offset;
}
return ArenaPlan{arena_top, std::move(offsets)};
}
PrimExpr AlignPrimExpr(const PrimExpr &value, int alignment) const {
if (alignment <= 1) {
return value;
}
DataType dtype = value.dtype();
ICHECK(dtype.is_int() || dtype.is_uint())
<< "Expected integer dtype for alignment, but got " << dtype;
PrimExpr align_expr = make_const(dtype, alignment);
PrimExpr adjust = make_const(dtype, alignment - 1);
return indexdiv(value + adjust, align_expr) * align_expr;
}
// Event entry in liveness analysis
struct EventEntry {
// variables we generate
......@@ -905,173 +1051,228 @@ private:
void
PlanMemory(const std::vector<StmtEntry> &seq,
const std::unordered_map<const Object *, StmtAttr> &stmt_attrs) {
std::unordered_set<const VarNode *> inplace_flag;
buffer_byte_offsets_.clear();
(void)stmt_attrs;
if (shmem_allocs_.empty()) {
merged_alloc_size_ = make_const(DataType::Int(64), 0);
return;
}
// Discover the first and last touch for every allocation.
std::unordered_map<const VarNode *, int> start_index;
std::unordered_map<const VarNode *, int> end_index;
for (size_t i = 0; i < seq.size(); ++i) {
auto it = event_map_.find(seq[i].stmt);
// scope_pair_offset <= 0 means it is either
// - leaf stmt(offset = 0)
// - end of scope(offset < 0)
// In both cases, we need to handle the kill event correctly
auto is_leaf_alloc = [&](const VarNode *var) {
return seq[i].scope_pair_offset == 0 &&
std::find(it->second.gen.begin(), it->second.gen.end(), var) !=
it->second.gen.end();
};
if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
if (it == event_map_.end())
continue;
for (const VarNode *var : it->second.gen) {
start_index.emplace(var, static_cast<int>(i));
}
for (const VarNode *var : it->second.kill) {
if (!is_leaf_alloc(var))
this->Free(var);
end_index[var] = std::max(end_index[var], static_cast<int>(i) + 1);
}
}
// scope_pair_offset >= 0 means it is either
// - leaf stmt(offset = 0)
// - beginning of scope(offset < 0)
// In both cases, we need to handle the gen event correctly
if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) {
for (const VarNode *var : it->second.gen) {
ICHECK(shmem_allocs_.count(var));
const AllocateNode *alloc = shmem_allocs_[var];
StorageEntry *dst_entry = FindAlloc(alloc);
alloc_map_[var] = dst_entry;
const int seq_len = static_cast<int>(seq.size());
for (const auto &kv : start_index) {
if (!end_index.count(kv.first)) {
end_index[kv.first] = seq_len;
}
}
if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
for (const VarNode *var : it->second.kill) {
if (is_leaf_alloc(var))
this->Free(var);
std::vector<BufInfo> buf_infos;
buf_infos.reserve(shmem_allocs_.size());
// Build a BufInfo for all allocations that participate in liveness.
for (const auto &kv : shmem_allocs_) {
const VarNode *var = kv.first;
auto start_it = start_index.find(var);
if (start_it == start_index.end()) {
continue;
}
BufInfo info;
info.var = var;
info.name = var->name_hint;
info.start = start_it->second;
info.end = std::max(end_index[var], info.start + 1);
info.alignment = align_bytes_;
auto align_it = shmem_alignment_map_.find(var);
if (align_it != shmem_alignment_map_.end()) {
info.alignment = std::max(info.alignment, align_it->second);
}
const AllocateNode *alloc = kv.second;
int64_t bytes_per_elem =
static_cast<int64_t>(alloc->dtype.bytes() * alloc->dtype.lanes());
DataType size_dtype = DataType::Int(32);
if (!alloc->extents.empty()) {
size_dtype = alloc->extents[0].dtype();
}
if (!size_dtype.is_int() && !size_dtype.is_uint()) {
size_dtype = DataType::Int(32);
}
/*!
* \brief Allocate new storage entry.
* \param op the allocate node
* \param the size of the allocation in bits
* \return the new storage entry
*/
StorageEntry *NewAlloc(const AllocateNode *op, size_t const_nbits) {
ICHECK(op != nullptr);
// Reuse not successful, allocate a new buffer.
StorageEntry *entry = arena_.make<StorageEntry>();
entry->allocs.push_back({op->buffer_var.get()});
entry->const_nbits = const_nbits;
return entry;
PrimExpr size_expr = make_const(size_dtype, bytes_per_elem);
for (const PrimExpr &extent : alloc->extents) {
PrimExpr e = extent;
if (e.dtype() != size_dtype) {
e = cast(size_dtype, e);
}
/*!
* @brief Locate or create a storage entry from free lists to satisfy an
* AllocateNode.
*
* Finds a reusable StorageEntry for the given AllocateNode (constant or
* symbolic size) using two-tiered strategies:
* - For constant-size allocations (>0): prefer a free entry that is >=
* required size; if none, coalesce smaller free constant-size entries until
* the sum meets the request and return a new StorageEntry representing the
* merged space. Very small constant allocations (<= 32 bits) are not reused
* and will allocate a fresh entry.
* - For symbolic-size (unknown at compile time): pick and remove an arbitrary
* entry from the symbolic free list.
*
* If no suitable free entry is found, a fresh StorageEntry is created via
* NewAlloc.
*
* @param op Pointer to the AllocateNode to satisfy. Must be non-null.
* @return StorageEntry* A storage entry that will hold the allocation (may be
* newly created).
*/
StorageEntry *FindAlloc(const AllocateNode *op) {
ICHECK(op != nullptr);
// skip plan for local variable,
// compiler can do a better job with register allocation.
const uint64_t match_range = 16;
uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes();
uint64_t const_nbits =
static_cast<uint64_t>(op->ConstantAllocationSize() * op_elem_bits);
// disable reuse of small arrays, they will be lowered to registers in LLVM
// This rules only apply if we are using non special memory
if (const_nbits > 0 && const_nbits <= 32) {
return NewAlloc(op, const_nbits);
}
if (const_nbits != 0) {
// constant allocation.
auto begin = const_free_map_.lower_bound(0);
auto mid = const_free_map_.lower_bound(const_nbits);
auto end = const_free_map_.upper_bound(const_nbits * match_range);
// Start looking at the buffer that is bigger than the required size
// first. If we find one, directly allocate the buffer in its location and
// remove its entry in the free list
for (auto it = mid; it != end; ++it) {
StorageEntry *e = it->second;
e->const_nbits = std::max(const_nbits, e->const_nbits);
const_free_map_.erase(it);
it->second->allocs.push_back({op->buffer_var.get()});
return e;
}
// Then start looking at smaller buffers.
// Keep collecting the buffer until the sum of their size exceeds the
// buffer to allocate and finally free all these entry in the free list
std::vector<std::multimap<uint64_t, StorageEntry *>::iterator> delete_it;
// the alloc list for the new entry
std::vector<std::vector<const VarNode *>> reuse_allocs;
uint64_t mem_ct = 0;
for (auto it = mid; it != begin;) {
--it;
delete_it.push_back(it);
mem_ct += it->second->const_nbits;
int n = it->second->allocs.size();
if (n > static_cast<int>(reuse_allocs.size())) {
reuse_allocs.resize(n, {});
}
for (int i = 0; i < n; i++) {
for (const VarNode *alloc : it->second->allocs[i]) {
reuse_allocs[i].push_back(alloc);
}
}
if (mem_ct >= const_nbits) {
break;
size_expr = size_expr * e;
}
info.size_dtype = size_dtype;
info.size_expr = size_expr;
int64_t const_extent = alloc->ConstantAllocationSize();
if (const_extent >= 0) {
info.const_size_bytes = const_extent * bytes_per_elem;
}
reuse_allocs.push_back({op->buffer_var.get()});
if (mem_ct != 0) {
StorageEntry *e = arena_.make<StorageEntry>();
e->const_nbits = std::max(const_nbits, mem_ct);
e->allocs = reuse_allocs;
for (auto it : delete_it) {
const_free_map_.erase(it);
buf_infos.push_back(std::move(info));
}
return e;
// Stable order so the later passes have deterministic behaviour.
std::sort(buf_infos.begin(), buf_infos.end(),
[](const BufInfo &a, const BufInfo &b) {
if (a.start != b.start)
return a.start < b.start;
if (a.end != b.end)
return a.end < b.end;
return a.name < b.name;
});
std::vector<Interval> intervals;
intervals.reserve(buf_infos.size());
for (const BufInfo &info : buf_infos) {
if (!info.const_size_bytes.has_value())
continue;
// Only constant-sized buffers participate in the arena packing because
// dynamic sizes must be placed sequentially later.
Interval interval;
interval.start = info.start;
interval.end = info.end;
interval.size_bytes = static_cast<size_t>(
std::max<int64_t>(0, info.const_size_bytes.value()));
interval.alignment = info.alignment;
interval.var = info.var;
intervals.push_back(interval);
}
} else {
// if its symbolic allocation, just arbitrarily choose one entry to fit in
// because we don't know its actual size
for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) {
StorageEntry *e = *it;
sym_free_list_.erase(it);
return e;
ArenaPlan plan = LinearScanPack(std::move(intervals));
size_t arena_size_const = plan.arena_size;
if (verbose_) {
LOG(DEBUG) << "ArenaPlan (constant buffers): arena_size="
<< arena_size_const;
for (const auto &kv : plan.offsets) {
const VarNode *var = kv.first;
LOG(DEBUG) << " " << var->name_hint << " -> offset=" << kv.second;
}
}
return NewAlloc(op, const_nbits);
// Cursor tracks the running byte offset within the merged arena.
DataType offset_dtype =
buf_infos.empty() ? DataType::Int(32) : buf_infos.front().size_dtype;
PrimExpr total_size = make_const(offset_dtype, 0);
PrimExpr cursor = AlignPrimExpr(
make_const(offset_dtype, static_cast<int64_t>(arena_size_const)),
align_bytes_);
auto CastToOffset = [&](PrimExpr expr) -> PrimExpr {
if (expr.dtype() == offset_dtype) {
return expr;
}
return cast(offset_dtype, expr);
};
/*!
* \brief add the storage entry to the buffer var into the free list.
* \param var the buffer var
*/
void Free(const VarNode *var) {
auto it = alloc_map_.find(var);
ICHECK(it != alloc_map_.end());
StorageEntry *e = it->second;
ICHECK_NE(e->allocs.size(), 0U);
// normal free.
if (e->const_nbits != 0) {
const_free_map_.insert({e->const_nbits, e});
for (const BufInfo &info : buf_infos) {
PrimExpr offset_expr;
auto it = plan.offsets.find(info.var);
if (it != plan.offsets.end()) {
offset_expr =
make_const(offset_dtype, static_cast<int64_t>(it->second));
} else {
sym_free_list_.push_back(e);
// Dynamic-sized buffers are appended after the constant arena.
cursor = AlignPrimExpr(cursor, info.alignment);
PrimExpr size_expr = CastToOffset(info.size_expr);
offset_expr = cursor;
cursor = offset_expr + size_expr;
}
buffer_byte_offsets_[info.var] = offset_expr;
PrimExpr buf_end = offset_expr + CastToOffset(info.size_expr);
total_size = max(total_size, buf_end);
}
merged_alloc_size_ = buf_infos.empty()
? make_const(offset_dtype, 0)
: AlignPrimExpr(total_size, align_bytes_);
bool overlap_detected = false;
if (verbose_) {
LOG(DEBUG) << "Memory Allocation Plan for "
<< (is_dynamic_ ? "Dynamic" : "Static") << " Shared Memory:";
LOG(DEBUG) << " Total Merged Size (aligned): " << merged_alloc_size_;
for (const BufInfo &info : buf_infos) {
const PrimExpr &offset = buffer_byte_offsets_.at(info.var);
LOG(DEBUG) << " Buffer: " << info.name << " start=" << info.start
<< " end=" << info.end << " alignment=" << info.alignment
<< " offset=" << offset << " size=" << info.size_expr;
}
// Sanity check for overlapping constant buffers.
for (size_t i = 0; i < buf_infos.size(); ++i) {
const BufInfo &a = buf_infos[i];
auto a_off_imm = buffer_byte_offsets_.at(a.var).as<IntImmNode>();
if (!a.const_size_bytes.has_value() || a_off_imm == nullptr)
continue;
int64_t a_off = a_off_imm->value;
int64_t a_end = a_off + a.const_size_bytes.value();
for (size_t j = i + 1; j < buf_infos.size(); ++j) {
const BufInfo &b = buf_infos[j];
auto b_off_imm = buffer_byte_offsets_.at(b.var).as<IntImmNode>();
if (!b.const_size_bytes.has_value() || b_off_imm == nullptr)
continue;
bool live_overlap = !(a.end <= b.start || b.end <= a.start);
if (!live_overlap)
continue;
int64_t b_off = b_off_imm->value;
int64_t b_end = b_off + b.const_size_bytes.value();
bool mem_overlap = !(a_end <= b_off || b_end <= a_off);
if (mem_overlap) {
overlap_detected = true;
LOG(WARNING) << "Buffer overlap detected between " << a.name
<< " and " << b.name << " (lifetime overlap with "
<< "offset ranges [" << a_off << ", " << a_end
<< ") and [" << b_off << ", " << b_end << ")).";
}
}
}
}
if (overlap_detected) {
LOG(WARNING) << "Detected overlapping constant buffers; falling back to "
<< "sequential allocation without reuse.";
buffer_byte_offsets_.clear();
// In the fallback path we simply lay buffers out sequentially.
PrimExpr new_cursor = make_const(offset_dtype, 0);
PrimExpr new_total = make_const(offset_dtype, 0);
for (const BufInfo &info : buf_infos) {
new_cursor = AlignPrimExpr(new_cursor, info.alignment);
PrimExpr size_expr = CastToOffset(info.size_expr);
buffer_byte_offsets_[info.var] = new_cursor;
PrimExpr buf_end = new_cursor + size_expr;
new_total = max(new_total, buf_end);
new_cursor = buf_end;
}
merged_alloc_size_ = buf_infos.empty()
? make_const(offset_dtype, 0)
: AlignPrimExpr(new_total, align_bytes_);
}
}
// Whether enable dynamic analysis.
bool is_dynamic_{true};
......@@ -1095,14 +1296,6 @@ private:
bool allocated_{false};
// Locations of free ops.
std::unordered_map<const Object *, EventEntry> event_map_;
// constant size free map.
std::multimap<uint64_t, StorageEntry *> const_free_map_;
// symbolic free list, for non constant items.
std::list<StorageEntry *> sym_free_list_;
// The allocation assign map
std::unordered_map<const VarNode *, StorageEntry *> alloc_map_;
/*! \brief allocator of all the StorageEntry*/
support::Arena arena_;
// The mapping of buffer bytes alignment
std::unordered_map<const VarNode *, int> shmem_alignment_map_;
};
......
......@@ -29,6 +29,7 @@
#include <string>
#include <utility>
#include "../op/builtin.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
......@@ -301,6 +302,24 @@ void TileLangStorageAccessVisitor::VisitStmt_(const WhileNode *op) {
}
void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
// Mark async TMA load context so that tvm_access_ptr within the call
// can be tagged accordingly.
auto is_tma_load = [&]() {
if (auto opt = op->op.as<Op>()) {
const Op &call_op = opt.value();
return call_op.same_as(tl::tma_load()) ||
call_op.same_as(tl::tma_load_im2col());
}
return false;
}();
if (is_tma_load) {
tma_depth_++;
for (const auto &a : op->args) {
this->VisitExpr(a);
}
tma_depth_--;
return;
}
if (op->op.same_as(builtin::address_of())) {
ICHECK_EQ(op->args.size(), 1U);
if (auto load = op->args[0].as<BufferLoadNode>()) {
......@@ -395,10 +414,12 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
e.scope = scope;
if (flag->value & 1) {
e.type = kRead;
e.is_async_copy = (tma_depth_ > 0);
curr_stmt_.access.emplace_back(e);
}
if (flag->value & 2) {
e.type = kWrite;
e.is_async_copy = (tma_depth_ > 0);
curr_stmt_.access.emplace_back(e);
}
}
......
......@@ -83,6 +83,10 @@ public:
bool double_buffer_write = false;
/*! \brief Whether the access is pointer access */
bool is_pointer_access = false;
/*! \brief Whether this access originates from an async copy context
* (e.g., inside a TMA load) and therefore multiple writes
* among themselves should not force barriers between them. */
bool is_async_copy = false;
};
/*! \brief Access pattern about a single statement */
......@@ -159,6 +163,8 @@ private:
bool allow_append_{false};
// Whether we are in device environment
bool in_device_env_{false};
// Nesting depth of tma_load/tma_load_im2col calls
int tma_depth_{0};
// Whether we are inside condition.
int condition_counter_{0};
// The current double buffer write scope.
......
......@@ -86,6 +86,7 @@ protected:
// check if sync before statement is needed.
bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
// Apply the syncs added already.
if (sync_before_stmt) {
reads.clear();
writes.clear();
......@@ -98,7 +99,8 @@ protected:
break;
}
} else if (acc.type == kWrite) {
if (FindConflict(reads, acc, false)) {
if (FindConflict(reads, acc, false) ||
FindConflict(writes, acc, false)) {
sync_before_stmt = true;
break;
}
......@@ -123,27 +125,51 @@ protected:
writes.clear();
}
}
if (sync_before_stmt) {
insert_syncs(s.stmt);
}
}
if (loop != nullptr) {
// Check if the loop body contains any reads in the same sync scope.
// If there are reads, we conservatively keep the sync within the loop
// body to preserve per-iteration ordering when needed. If there are no
// reads (e.g., only writes to shared.dyn), we can safely hoist the sync
// to before the loop to avoid redundant barriers.
bool has_read_in_scope = false;
for (const StmtEntry &s : seq) {
for (const AccessEntry &acc : s.access) {
if (acc.type == kRead && acc.scope == sync_scope_) {
has_read_in_scope = true;
break;
}
}
if (has_read_in_scope)
break;
}
// If there is a loop-carried dependency, insert a single sync
// before the loop rather than hoisting a sync into the loop body.
// This reduces redundant per-iteration synchronizations for cases
// where each iteration touches disjoint regions (e.g., stmatrix
// writes to shared.dyn) and only a global ordering before/after the
// loop is required.
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry &s = seq[i];
if (syncs_inserted_.count(s.stmt) != 0)
break;
if (reads.empty() && writes.empty())
break;
bool sync_before_stmt = false;
bool need_loop_sync = false;
for (const AccessEntry &acc : s.access) {
if (acc.type == kRead) {
if (FindConflict(writes, acc, true)) {
sync_before_stmt = true;
need_loop_sync = true;
break;
}
} else if (acc.type == kWrite) {
if (FindConflict(reads, acc, true)) {
sync_before_stmt = true;
if (FindConflict(reads, acc, true) ||
FindConflict(writes, acc, true)) {
need_loop_sync = true;
break;
}
} else if (acc.type == kSync) {
......@@ -151,8 +177,17 @@ protected:
writes.clear();
}
}
if (sync_before_stmt) {
if (need_loop_sync) {
if (!has_read_in_scope) {
// Mark the loop itself to receive a sync before it, instead of
// inserting inside the loop body. This ensures a single sync is
// emitted outside the loop and avoids per-iteration overhead.
insert_syncs(loop);
} else {
// Fall back to inserting before the first conflicting statement
// inside the loop to maintain correctness when reads are present.
insert_syncs(s.stmt);
}
break;
}
}
......@@ -217,6 +252,14 @@ private:
bool FindConflict(const AccessEntry &prev, const AccessEntry &curr,
bool loop_carry) {
// Special case: ignore conflicts between async-copy writes (e.g., TMA
// loads into shared memory). Multiple async writes do not require
// interspersed barriers among themselves. We still respect conflicts with
// reads to ensure visibility before consumption.
if (prev.type == kWrite && curr.type == kWrite && prev.is_async_copy &&
curr.is_async_copy) {
return false;
}
// Access to different buffers does not conflict.
if (!prev.buffer.same_as(curr.buffer)) {
return false;
......@@ -241,10 +284,15 @@ private:
return true;
}
if (prev.is_pointer_access || curr.is_pointer_access) {
// If either access is a pointer access, conservatively assume a
// conflict. For example, address_of(A[0, 0]) may refer to an unknown
// memory region, so we cannot safely determine if it overlaps with
// previous accesses.
// For accesses created via tvm_access_ptr we may still be able to prove
// disjointness using their byte ranges. If both sides expose a touched
// interval and we can show they don't overlap, skip the conflict.
if (prev.is_pointer_access && curr.is_pointer_access &&
PointerAccessIsDisjoint(prev, curr)) {
return false;
}
// Otherwise fall back to the conservative answer: treat them as
// overlapping.
return true;
}
......@@ -327,7 +375,7 @@ private:
}
}
if (!(has_same_index)) {
if (!has_same_index) {
break;
}
}
......@@ -350,6 +398,26 @@ private:
return range_is_overlap;
}
bool PointerAccessIsDisjoint(const AccessEntry &lhs, const AccessEntry &rhs) {
if (lhs.touched.size() != 1 || rhs.touched.size() != 1) {
return false;
}
PrimExpr lhs_min = analyzer_.Simplify(lhs.touched[0].min());
PrimExpr lhs_max = analyzer_.Simplify(lhs.touched[0].max());
PrimExpr rhs_min = analyzer_.Simplify(rhs.touched[0].min());
PrimExpr rhs_max = analyzer_.Simplify(rhs.touched[0].max());
if (analyzer_.CanProve(lhs_max < rhs_min,
arith::ProofStrength::kSymbolicBound)) {
return true;
}
if (analyzer_.CanProve(rhs_max < lhs_min,
arith::ProofStrength::kSymbolicBound)) {
return true;
}
return false;
}
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tvm::tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment