Unverified Commit 7045f1d6 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Simplify logic in the `CompleteBufferFragment` (#1226)



* fix

* Fix logging level in LayoutNode::InverseWithLevel method from WARNING to DLOG for symbolic layout fallback.

* lint fix

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent 67cc8611
...@@ -250,7 +250,7 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const { ...@@ -250,7 +250,7 @@ std::pair<Layout, arith::IterMapLevel> LayoutNode::InverseWithLevel() const {
if (!is_static_shape) { if (!is_static_shape) {
// Runtime guards keep dynamic tails safe, so we allow NoCheck here and // Runtime guards keep dynamic tails safe, so we allow NoCheck here and
// warn. // warn.
LOG(WARNING) << "Layout::Inverse on symbolic layout, falling back to " DLOG(WARNING) << "Layout::Inverse on symbolic layout, falling back to "
"NoCheck; symbolic dims: " "NoCheck; symbolic dims: "
<< symbolic_dims; << symbolic_dims;
} }
......
...@@ -649,37 +649,8 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const { ...@@ -649,37 +649,8 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const {
DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_)); DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_));
auto bijective_indice = indice_map_[buffer]; auto bijective_indice = indice_map_[buffer];
bijective_indice.push_back(rep_b); bijective_indice.push_back(rep_b);
Layout layout_before_inv = Layout(loop_vars_, bijective_indice); Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse();
// Pre-check cardinality to guard non-bijective combinations after adding
// rep_b.
PrimExpr in_prod = 1;
for (const auto &iv : loop_vars_)
in_prod *= iv->dom->extent;
PrimExpr out_prod = 1;
for (const auto &d : layout_before_inv->OutputShape())
out_prod *= d;
if (!analyzer_.CanProveEqual(in_prod, out_prod)) {
DLOG(WARNING) << " Non-bijective mapping after appending rep_b; falling "
"back to no-rep inversion.";
Layout ind_inv_fallback =
Layout(loop_vars_, indice_map_[buffer])->Inverse();
PrimExpr indice_rep_extent = 1;
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent;
Array<PrimExpr> fwd2;
for (size_t i = 0; i < buffer->shape.size(); i++) {
fwd2.push_back(InputPlaceholder(i));
}
PrimExpr thd_b = loop_layout_->ForwardThread(
ind_inv_fallback->Forward(fwd2), std::nullopt);
return Fragment(buffer->shape, {}, thd_b, dest_buffer_rep_extent,
std::nullopt)
->CondenseReplicateVar();
}
Layout ind_inv = layout_before_inv->Inverse();
PrimExpr indice_rep_extent = PrimExpr indice_rep_extent =
ind_inv->InputShape().back(); // this is the size of rep_b ind_inv->InputShape().back(); // this is the size of rep_b
PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent(); PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment