Unverified Commit e5399527 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[TIR] Revert some changes of Pass `LowerIntrin` (#1035)



* keep >> instead of /

* re think replicate

* lint fix

* handle const int buffers

* rep fix

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent 5767475a
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "parallel.h" #include "parallel.h"
#include <algorithm>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include "../layout/utils.h" #include "../layout/utils.h"
...@@ -413,22 +414,24 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -413,22 +414,24 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
// check if loop body contains a "pure" buffer store (i.e., direct // check if loop body contains a "pure" buffer store (i.e., direct
// assignment, not compound update) // assignment, not compound update)
bool has_pure_buffer_store = false; std::vector<Buffer> store_shared_global_buffers, store_fragment_buffers;
// Buffers that scope is above fragments.
// global, shared, shared.dyn
// which can be used to analysis replicate case
PostOrderVisit(root_, [&](const ObjectRef &obj) { PostOrderVisit(root_, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) { if (const auto *store = obj.as<BufferStoreNode>()) {
// Check if the value is a direct load from another buffer (i.e., b[i] auto buffer = store->buffer;
// = a[i]) if (buffer.scope() == "shared" || buffer.scope() == "shared.dyn" ||
if (const auto *load = store->value.as<BufferLoadNode>()) { buffer.scope() == "global") {
has_pure_buffer_store = true; store_shared_global_buffers.emplace_back(buffer);
} else if (buffer.scope() == "local.fragment") {
store_fragment_buffers.emplace_back(buffer);
} }
} }
}); });
if (read_source_buffer.defined() && allow_layout_propgate) { if (read_source_buffer.defined() && allow_layout_propgate) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer); loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
// // Loop don't need to be replicated.
// if (!is_one(loop_layout_->ReplicateExtent()))
// loop_layout_ = loop_layout_->DeReplicate();
} }
if (!loop_layout_.defined()) { if (!loop_layout_.defined()) {
...@@ -477,16 +480,73 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -477,16 +480,73 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = " DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = "
<< loop_layout_->DebugOutput() << '\n'; << loop_layout_->DebugOutput() << '\n';
} }
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access &&
!has_pure_buffer_store) { // Lambda that guards replicated accesses:
auto inv = loop_layout_->Inverse(); // - When a loop layout replicates a fragment buffer (rep > 1), each thread
Array<PrimExpr> fwd; // observes the same fragment elements. Blindly storing to shared/global
for (size_t i = 0; i < loop_layout_->OutputDim(); i++) // memory in that case would add the same value multiple times.
fwd.push_back(0); // - We therefore restrict the store so that only the replica with rep == 0
fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min); // performs the update (e.g. global[i] += fragment[i] only fires once).
auto rep = inv->Forward(fwd).back(); // Trigger conditions for this guard:
AddPredicate(EQ(rep, 0)); // 1) There are cross-thread stores targeting shared/global memory (no
} // fragment stores in this branch; atomic_add and similar remain TODO).
// 2) The loop layout replicate extent is greater than 1, inferred from the
// thread bounds captured in the layout.
[this, &store_shared_global_buffers, &store_fragment_buffers,
&has_cross_thread_access, &const_index_fragment_buffer, &T]() {
if (is_one(loop_layout_->ReplicateExtent()))
return;
if (!has_cross_thread_access)
return;
if (!store_fragment_buffers.empty()) {
// Iterate replicated fragment stores: when the fragment index is a
// constant (e.g. fragment[0]), every thread touches the same slot, so
// the rep == 0 predicate is unnecessary. Example: for i in
// T.Parallel(...):
// shared[i] = ...
// fragment[0] = ...
bool replicate_is_from_dynamic_index_fragment = false;
for (const auto &fragment : store_fragment_buffers) {
if (!T.layout_map.count(fragment)) {
continue;
}
auto fragment_layout = T.layout_map[fragment].as<Fragment>().value();
if (is_one(fragment_layout->ReplicateExtent()))
continue;
if (analyzer_.CanProveEqual(fragment_layout->ReplicateExtent(),
loop_layout_->ReplicateExtent()))
continue;
if (std::find(const_index_fragment_buffer.begin(),
const_index_fragment_buffer.end(),
fragment) == const_index_fragment_buffer.end()) {
replicate_is_from_dynamic_index_fragment = true;
}
}
if (!replicate_is_from_dynamic_index_fragment)
return;
ICHECK(store_shared_global_buffers.empty())
<< "Invalid layout: cannot have both fragment and shared store "
"buffers "
"in replicated loop layout.";
return;
} else {
// Now, store is global or shared
// or T.call_extern or T.call_intrin ...
auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
fwd.push_back(0);
fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0));
}
}();
} else { } else {
return {}; return {};
} }
......
...@@ -160,9 +160,8 @@ public: ...@@ -160,9 +160,8 @@ public:
// == truncdiv(a + b*c, b) - c // == truncdiv(a + b*c, b) - c
IntImm min(op->a->dtype.element_of(), const_int_bound->min_value); IntImm min(op->a->dtype.element_of(), const_int_bound->min_value);
PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b); PrimExpr ceildiv = truncdiv((op->b - 1) - min, op->b);
// Skip analyzer simplification so we preserve straightforward div PrimExpr offset_numerator =
// expressions. analyzer_->Simplify(op->a + op->b * ceildiv);
PrimExpr offset_numerator = op->a + op->b * ceildiv;
return truncdiv(offset_numerator, op->b) - ceildiv; return truncdiv(offset_numerator, op->b) - ceildiv;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment