"docs/vscode:/vscode.git/clone" did not exist on "87b9db644b9034bf316811918722f5e09c676b1f"
Unverified Commit f7ba45d8 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Implement classic arena algorithm for shmem merge and WAW conflict detection (#1146)

* atomic_fix

* atomic_fix

* mem fix

* lint fix

* add some comments

* fix

* fix

* lint fix

* handle async copy

* lint fix
parent c70b2697
......@@ -29,6 +29,7 @@
#include <string>
#include <utility>
#include "../op/builtin.h"
#include "tir/transforms/ir_utils.h"
namespace tvm {
......@@ -301,6 +302,24 @@ void TileLangStorageAccessVisitor::VisitStmt_(const WhileNode *op) {
}
void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
// Mark async TMA load context so that tvm_access_ptr within the call
// can be tagged accordingly.
auto is_tma_load = [&]() {
if (auto opt = op->op.as<Op>()) {
const Op &call_op = opt.value();
return call_op.same_as(tl::tma_load()) ||
call_op.same_as(tl::tma_load_im2col());
}
return false;
}();
if (is_tma_load) {
tma_depth_++;
for (const auto &a : op->args) {
this->VisitExpr(a);
}
tma_depth_--;
return;
}
if (op->op.same_as(builtin::address_of())) {
ICHECK_EQ(op->args.size(), 1U);
if (auto load = op->args[0].as<BufferLoadNode>()) {
......@@ -395,10 +414,12 @@ void TileLangStorageAccessVisitor::VisitExpr_(const CallNode *op) {
e.scope = scope;
if (flag->value & 1) {
e.type = kRead;
e.is_async_copy = (tma_depth_ > 0);
curr_stmt_.access.emplace_back(e);
}
if (flag->value & 2) {
e.type = kWrite;
e.is_async_copy = (tma_depth_ > 0);
curr_stmt_.access.emplace_back(e);
}
}
......
......@@ -83,6 +83,10 @@ public:
bool double_buffer_write = false;
/*! \brief Whether the access is pointer access */
bool is_pointer_access = false;
/*! \brief Whether this access originates from an async copy context
* (e.g., inside a TMA load) and therefore multiple writes
* among themselves should not force barriers between them. */
bool is_async_copy = false;
};
/*! \brief Access pattern about a single statement */
......@@ -159,6 +163,8 @@ private:
bool allow_append_{false};
// Whether we are in device environment
bool in_device_env_{false};
// Nesting depth of tma_load/tma_load_im2col calls
int tma_depth_{0};
// Whether we are inside condition.
int condition_counter_{0};
// The current double buffer write scope.
......
......@@ -86,6 +86,7 @@ protected:
// check if sync before statement is needed.
bool sync_before_stmt = (syncs_inserted_.count(s.stmt) != 0);
// Apply the syncs added already.
if (sync_before_stmt) {
reads.clear();
writes.clear();
......@@ -98,7 +99,8 @@ protected:
break;
}
} else if (acc.type == kWrite) {
if (FindConflict(reads, acc, false)) {
if (FindConflict(reads, acc, false) ||
FindConflict(writes, acc, false)) {
sync_before_stmt = true;
break;
}
......@@ -123,27 +125,51 @@ protected:
writes.clear();
}
}
if (sync_before_stmt) {
insert_syncs(s.stmt);
}
}
if (loop != nullptr) {
// Check if the loop body contains any reads in the same sync scope.
// If there are reads, we conservatively keep the sync within the loop
// body to preserve per-iteration ordering when needed. If there are no
// reads (e.g., only writes to shared.dyn), we can safely hoist the sync
// to before the loop to avoid redundant barriers.
bool has_read_in_scope = false;
for (const StmtEntry &s : seq) {
for (const AccessEntry &acc : s.access) {
if (acc.type == kRead && acc.scope == sync_scope_) {
has_read_in_scope = true;
break;
}
}
if (has_read_in_scope)
break;
}
// If there is a loop-carried dependency, insert a single sync
// before the loop rather than hoisting a sync into the loop body.
// This reduces redundant per-iteration synchronizations for cases
// where each iteration touches disjoint regions (e.g., stmatrix
// writes to shared.dyn) and only a global ordering before/after the
// loop is required.
for (size_t i = 0; i < seq.size(); ++i) {
const StmtEntry &s = seq[i];
if (syncs_inserted_.count(s.stmt) != 0)
break;
if (reads.empty() && writes.empty())
break;
bool sync_before_stmt = false;
bool need_loop_sync = false;
for (const AccessEntry &acc : s.access) {
if (acc.type == kRead) {
if (FindConflict(writes, acc, true)) {
sync_before_stmt = true;
need_loop_sync = true;
break;
}
} else if (acc.type == kWrite) {
if (FindConflict(reads, acc, true)) {
sync_before_stmt = true;
if (FindConflict(reads, acc, true) ||
FindConflict(writes, acc, true)) {
need_loop_sync = true;
break;
}
} else if (acc.type == kSync) {
......@@ -151,8 +177,17 @@ protected:
writes.clear();
}
}
if (sync_before_stmt) {
insert_syncs(s.stmt);
if (need_loop_sync) {
if (!has_read_in_scope) {
// Mark the loop itself to receive a sync before it, instead of
// inserting inside the loop body. This ensures a single sync is
// emitted outside the loop and avoids per-iteration overhead.
insert_syncs(loop);
} else {
// Fall back to inserting before the first conflicting statement
// inside the loop to maintain correctness when reads are present.
insert_syncs(s.stmt);
}
break;
}
}
......@@ -217,6 +252,14 @@ private:
bool FindConflict(const AccessEntry &prev, const AccessEntry &curr,
bool loop_carry) {
// Special case: ignore conflicts between async-copy writes (e.g., TMA
// loads into shared memory). Multiple async writes do not require
// interspersed barriers among themselves. We still respect conflicts with
// reads to ensure visibility before consumption.
if (prev.type == kWrite && curr.type == kWrite && prev.is_async_copy &&
curr.is_async_copy) {
return false;
}
// Access to different buffers does not conflict.
if (!prev.buffer.same_as(curr.buffer)) {
return false;
......@@ -241,10 +284,15 @@ private:
return true;
}
if (prev.is_pointer_access || curr.is_pointer_access) {
// If either access is a pointer access, conservatively assume a
// conflict. For example, address_of(A[0, 0]) may refer to an unknown
// memory region, so we cannot safely determine if it overlaps with
// previous accesses.
// For accesses created via tvm_access_ptr we may still be able to prove
// disjointness using their byte ranges. If both sides expose a touched
// interval and we can show they don't overlap, skip the conflict.
if (prev.is_pointer_access && curr.is_pointer_access &&
PointerAccessIsDisjoint(prev, curr)) {
return false;
}
// Otherwise fall back to the conservative answer: treat them as
// overlapping.
return true;
}
......@@ -327,7 +375,7 @@ private:
}
}
if (!(has_same_index)) {
if (!has_same_index) {
break;
}
}
......@@ -350,6 +398,26 @@ private:
return range_is_overlap;
}
bool PointerAccessIsDisjoint(const AccessEntry &lhs, const AccessEntry &rhs) {
if (lhs.touched.size() != 1 || rhs.touched.size() != 1) {
return false;
}
PrimExpr lhs_min = analyzer_.Simplify(lhs.touched[0].min());
PrimExpr lhs_max = analyzer_.Simplify(lhs.touched[0].max());
PrimExpr rhs_min = analyzer_.Simplify(rhs.touched[0].min());
PrimExpr rhs_max = analyzer_.Simplify(rhs.touched[0].max());
if (analyzer_.CanProve(lhs_max < rhs_min,
arith::ProofStrength::kSymbolicBound)) {
return true;
}
if (analyzer_.CanProve(rhs_max < lhs_min,
arith::ProofStrength::kSymbolicBound)) {
return true;
}
return false;
}
void VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tvm::tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment