Unverified Commit eed320f5 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Recover code for flexible parallel (#1032)



* recover flex parallel process

* lint fix

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent 1e8f0b18
......@@ -307,8 +307,10 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
// (const index frag_a interacts with non-const index frag_b)
// - No propagation needed: shared_a[i] = frag_a[0]
// (const index frag_a with non-fragment buffer)
bool allow_layout_propgate =
fragment_buffers.size() > const_index_fragment_buffer.size();
const_index_fragment_buffer.empty() ||
(fragment_buffers.size() > const_index_fragment_buffer.size());
// Step 1: try to infer loop's partition from a source fragment
Buffer source_buffer, read_source_buffer;
......@@ -361,7 +363,15 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep);
loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread);
PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) {
if (auto opt_var = objref.as<Var>();
opt_var && inner_vars_.count(*opt_var)) {
std::ostringstream oss;
oss << "loop_var_to_thread = " << loop_var_to_thread
<< "contains inner var" << *opt_var;
throw LayoutConflictException(oss.str());
}
});
result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
->BindThreadRange(T.thread_bounds);
}
......@@ -379,57 +389,46 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
if (source_buffer.defined() && allow_layout_propgate) {
loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
} else if (level == InferLevel::kFree) {
// For free layout inference
// If replication exists and buffer has cross-thread shared memory access,
// add predicate
bool has_cross_thread_access = false;
PostOrderVisit(root_, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
// check if scope is shared or global
if (store->buffer.scope() == "shared" ||
store->buffer.scope() == "shared.dyn" ||
store->buffer.scope() == "global") {
has_cross_thread_access = true;
}
} else if (const auto *load = obj.as<BufferLoadNode>()) {
// check if scope is shared or global
if (load->buffer.scope() == "shared" ||
load->buffer.scope() == "shared.dyn" ||
load->buffer.scope() == "global") {
has_cross_thread_access = true;
}
}
});
// check if loop body contains a "pure" buffer store (i.e., direct
// assignment, not compound update)
bool has_pure_buffer_store = false;
PostOrderVisit(root_, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
// Check if the value is a direct load from another buffer (i.e., b[i]
// = a[i])
if (const auto *load = store->value.as<BufferLoadNode>()) {
has_pure_buffer_store = true;
}
}
});
if (read_source_buffer.defined() && allow_layout_propgate) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
// // Loop don't need to be replicated.
// if (!is_one(loop_layout_->ReplicateExtent()))
// loop_layout_ = loop_layout_->DeReplicate();
// For free layout inference
// If replication exists and buffer has cross-thread shared memory access,
// add predicate
bool has_cross_thread_access = false;
PostOrderVisit(root_, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
// check if scope is shared or global
if (store->buffer.scope() == "shared" ||
store->buffer.scope() == "shared.dyn" ||
store->buffer.scope() == "global") {
has_cross_thread_access = true;
}
} else if (const auto *load = obj.as<BufferLoadNode>()) {
// check if scope is shared or global
if (load->buffer.scope() == "shared" ||
load->buffer.scope() == "shared.dyn" ||
load->buffer.scope() == "global") {
has_cross_thread_access = true;
}
}
});
// check if loop body contains a "pure" buffer store (i.e., direct
// assignment, not compound update)
bool has_pure_buffer_store = false;
PostOrderVisit(root_, [&](const ObjectRef &obj) {
if (const auto *store = obj.as<BufferStoreNode>()) {
// Check if the value is a direct load from another buffer (i.e., b[i]
// = a[i])
if (const auto *load = store->value.as<BufferLoadNode>()) {
has_pure_buffer_store = true;
}
}
});
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access &&
!has_pure_buffer_store) {
auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
fwd.push_back(0);
fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0));
}
}
if (!loop_layout_.defined()) {
......@@ -478,6 +477,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = "
<< loop_layout_->DebugOutput() << '\n';
}
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access &&
!has_pure_buffer_store) {
auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
fwd.push_back(0);
fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0));
}
} else {
return {};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment