Unverified Commit eed320f5 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Recover code for flexible parallel (#1032)



* recover flex parallel process

* lint fix

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent 1e8f0b18
...@@ -307,8 +307,10 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -307,8 +307,10 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
// (const index frag_a interacts with non-const index frag_b) // (const index frag_a interacts with non-const index frag_b)
// - No propagation needed: shared_a[i] = frag_a[0] // - No propagation needed: shared_a[i] = frag_a[0]
// (const index frag_a with non-fragment buffer) // (const index frag_a with non-fragment buffer)
bool allow_layout_propgate = bool allow_layout_propgate =
fragment_buffers.size() > const_index_fragment_buffer.size(); const_index_fragment_buffer.empty() ||
(fragment_buffers.size() > const_index_fragment_buffer.size());
// Step 1: try to infer loop's partition from a source fragment // Step 1: try to infer loop's partition from a source fragment
Buffer source_buffer, read_source_buffer; Buffer source_buffer, read_source_buffer;
...@@ -361,7 +363,15 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -361,7 +363,15 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
PrimExpr loop_var_to_thread = PrimExpr loop_var_to_thread =
src_layout->ForwardThread(indice_map_[buffer], rep); src_layout->ForwardThread(indice_map_[buffer], rep);
loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread); loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread);
PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) {
if (auto opt_var = objref.as<Var>();
opt_var && inner_vars_.count(*opt_var)) {
std::ostringstream oss;
oss << "loop_var_to_thread = " << loop_var_to_thread
<< "contains inner var" << *opt_var;
throw LayoutConflictException(oss.str());
}
});
result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter) result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter)
->BindThreadRange(T.thread_bounds); ->BindThreadRange(T.thread_bounds);
} }
...@@ -379,12 +389,6 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -379,12 +389,6 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
if (source_buffer.defined() && allow_layout_propgate) { if (source_buffer.defined() && allow_layout_propgate) {
loop_layout_ = compute_loop_layout_from_buffer(source_buffer); loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
} else if (level == InferLevel::kFree) { } else if (level == InferLevel::kFree) {
if (read_source_buffer.defined() && allow_layout_propgate) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
// // Loop don't need to be replicated.
// if (!is_one(loop_layout_->ReplicateExtent()))
// loop_layout_ = loop_layout_->DeReplicate();
// For free layout inference // For free layout inference
// If replication exists and buffer has cross-thread shared memory access, // If replication exists and buffer has cross-thread shared memory access,
// add predicate // add predicate
...@@ -420,16 +424,11 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -420,16 +424,11 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
} }
}); });
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access && if (read_source_buffer.defined() && allow_layout_propgate) {
!has_pure_buffer_store) { loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
auto inv = loop_layout_->Inverse(); // // Loop don't need to be replicated.
Array<PrimExpr> fwd; // if (!is_one(loop_layout_->ReplicateExtent()))
for (size_t i = 0; i < loop_layout_->OutputDim(); i++) // loop_layout_ = loop_layout_->DeReplicate();
fwd.push_back(0);
fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0));
}
} }
if (!loop_layout_.defined()) { if (!loop_layout_.defined()) {
...@@ -478,6 +477,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, ...@@ -478,6 +477,16 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = " DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = "
<< loop_layout_->DebugOutput() << '\n'; << loop_layout_->DebugOutput() << '\n';
} }
if (!is_one(loop_layout_->ReplicateExtent()) && has_cross_thread_access &&
!has_pure_buffer_store) {
auto inv = loop_layout_->Inverse();
Array<PrimExpr> fwd;
for (size_t i = 0; i < loop_layout_->OutputDim(); i++)
fwd.push_back(0);
fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min);
auto rep = inv->Forward(fwd).back();
AddPredicate(EQ(rep, 0));
}
} else { } else {
return {}; return {};
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment