Commit 97d63fab authored by Lei Wang's avatar Lei Wang Committed by LeiWang1999
Browse files

[Layout] Enhance layout inference pass (#427)

* [Enhancement] Improve layout inference in Copy operation (#426)

* Updated the Copy operation to infer layouts at multiple levels (kCommon, kStrict, kFree) for enhanced flexibility in layout optimization.
* Added detailed documentation for layout inference levels in ParallelOp, clarifying their purposes and use cases.
* Refactored layout inference logic to accommodate new levels, improving overall robustness and performance in parallel operations.

* lint fix
parent afa74f4e
......@@ -160,9 +160,13 @@ Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
if (is_cpu_target) {
vectorized_thread_loop = VectorizeLoop(fused_loop);
} else {
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
for (auto level : levels) {
par_op->InferLayout(
{T.target, T.thread_bounds, T.layout_map, T.buffer_remap},
InferLevel::kFree);
{T.target, T.thread_bounds, T.layout_map, T.buffer_remap}, level);
}
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
vectorized_thread_loop = VectorizeLoop(thread_loop);
......
......@@ -122,6 +122,24 @@ bool ParallelOp::IsCommonAccessIndice(const Buffer &buffer) const {
return StructuralEqual()(indice_map_[buffer], common_indice);
}
/*! \brief Infer the layout for parallel operations based on different inference
* levels
*
* The inference level controls how aggressively we try to infer and optimize
* layouts:
* - kStrict (2): Most conservative level. Only allows explicitly defined
* layouts. Returns empty layout map if loop_layout_ is not already defined.
* Used when exact layout control is required.
*
* - kCommon (1): Intermediate level between strict and free.
* Allows common layout patterns while maintaining some
* constraints.
*
* - kFree (0): Most permissive level. Allows maximum optimization freedom.
* Will attempt layout inference even without source buffers.
* Can generate new layouts based on vectorization and thread
* bounds. Used when maximum performance optimization is desired.
*/
LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
if (loop_layout_.defined())
return {};
......@@ -163,6 +181,8 @@ LayoutMap ParallelOp::InferLayout(const LayoutInferArgs &T, InferLevel level) {
};
if (source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(source_buffer);
} else if (read_source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
} else if (level == InferLevel::kFree) {
if (read_source_buffer.defined()) {
loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer);
......
......@@ -35,6 +35,7 @@ extern "C" const char* get_last_error() {{
extern "C" int init() {{
error_buf[0] = '\\0';
{0}
return 0;
}}
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment