Unverified Commit 407117e1 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Layout] Introduce a new layout inference mechanism (#699)



* Implement new free stage layout inference.

* Fix bug

* Make replication upcasting and unnormalizable iterators safe.

* Better handling of updating with more replica

* Remove unnecessary check.

* Fix compilation.

* Fix setup.py.

* Simplify development mode.

* Allow ParallelOp layout when there's already a compatible layout specified

* lint fix

* Add ProveFragmentContains function to validate thread access between small and large fragments

This function checks if the threads accessing elements of a smaller fragment are a subset of those accessing a larger fragment, ensuring valid access during updates. The implementation includes deriving thread indices, computing logical indices, and verifying thread mappings.

* Update dependencies in requirements files

* Remove 'thefuzz' from requirements-dev.txt
* Specify exact versions for 'torch' and add 'flash_attn' in requirements-test.txt

* Update CI workflow to use SHA256 hash for requirements file

* Update requirements and CI workflow for flash attention

* Removed specific version for 'torch' in requirements-test.txt
* Added installation of 'flash_attn==2.5.8' in CI workflow to ensure compatibility

* Refactor flash attention import handling in examples

* Removed availability checks for 'flash_attn' in multiple example scripts.
* Simplified import statements for 'flash_attn' to ensure consistent usage across examples.

---------
Co-authored-by: default avatarHuanqi Cao <caohuanqi@deepseek.com>
parent 87aae294
...@@ -17,6 +17,20 @@ namespace tl { ...@@ -17,6 +17,20 @@ namespace tl {
using namespace tir; using namespace tir;
class LayoutConflictException : public std::exception {
public:
const char *what() const noexcept override { return msg_.c_str(); }
LayoutConflictException(const std::string &msg) : msg_(msg) {}
private:
std::string msg_;
};
bool ProveFragmentContains(Fragment small_frag, Fragment large_frag,
Array<PrimExpr> small_frag_indices,
Array<PrimExpr> large_frag_indices,
arith::Analyzer &analyzer_);
class ParallelOp; class ParallelOp;
class ParallelLoopNestVisitor : public StmtExprVisitor { class ParallelLoopNestVisitor : public StmtExprVisitor {
...@@ -36,6 +50,14 @@ public: ...@@ -36,6 +50,14 @@ public:
ParallelOp(For root); ParallelOp(For root);
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
ParallelOp(const ParallelOp &other) : ParallelOp(other.root_) {
loop_layout_ = other.loop_layout_;
predicate_ = other.predicate_;
}
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<ParallelOp>(*this);
}
Fragment GetLoopLayout() const { return loop_layout_; } Fragment GetLoopLayout() const { return loop_layout_; }
For GetRoot() const { return root_; } For GetRoot() const { return root_; }
Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; } Map<Buffer, Array<PrimExpr>> GetIndiceMap() const { return indice_map_; }
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include "../layout/utils.h" #include "../layout/utils.h"
#include "../op/parallel.h"
#include "../transform/loop_partition.h" #include "../transform/loop_partition.h"
#include "tir/transforms/ir_utils.h" #include "tir/transforms/ir_utils.h"
...@@ -287,7 +288,7 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -287,7 +288,7 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (level >= InferLevel::kStrict) if (level >= InferLevel::kStrict)
return {}; return {};
if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" && if (src.scope() == "local.fragment" && dst.scope() == "local.fragment" &&
T.layout_map.count(src) && !T.layout_map.count(dst)) { T.layout_map.count(src)) {
auto src_layout = T.layout_map[src].as<Fragment>().value(); auto src_layout = T.layout_map[src].as<Fragment>().value();
PrimExpr indice_rep_extent = src->shape[dim]; PrimExpr indice_rep_extent = src->shape[dim];
...@@ -310,7 +311,46 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) { ...@@ -310,7 +311,46 @@ LayoutMap ReduceOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt) Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt)
->CondenseReplicateVar() ->CondenseReplicateVar()
->BindThreadRange(T.thread_bounds); ->BindThreadRange(T.thread_bounds);
if (!T.layout_map.count(dst))
return {{dst, dst_layout}}; return {{dst, dst_layout}};
else {
// Check if computed layout is compatible with existing: the existing one
// must strictly contains the computed layout
auto orig_dst_layout =
T.layout_map.Get(dst).value().as<Fragment>().value();
ICHECK(dst_layout->InputDim() == orig_dst_layout->InputDim());
Array<PrimExpr> indices;
indices.reserve(dst_layout->InputDim());
arith::Analyzer inner_analyzer;
for (int i = 0; i < dst_layout->InputDim(); ++i) {
auto x = InputPlaceholder(i);
indices.push_back(x);
// should be literal - literal = 0, any analyzer will work
ICHECK(is_zero(inner_analyzer.Simplify(
dst_layout->InputShape()[i] - orig_dst_layout->InputShape()[i])));
inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i]));
}
ICHECK(as_const_int(dst_layout->ReplicateExtent()));
ICHECK(as_const_int(src_layout->ReplicateExtent()));
auto dst_rep = *as_const_int(dst_layout->ReplicateExtent());
auto src_rep = *as_const_int(src_layout->ReplicateExtent());
if (dst_rep < src_rep ||
!ProveFragmentContains(orig_dst_layout, dst_layout, indices, indices,
inner_analyzer)) {
std::ostringstream oss;
oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. "
<< src << "\nLHS = " << src_layout->DebugOutput()
<< "\nRHS = " << orig_dst_layout->DebugOutput()
<< "\nYou may need to use a shared memory to transform the "
"layout";
throw LayoutConflictException(oss.str());
}
if (dst_rep > src_rep) {
return {{dst, dst_layout}};
}
}
} }
return {}; return {};
} }
......
...@@ -21,6 +21,10 @@ public: ...@@ -21,6 +21,10 @@ public:
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get(); static const Op &Get();
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<ReduceOp>(*this);
}
private: private:
tir::Buffer src, dst; tir::Buffer src, dst;
int dim; int dim;
...@@ -45,6 +49,10 @@ public: ...@@ -45,6 +49,10 @@ public:
LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final; LayoutMap InferLayout(const LayoutInferArgs &T, InferLevel level) final;
static const Op &Get(); static const Op &Get();
std::unique_ptr<Operator> Clone() const final {
return std::make_unique<CumSumOp>(*this);
}
private: private:
tir::Buffer src, dst; tir::Buffer src, dst;
int dim; int dim;
......
#ifndef TVM_TL_TRANSFORM_COMMON_UNION_FIND_H_
#define TVM_TL_TRANSFORM_COMMON_UNION_FIND_H_
#include <unordered_map>
#include <vector>
namespace tvm {
namespace tl {
template <typename T> class UnionFind {
public:
void MakeSet(const T &x) {
if (parent_.find(x) == parent_.end()) {
parent_[x] = x;
rank_[x] = 0;
}
}
T Find(const T &x) {
if (parent_[x] != x) {
parent_[x] = Find(parent_[x]); // Path compression
}
return parent_[x];
}
void Union(const T &x, const T &y) {
T x_root = Find(x);
T y_root = Find(y);
if (x_root == y_root)
return;
// Union by rank
if (rank_[x_root] < rank_[y_root]) {
parent_[x_root] = y_root;
} else if (rank_[x_root] > rank_[y_root]) {
parent_[y_root] = x_root;
} else {
parent_[y_root] = x_root;
rank_[x_root]++;
}
}
private:
std::unordered_map<T, T> parent_;
std::unordered_map<T, int> rank_;
};
} // namespace tl
} // namespace tvm
#endif // TVM_TL_TRANSFORM_COMMON_UNION_FIND_H_
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
#include <queue> #include <queue>
#include "../layout/utils.h"
#include "../op/parallel.h" #include "../op/parallel.h"
#include "arith/ir_mutator_with_analyzer.h" #include "arith/ir_mutator_with_analyzer.h"
#include "arith/ir_visitor_with_analyzer.h" #include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_fusion_utils.h" #include "common/loop_fusion_utils.h"
#include "common/loop_parallel_transform_utils.h" #include "common/loop_parallel_transform_utils.h"
#include "common/union_find.h"
#include "loop_partition.h" #include "loop_partition.h"
#include "loop_vectorize.h" #include "loop_vectorize.h"
#include "runtime/thread_storage_scope.h" #include "runtime/thread_storage_scope.h"
...@@ -60,45 +62,13 @@ public: ...@@ -60,45 +62,13 @@ public:
BufferUseDefCollector(bool skip_thread_partition) BufferUseDefCollector(bool skip_thread_partition)
: skip_thread_partition_(skip_thread_partition) {} : skip_thread_partition_(skip_thread_partition) {}
LayoutInferenceResult Run() { void RunInferStep(int cur_infer_id, InferLevel level, bool update_queue,
// Basic consistency check: infer_list_ and thread_var_vec_ should have the LayoutMap &layout_map, const LayoutMap &strict_layout_map,
// same size std::queue<int> &q, std::vector<bool> &in_queue) {
ICHECK_EQ(infer_list_.size(), thread_var_vec_.size()) auto num_infer = infer_list_.size();
<< "Size mismatch: infer_list_ and thread_var_vec_ must match in "
"length.";
ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size())
<< "Size mismatch: thread_bounds_vec_ and infer_list_ must match in "
"length.";
// If needed, you can also check that annotated_layout_map_ is not empty, or
// anything else relevant to your setup.
// Copy the annotated layout map to local variable
Map<Buffer, Layout> layout_map = annotated_layout_map_;
Map<Buffer, Layout> strict_layout_map;
int num_infer = infer_list_.size();
// Prepare BFS queue for iterative inference
std::queue<int> q;
std::vector<bool> in_queue(num_infer, true);
for (int i = 0; i < num_infer; i++) {
// Check that each infer_list_ entry is valid
ICHECK(infer_list_[i] != nullptr)
<< "infer_list_[" << i
<< "] is null. The inference object is not allocated properly.";
// Check that each thread_var_vec_ entry is defined
if (!thread_var_vec_[i].defined() && skip_thread_partition_) {
thread_var_vec_[i] = thread_var_;
}
q.push(i);
}
auto run_infer_step = [&](int cur_infer_id, InferLevel level,
bool update_queue) {
// Range check for cur_infer_id // Range check for cur_infer_id
ICHECK_GE(cur_infer_id, 0) ICHECK_GE(cur_infer_id, 0) << "cur_infer_id is negative, which is invalid.";
<< "cur_infer_id is negative, which is invalid.";
ICHECK_LT(cur_infer_id, num_infer) ICHECK_LT(cur_infer_id, num_infer)
<< "cur_infer_id " << cur_infer_id << " is out of range, must be < " << "cur_infer_id " << cur_infer_id << " is out of range, must be < "
<< num_infer << "."; << num_infer << ".";
...@@ -109,8 +79,8 @@ public: ...@@ -109,8 +79,8 @@ public:
auto iter_var = thread_var_vec_[cur_infer_id]; auto iter_var = thread_var_vec_[cur_infer_id];
auto thread_bounds = thread_bounds_vec_[cur_infer_id]; auto thread_bounds = thread_bounds_vec_[cur_infer_id];
// Double-check that 'next' is valid // Double-check that 'next' is valid
ICHECK(next != nullptr) << "infer_list_[" << cur_infer_id ICHECK(next != nullptr)
<< "] is null inside run_infer_step."; << "infer_list_[" << cur_infer_id << "] is null inside run_infer_step.";
// Check iter_var->dom and dom->extent // Check iter_var->dom and dom->extent
ICHECK(iter_var.defined()) ICHECK(iter_var.defined())
...@@ -137,33 +107,37 @@ public: ...@@ -137,33 +107,37 @@ public:
ICHECK(layout.defined()) << "InferLayout returned an undefined layout."; ICHECK(layout.defined()) << "InferLayout returned an undefined layout.";
if (layout_map.count(buffer)) { if (layout_map.count(buffer)) {
// If replicate size of this buffer is greater than the old one // If new layout contains the old one, update map
if (buffer.scope() == "local.fragment" && if (buffer.scope() == "local.fragment" &&
level != InferLevel::kStrict) { level != InferLevel::kStrict && !strict_layout_map.count(buffer)) {
const FragmentNode *dst_layout = layout.as<FragmentNode>(); // Actually this test has been done in ParallelOp::InferLayout
const FragmentNode *src_layout = // already. Just do it again to avoid missing implementations in other
layout_map[buffer].as<FragmentNode>(); // `Operator`s.
if (as_const_int(dst_layout->ReplicateExtent()) && auto dst_layout = layout.as<Fragment>().value();
as_const_int(src_layout->ReplicateExtent()) && auto src_layout = layout_map[buffer].as<Fragment>().value();
(*as_const_int(dst_layout->ReplicateExtent()) > ICHECK(dst_layout->InputDim() == src_layout->InputDim());
*as_const_int(src_layout->ReplicateExtent()))) { Array<PrimExpr> indices;
// update map indices.reserve(dst_layout->InputDim());
arith::Analyzer inner_analyzer;
for (int i = 0; i < dst_layout->InputDim(); ++i) {
auto x = InputPlaceholder(i);
indices.push_back(x);
// should be literal - literal = 0, any analyzer will work
ICHECK(is_zero(inner_analyzer.Simplify(
dst_layout->InputShape()[i] - src_layout->InputShape()[i])));
inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i]));
}
if (ProveFragmentContains(src_layout, dst_layout, indices, indices,
inner_analyzer)) {
layout_map.Set(buffer, layout); layout_map.Set(buffer, layout);
continue; continue;
} }
} }
// If already in map, ensure they are structurally equal // If already in map, ensure they are structurally equal
// (zhengju) We can not modify the strict layout map when current
// level is not strict. This check should be done in certain
// conditions, since the strict layout map is not updated in the
// above code when current level is not strict
if (level == InferLevel::kStrict ||
!strict_layout_map.count(buffer)) {
ICHECK(StructuralEqual()(layout, layout_map[buffer])) ICHECK(StructuralEqual()(layout, layout_map[buffer]))
<< "Get different layout for " << buffer << "Get different layout for " << buffer
<< "\n current layout: " << layout->DebugOutput() << "\n current layout: " << layout->DebugOutput()
<< "\n previous layout: " << layout_map[buffer]->DebugOutput(); << "\n previous layout: " << layout_map[buffer]->DebugOutput();
}
} else { } else {
// Otherwise, update map // Otherwise, update map
layout_map.Set(buffer, layout); layout_map.Set(buffer, layout);
...@@ -181,8 +155,8 @@ public: ...@@ -181,8 +155,8 @@ public:
// Push back into BFS queue // Push back into BFS queue
for (int idx : use_list_[buffer]) { for (int idx : use_list_[buffer]) {
ICHECK_GE(idx, 0) << "Index in use_list_ for buffer " << buffer ICHECK_GE(idx, 0)
<< " is negative."; << "Index in use_list_ for buffer " << buffer << " is negative.";
ICHECK_LT(idx, num_infer) ICHECK_LT(idx, num_infer)
<< "Index in use_list_ for buffer " << buffer << "Index in use_list_ for buffer " << buffer
<< " out of range: " << idx << " >= " << num_infer << "."; << " out of range: " << idx << " >= " << num_infer << ".";
...@@ -196,7 +170,10 @@ public: ...@@ -196,7 +170,10 @@ public:
} }
}; };
auto finish_infer_queue = [&]() { void FinishInferQueue(InferLevel level, LayoutMap &layout_map,
const LayoutMap &strict_layout_map, std::queue<int> &q,
std::vector<bool> &in_queue) {
auto num_infer = infer_list_.size();
while (!q.empty()) { while (!q.empty()) {
int cur_infer_id = q.front(); int cur_infer_id = q.front();
q.pop(); q.pop();
...@@ -205,13 +182,49 @@ public: ...@@ -205,13 +182,49 @@ public:
ICHECK_LT(cur_infer_id, num_infer); ICHECK_LT(cur_infer_id, num_infer);
in_queue[cur_infer_id] = false; in_queue[cur_infer_id] = false;
run_infer_step(cur_infer_id, InferLevel::kCommon, true); RunInferStep(cur_infer_id, level, true, layout_map, strict_layout_map, q,
in_queue);
} }
}; };
LayoutInferenceResult Run() {
// Basic consistency check: infer_list_ and thread_var_vec_ should have the
// same size
ICHECK_EQ(infer_list_.size(), thread_var_vec_.size())
<< "Size mismatch: infer_list_ and thread_var_vec_ must match in "
"length.";
ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size())
<< "Size mismatch: thread_bounds_vec_ and infer_list_ must match in "
"length.";
// If needed, you can also check that annotated_layout_map_ is not empty, or
// anything else relevant to your setup.
// Copy the annotated layout map to local variable
Map<Buffer, Layout> layout_map = annotated_layout_map_;
Map<Buffer, Layout> strict_layout_map;
int num_infer = infer_list_.size();
// Prepare BFS queue for iterative inference
std::queue<int> q;
std::vector<bool> in_queue(num_infer, true);
for (int i = 0; i < num_infer; i++) {
// Check that each infer_list_ entry is valid
ICHECK(infer_list_[i] != nullptr)
<< "infer_list_[" << i
<< "] is null. The inference object is not allocated properly.";
// Check that each thread_var_vec_ entry is defined
if (!thread_var_vec_[i].defined() && skip_thread_partition_) {
thread_var_vec_[i] = thread_var_;
}
q.push(i);
}
// step 1: infer strict layout // step 1: infer strict layout
for (int i = 0; i < num_infer; i++) { for (int i = 0; i < num_infer; i++) {
run_infer_step(i, InferLevel::kStrict, false); RunInferStep(i, InferLevel::kStrict, false, layout_map, strict_layout_map,
q, in_queue);
} }
for (const auto &[buffer, layout] : layout_map) { for (const auto &[buffer, layout] : layout_map) {
...@@ -219,13 +232,12 @@ public: ...@@ -219,13 +232,12 @@ public:
} }
// step 2: infer common layout with BFS // step 2: infer common layout with BFS
finish_infer_queue(); FinishInferQueue(InferLevel::kCommon, layout_map, strict_layout_map, q,
in_queue);
// step 3: relax constraints to free and re-run // step 3: relax constraints to free and re-run
for (int i = 0; i < num_infer; i++) { InferInFreeMode(layout_map, strict_layout_map);
run_infer_step(i, InferLevel::kFree, true);
finish_infer_queue();
}
// Check that all local.fragment buffers have inferred layouts // Check that all local.fragment buffers have inferred layouts
for (const auto &[buffer, _] : use_list_) { for (const auto &[buffer, _] : use_list_) {
if (buffer.scope() == "local.fragment") { if (buffer.scope() == "local.fragment") {
...@@ -291,6 +303,7 @@ private: ...@@ -291,6 +303,7 @@ private:
addToUseList(buffer.value()); addToUseList(buffer.value());
} }
} }
infer_list_stmt_.push_back(GetRef<ObjectRef>(op));
infer_list_.push_back(std::move(p)); infer_list_.push_back(std::move(p));
thread_var_vec_.push_back(thread_var_); thread_var_vec_.push_back(thread_var_);
if (analyzer_.const_int_bound.IsBound(thread_var_->var)) { if (analyzer_.const_int_bound.IsBound(thread_var_->var)) {
...@@ -309,9 +322,14 @@ private: ...@@ -309,9 +322,14 @@ private:
Optional<Buffer> getBufferFromAccessPtr(const PrimExpr &expr) { Optional<Buffer> getBufferFromAccessPtr(const PrimExpr &expr) {
auto call = expr.as<CallNode>(); auto call = expr.as<CallNode>();
if (call && call->op.same_as(builtin::tvm_access_ptr())) { if (!call) {
return std::nullopt;
}
if (call->op.same_as(builtin::tvm_access_ptr())) {
auto var = call->args[1].as<Var>().value(); auto var = call->args[1].as<Var>().value();
return buffer_data_to_buffer_[var]; return buffer_data_to_buffer_[var];
} else if (call->op.same_as(RegionOp::Get())) {
return call->args[0].as<BufferLoadNode>()->buffer;
} }
return std::nullopt; return std::nullopt;
} }
...@@ -330,6 +348,7 @@ private: ...@@ -330,6 +348,7 @@ private:
for (const auto &[buffer, _] : infer->GetIndiceMap()) { for (const auto &[buffer, _] : infer->GetIndiceMap()) {
addToUseList(buffer); addToUseList(buffer);
} }
infer_list_stmt_.push_back(GetRef<ObjectRef>(op));
infer_list_.push_back(std::move(infer)); infer_list_.push_back(std::move(infer));
thread_var_vec_.push_back(thread_var_); thread_var_vec_.push_back(thread_var_);
if (thread_var_.defined() && if (thread_var_.defined() &&
...@@ -379,6 +398,7 @@ private: ...@@ -379,6 +398,7 @@ private:
} }
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
std::vector<ObjectRef> infer_list_stmt_;
std::vector<std::unique_ptr<Operator>> infer_list_; std::vector<std::unique_ptr<Operator>> infer_list_;
std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual> std::unordered_map<Buffer, std::vector<int>, ObjectPtrHash, ObjectPtrEqual>
use_list_; use_list_;
...@@ -391,6 +411,122 @@ private: ...@@ -391,6 +411,122 @@ private:
Target target_; Target target_;
LayoutMap annotated_layout_map_; LayoutMap annotated_layout_map_;
bool skip_thread_partition_{false}; bool skip_thread_partition_{false};
std::vector<std::unique_ptr<Operator>> BackupInferList() {
std::vector<std::unique_ptr<Operator>> back_infer_list;
back_infer_list.reserve(infer_list_.size());
for (auto &&p : infer_list_) {
back_infer_list.push_back(p->Clone());
}
return back_infer_list;
}
void InferInFreeMode(LayoutMap &layout_map,
const LayoutMap &strict_layout_map) {
// Group operators into connected components
UnionFind<int> uf;
for (int i = 0; i < infer_list_.size(); i++) {
uf.MakeSet(i);
}
for (const auto &[buffer, infer_indices] : use_list_) {
if (infer_indices.empty())
continue;
// Union all infer_list_ indices that share the same buffer
int first_idx = infer_indices[0];
for (size_t i = 1; i < infer_indices.size(); i++) {
uf.Union(first_idx, infer_indices[i]);
}
}
std::unordered_map<int, std::vector<int>> components;
for (int i = 0; i < infer_list_.size(); i++) {
int root = uf.Find(i);
components[root].push_back(i);
}
std::unordered_map<int, std::vector<Buffer>> components_buffers;
for (const auto &[buffer, infer_indices] : use_list_) {
int root = uf.Find(infer_indices[0]);
components_buffers[root].push_back(buffer);
}
// For each component, try each op as root, and determine the least
// replicated one
std::queue<int> q;
std::vector<bool> in_queue(infer_list_.size(), false);
for (auto &&[root, members] : components) {
decltype(infer_list_) best_infer_list;
LayoutMap best_layout_map;
int64_t min_reg_num = INT64_MAX;
for (int attempt_infer_root : members) {
// backup infer_list_ in class member
auto back_infer_list = BackupInferList();
// create temporarily used layout_map, new handle so that it copies on
// write
LayoutMap tmp_layout_map = layout_map;
// infer from attempt_infer_root in free mode
bool do_update = true;
try {
RunInferStep(attempt_infer_root, InferLevel::kFree, true,
tmp_layout_map, strict_layout_map, q, in_queue);
FinishInferQueue(InferLevel::kFree, tmp_layout_map, strict_layout_map,
q, in_queue);
// Silly workaround: we have no clue if single root will iterate over
// the entire component, since the InferLayout implementations have
// complicated conditioning inside and we know nothing about it.
// This would constantly result in incomplete layouts for buffers in
// this component. Instead of trying all combinations of root
// selection order, we simply go through all other loops in order
// after the first search from attempt_infer_root.
for (int other_infer_root : members) {
if (other_infer_root != attempt_infer_root) {
RunInferStep(other_infer_root, InferLevel::kFree, true,
tmp_layout_map, strict_layout_map, q, in_queue);
// must also be kFree here to avoid conflicts.
FinishInferQueue(InferLevel::kFree, tmp_layout_map,
strict_layout_map, q, in_queue);
}
}
} catch (LayoutConflictException e) {
// such an order fails, try others
do_update = false;
} catch (NormalizeIterException e) {
// such an order encounters iterators that is not normalizable, try
// others e.g. i * 576 % 2048
do_update = false;
}
if (do_update) {
// compute total register number
int64_t reg_num = 0;
for (auto &&[buffer, layout] : tmp_layout_map) {
if (auto frag = layout.as<Fragment>()) {
int64_t frag_reg_num = 1;
for (auto i : frag.value()->OutputShape()) {
auto pci = as_const_int(i);
ICHECK(pci != nullptr);
frag_reg_num *= *pci;
}
reg_num += frag_reg_num;
}
}
// if it's any better, update the best_* storage
if (reg_num < min_reg_num) {
best_infer_list = std::move(infer_list_);
best_layout_map = tmp_layout_map;
min_reg_num = reg_num;
}
}
// recover stateful infer_list_, head on next
infer_list_ = std::move(back_infer_list);
}
if (min_reg_num < INT64_MAX) {
// now apply the best plan for this component
infer_list_ = std::move(best_infer_list);
layout_map = best_layout_map;
}
}
}
}; };
class LayoutInferencer : public IRMutatorWithAnalyzer { class LayoutInferencer : public IRMutatorWithAnalyzer {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment