Commit 18889821 authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Enhancement] Add strict layout map for improved buffer layout inference (#594)

- Introduced a `strict_layout_map` to enhance layout inference by ensuring that buffers with strict layout requirements are properly accounted for during the inference process.
- Updated the inference logic to check for the presence of buffers in the `strict_layout_map` before applying layout changes, improving the accuracy of layout assignments.
- Refactored the layout inference steps to include the copying of layouts into the new strict map, ensuring a clear separation of layout handling based on inference levels.
parent a6b52c52
...@@ -225,6 +225,7 @@ public: ...@@ -225,6 +225,7 @@ public:
// Copy the annotated layout map to local variable // Copy the annotated layout map to local variable
Map<Buffer, Layout> layout_map = annotated_layout_map_; Map<Buffer, Layout> layout_map = annotated_layout_map_;
Map<Buffer, Layout> strict_layout_map;
int num_infer = infer_list_.size(); int num_infer = infer_list_.size();
// Prepare BFS queue for iterative inference // Prepare BFS queue for iterative inference
...@@ -242,6 +243,7 @@ public: ...@@ -242,6 +243,7 @@ public:
} }
q.push(i); q.push(i);
} }
auto run_infer_step = [&](int cur_infer_id, InferLevel level, auto run_infer_step = [&](int cur_infer_id, InferLevel level,
bool update_queue) { bool update_queue) {
// Range check for cur_infer_id // Range check for cur_infer_id
...@@ -287,7 +289,8 @@ public: ...@@ -287,7 +289,8 @@ public:
if (layout_map.count(buffer)) { if (layout_map.count(buffer)) {
// If replicate size of this buffer is greater than the old one // If replicate size of this buffer is greater than the old one
if (buffer.scope() == "local.fragment" && if (buffer.scope() == "local.fragment" &&
level != InferLevel::kStrict) { level != InferLevel::kStrict &&
!strict_layout_map.count(buffer)) {
const FragmentNode *dst_layout = layout.as<Fragment>().get(); const FragmentNode *dst_layout = layout.as<Fragment>().get();
const FragmentNode *src_layout = const FragmentNode *src_layout =
layout_map[buffer].as<Fragment>().get(); layout_map[buffer].as<Fragment>().get();
...@@ -355,6 +358,10 @@ public: ...@@ -355,6 +358,10 @@ public:
run_infer_step(i, InferLevel::kStrict, false); run_infer_step(i, InferLevel::kStrict, false);
} }
for (const auto &[buffer, layout] : layout_map) {
strict_layout_map.Set(buffer, layout);
}
// step 2: infer common layout with BFS // step 2: infer common layout with BFS
finish_infer_queue(); finish_infer_queue();
...@@ -363,7 +370,6 @@ public: ...@@ -363,7 +370,6 @@ public:
run_infer_step(i, InferLevel::kFree, true); run_infer_step(i, InferLevel::kFree, true);
finish_infer_queue(); finish_infer_queue();
} }
// Check that all local.fragment buffers have inferred layouts // Check that all local.fragment buffers have inferred layouts
for (const auto &[buffer, _] : use_list_) { for (const auto &[buffer, _] : use_list_) {
if (buffer.scope() == "local.fragment") { if (buffer.scope() == "local.fragment") {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment