Unverified Commit e5399527 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[TIR] Revert some changes of Pass `LowerIntrin` (#1035)



* keep >> instead of /

* re think replicate

* lint fix

* handle const int buffers

* rep fix

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent 5767475a
......@@ -5,6 +5,7 @@
#include "parallel.h"
#include <algorithm>
#include <tvm/tir/op.h>
#include "../layout/utils.h"
......@@ -413,22 +414,24 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
// check if loop body contains a "pure" buffer store (i.e., direct
// assignment, not compound update)
bool has_pure_buffer_store = false;
std::vector<Buffer> store_shared_global_buffers, store_fragment_buffers;
// Buffers that scope is above fragments.
// global, shared, shared.dyn
// which can be used to analysis replicate case
PostOrderVisit(root_, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
// Check if the value is a direct load from another buffer (i.e., b[i]
// = a[i])
if (const auto *load = store->value.as<BufferLoadNode>()) {
has_pure_buffer_store = true;
auto buffer = store->buffer;
if (buffer.scope() == "shared" || buffer.scope() == "shared.dyn" ||
buffer.scope() == "global") {
store_shared_global_buffers.emplace_back(buffer);
} else if (buffer.scope() == "local.fragment") {
store_fragment_buffers.emplace_back(buffer);
}
}
});
if (read_source_buffer.defined() && allow_layout_propgate) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
// // Loop don't need to be replicated.
// if (!is_one(loop_layout_->ReplicateExtent()))
// loop_layout_ = loop_layout_->DeReplicate();
}
if (!loop_layout_.defined()) {
......@@ -477,8 +480,64 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = "
<< loop_layout_->DebugOutput() << '\n';
}
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access &&
!has_pure_buffer_store) {
// Lambda that guards replicated accesses:
// - When a loop layout replicates a fragment buffer (rep > 1), each thread
// observes the same fragment elements. Blindly storing to shared/global
// memory in that case would add the same value multiple times.
// - We therefore restrict the store so that only the replica with rep == 0
// performs the update (e.g. global[i] += fragment[i] only fires once).
// Trigger conditions for this guard:
// 1) There are cross-thread stores targeting shared/global memory (no
// fragment stores in this branch; atomic_add and similar remain TODO).
// 2) The loop layout replicate extent is greater than 1, inferred from the
// thread bounds captured in the layout.
[this, &store_shared_global_buffers, &store_fragment_buffers,
&has_cross_thread_access, &const_index_fragment_buffer, &T]() {
if (is_one(loop_layout_->ReplicateExtent()))
return;
if (!has_cross_thread_access)
return;
if (!store_fragment_buffers.empty()) {
// Iterate replicated fragment stores: when the fragment index is a
// constant (e.g. fragment[0]), every thread touches the same slot, so
// the rep == 0 predicate is unnecessary. Example: for i in
// T.Parallel(...):
// shared[i] = ...
// fragment[0] = ...
bool replicate_is_from_dynamic_index_fragment = false;
for (const auto &fragment : store_fragment_buffers) {
if (!T.layout_map.count(fragment)) {
continue;
}
auto fragment_layout = T.layout_map[fragment].as<Fragment>().value();
if (is_one(fragment_layout->ReplicateExtent()))
continue;
if (analyzer_.CanProveEqual(fragment_layout->ReplicateExtent(),
loop_layout_->ReplicateExtent()))
continue;
if (std::find(const_index_fragment_buffer.begin(),
const_index_fragment_buffer.end(),
fragment) == const_index_fragment_buffer.end()) {
replicate_is_from_dynamic_index_fragment = true;
}
}
if (!replicate_is_from_dynamic_index_fragment)
return;
ICHECK(store_shared_global_buffers.empty())
<< "Invalid layout: cannot have both fragment and shared store "
"buffers "
"in replicated loop layout.";
return;
} else {
// Now, store is global or shared
// or T.call_extern or T.call_intrin ...
auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
......@@ -487,6 +546,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0));
}
}();
} else {
return {};
}
......
......@@ -160,9 +160,8 @@ public:
// == truncdiv(a + b*c, b) - c
IntImm min(op->a->dtype.element_of(), const_int_bound->min_value);
PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b);
// Skip analyzer simplification so we preserve straightforward div
// expressions.
PrimExpr offset_numerator = op->a + op->b * ceildiv;
PrimExpr offset_numerator =
analyzer_->Simplify(op->a + op->b * ceildiv);
return truncdiv(offset_numerator, op->b) - ceildiv;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment