Unverified Commit f2858fa1 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Refactor inflight computing to support dynamic pipeline extents (#1399)

* [Build] Update CMake configuration for tilelang_cython_wrapper installation

- Adjusted output directories for the tilelang_cython_wrapper to ensure that development builds place the extension in build/lib.
- Updated installation paths to place the extension in tilelang/lib within the wheel, improving organization and avoiding potential conflicts with other modules.
- Modified the internal library path exposure in env.py to prevent shadowing of common module names, enhancing compatibility and usability in user projects.

* [Build] Standardize output directories for tilelang libraries

- Set output directories for both tilelang and tilelang_module libraries to "${CMAKE_BINARY_DIR}/lib" for consistency in development builds.
- This change enhances organization and ensures that all build artifacts are located in a unified directory structure.

* [Refactor] Update TVM subproject and enhance pipeline loop handling

- Updated the TVM subproject to commit 90581fe9e5287bbcf1844ad14255a1e1e8cdf7f0.
- Added new fields to `PipelineAnnotation` and `RewrittenBlockInfo` structures to track original statement indices and improve async state management.
- Refactored `EmitImpl` and `PopulateWaitCounts` methods to enhance clarity and functionality, including better handling of commit groups and wait counts.
- Simplified access index calculations and strengthened analyzer constraints for loop bounds.

* [Cleanup] Remove license block and unused includes from inject_pipeline.cc

- Eliminated the Apache license block from the top of the file to streamline the code.
- Removed unused include directives for memory and stringstream to enhance code clarity and reduce unnecessary dependencies.

* [Refactor] Enhance transformation pipeline and test execution

- Added an additional Simplify transformation in the InjectSoftwarePipeline to improve optimization.
- Updated the test file to call `test_trival_pipeline()` directly, commenting out the previous main execution for better test isolation.
parent bc084aa4
Subproject commit 3a32b763e9d8393b14e4d0f824b2846f70041bc1 Subproject commit 90581fe9e5287bbcf1844ad14255a1e1e8cdf7f0
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*! /*!
* \file inject_software_pipeline.cc * \file inject_software_pipeline.cc
* \brief Transform annotated loops into pipelined one that parallelize * \brief Transform annotated loops into pipelined one that parallelize
...@@ -79,6 +60,8 @@ struct PipelineAnnotation { ...@@ -79,6 +60,8 @@ struct PipelineAnnotation {
int stage; int stage;
int order; int order;
bool async; bool async;
// Index of the statement in the original loop body order (SeqStmt order)
int original_idx = -1;
}; };
using PipelineInfo = std::unordered_map<Block, PipelineAnnotation, using PipelineInfo = std::unordered_map<Block, PipelineAnnotation,
...@@ -304,15 +287,17 @@ public: ...@@ -304,15 +287,17 @@ public:
} }
// Step 2: Emit the pipeline prologue, body and epilogue. // Step 2: Emit the pipeline prologue, body and epilogue.
Stmt prologue = EmitImpl(pipeline_loop_->min, Stmt prologue =
pipeline_loop_->min + max_stage_, true, true); EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true,
Stmt body = true, false);
EmitImpl(pipeline_loop_->min + max_stage_, Stmt body = EmitImpl(pipeline_loop_->min + max_stage_,
pipeline_loop_->min + pipeline_loop_->extent, false, false); pipeline_loop_->min + pipeline_loop_->extent, false,
Stmt epilogue = EmitImpl( false, false);
pipeline_loop_->min + pipeline_loop_->extent,
pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true); Stmt epilogue =
EmitImpl(pipeline_loop_->min + pipeline_loop_->extent,
pipeline_loop_->min + pipeline_loop_->extent + max_stage_,
true, true, true);
SeqStmt stmt = SeqStmt({prologue, body, epilogue}); SeqStmt stmt = SeqStmt({prologue, body, epilogue});
// Step 3: Make a new block that contains new buffer allocations after // Step 3: Make a new block that contains new buffer allocations after
...@@ -515,12 +500,16 @@ private: ...@@ -515,12 +500,16 @@ private:
// A symbolic expression representing the index the latest async operation // A symbolic expression representing the index the latest async operation
// associated with this stage has written into, at the "current" iteration. // associated with this stage has written into, at the "current" iteration.
Optional<PrimExpr> producer_head; Optional<PrimExpr> producer_head;
// the commit block's predicate
PrimExpr commit_predicate{nullptr};
}; };
/*! Structure holding intermediate information for pipeline loop rewriting. */ /*! Structure holding intermediate information for pipeline loop rewriting. */
struct RewrittenBlockInfo { struct RewrittenBlockInfo {
int stage; int stage;
int order; int order;
PrimExpr start;
PrimExpr end;
PrimExpr predicate; PrimExpr predicate;
Block block; Block block;
PrimExpr access_index; PrimExpr access_index;
...@@ -528,56 +517,103 @@ private: ...@@ -528,56 +517,103 @@ private:
}; };
void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks, void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks,
std::map<int, AsyncStateLocal> *async_states_local) { std::map<int, AsyncStateLocal> *async_states_local,
bool is_epilogue = false) {
// Precompute which orders are present in this emit, and their access_index
std::unordered_map<int, PrimExpr> order_to_access_index;
std::unordered_set<int> present_orders;
for (const auto &nb : new_blocks) {
order_to_access_index[nb.order] = nb.access_index;
present_orders.insert(nb.order);
}
for (size_t i = 0; i < new_blocks.size(); ++i) { for (size_t i = 0; i < new_blocks.size(); ++i) {
// 1. Find the unique async producer stage
int producer_stage_idx = -1; int producer_stage_idx = -1;
for (auto read_region : new_blocks[i].block->reads) { for (const auto &read_region : new_blocks[i].block->reads) {
for (const auto &[stage, state] : async_states) { for (const auto &[stage, state] : async_states) {
if (stage <= new_blocks[i].stage && if (stage <= new_blocks[i].stage &&
state.writes(read_region->buffer)) { state.writes(read_region->buffer)) {
// Found an earlier stage where read_region->buffer was // Currently only a single async stage dependency is supported
// asynchronously written
ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage) ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage)
<< "A dependency on multiple async stages is not supported"; << "A dependency on multiple async stages is not supported";
producer_stage_idx = stage; producer_stage_idx = stage;
} }
} }
} }
if (producer_stage_idx == -1) if (producer_stage_idx == -1) {
// This block does not depend on any async producer
continue; continue;
}
const auto &state = async_states[producer_stage_idx]; const auto &state = async_states[producer_stage_idx];
auto &dep_local_state = (*async_states_local)[producer_stage_idx]; auto &dep_local_state = (*async_states_local)[producer_stage_idx];
PrimExpr in_flight_cnt = 0;
for (const auto &group : state.commit_groups) {
PrimExpr consumer_head = new_blocks[i].access_index;
PrimExpr producer_head;
if (dep_local_state.producer_head.defined()) {
producer_head = dep_local_state.producer_head.value();
// if the group is after the wait point, minus by 1
if (group.front() > new_blocks[i].order)
producer_head -= 1;
} else {
producer_head = state.producer_head;
}
in_flight_cnt += producer_head - consumer_head;
}
// We can relax the in-flight-count by the number of independent commit. // 2. Use buffer_to_commit_group_ to find all actually dependent commit
// groups
std::unordered_set<int> dependent_groups; std::unordered_set<int> dependent_groups;
for (const auto &read_region : new_blocks[i].block->reads) { for (const auto &read_region : new_blocks[i].block->reads) {
if (state.buffer_to_commit_group_.count(read_region->buffer.get())) auto it = state.buffer_to_commit_group_.find(read_region->buffer.get());
dependent_groups.insert( if (it != state.buffer_to_commit_group_.end()) {
state.buffer_to_commit_group_.at(read_region->buffer.get())); dependent_groups.insert(it->second);
}
} }
for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) {
if (dependent_groups.count(i) == 0) // If there is no dependent commit group, no wait needs to be inserted
in_flight_cnt += 1; if (dependent_groups.empty()) {
else continue;
break; // stop relaxing }
// 3. Compute wait = max_g max(0, t_consumer - committed_before[g])
PrimExpr t_consumer = new_blocks[i].access_index;
PrimExpr wait_expr = make_zero(t_consumer.dtype());
PrimExpr current_head = dep_local_state.producer_head.defined()
? dep_local_state.producer_head.value()
: state.producer_head;
int consumer_order = new_blocks[i].order;
for (int g : dependent_groups) {
const auto &group = state.commit_groups[g];
if (group.empty())
continue;
int commit_order = group.back();
bool commit_present = present_orders.count(commit_order) > 0;
PrimExpr committed_before;
if (commit_present && commit_order <= consumer_order) {
// Commit point is in this iteration and earlier than the current
// consumer; this iteration's head is visible
auto commit_predicate = dep_local_state.commit_predicate;
if (analyzer_.CanProve(!commit_predicate,
arith::ProofStrength::kSymbolicBound)) {
// it means the commit block is not executed in this iteration
committed_before = new_blocks[i].start - 1;
} else if (is_epilogue) {
committed_before = new_blocks[i].start - 1;
} else {
committed_before = order_to_access_index.at(commit_order);
}
} else {
// Commit point is later than the current consumer or not in this
// iteration; only the previous iteration's head is visible
if (dep_local_state.producer_head.defined()) {
auto commit_predicate = dep_local_state.commit_predicate;
if (analyzer_.CanProve(!commit_predicate,
arith::ProofStrength::kSymbolicBound)) {
committed_before = new_blocks[i].start - 1;
} else if (is_epilogue) {
committed_before = new_blocks[i].start - 1;
} else {
committed_before = current_head - 1;
}
}
}
wait_expr = analyzer_.Simplify(committed_before - t_consumer);
} }
in_flight_cnt = analyzer_.Simplify(in_flight_cnt);
dep_local_state.pending_waits.push_back( wait_expr = analyzer_.Simplify(wait_expr);
{static_cast<int>(i), in_flight_cnt}); dep_local_state.pending_waits.push_back({static_cast<int>(i), wait_expr});
} }
} }
...@@ -630,7 +666,7 @@ private: ...@@ -630,7 +666,7 @@ private:
* \return The result loop. * \return The result loop.
*/ */
Stmt EmitImpl(const PrimExpr &start, const PrimExpr &end, bool unroll_loop, Stmt EmitImpl(const PrimExpr &start, const PrimExpr &end, bool unroll_loop,
bool need_bound_check) { bool need_bound_check, bool is_epilogue = false) {
PrimExpr new_loop_var; PrimExpr new_loop_var;
PrimExpr extent = end - start; PrimExpr extent = end - start;
auto make_nop = []() { auto make_nop = []() {
...@@ -642,7 +678,20 @@ private: ...@@ -642,7 +678,20 @@ private:
new_loop_var = start; // use constants as the loop var for unit loops new_loop_var = start; // use constants as the loop var for unit loops
} else { } else {
new_loop_var = pipeline_loop_->loop_var.copy_with_suffix(""); new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end)); // Bind the iteration domain [start, end) to strengthen analyzer facts.
analyzer_.Bind(Downcast<Var>(new_loop_var),
Range::FromMinExtent(start, end - start));
}
// Keep the bound constraints active for all analysis below.
// Only meaningful when the loop var is symbolic (non-unit loop).
std::unique_ptr<With<arith::ConstraintContext>> ctx_lb_guard;
std::unique_ptr<With<arith::ConstraintContext>> ctx_ub_guard;
if (!is_unit_loop) {
Var loop_iter = Downcast<Var>(new_loop_var);
ctx_lb_guard.reset(
new With<arith::ConstraintContext>(&analyzer_, loop_iter >= start));
ctx_ub_guard.reset(
new With<arith::ConstraintContext>(&analyzer_, loop_iter < end));
} }
std::vector<RewrittenBlockInfo> new_blocks; std::vector<RewrittenBlockInfo> new_blocks;
...@@ -653,15 +702,14 @@ private: ...@@ -653,15 +702,14 @@ private:
for (const Block &block : ordered_stmts_) { for (const Block &block : ordered_stmts_) {
int stage = pipeline_info_.at(block).stage; int stage = pipeline_info_.at(block).stage;
int order = pipeline_info_.at(block).order; int order = pipeline_info_.at(block).order;
PrimExpr inbound = Bool(true); PrimExpr inbound = Bool(true);
PrimExpr skewed_loop_var = new_loop_var - stage; PrimExpr skewed_loop_var = new_loop_var - stage;
if (need_bound_check) if (need_bound_check)
inbound = inbound = And(
analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) && pipeline_loop_->min <= skewed_loop_var,
(skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent); (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent));
if (analyzer_.CanProve(!inbound)) {
continue;
}
Block new_block = Downcast<Block>( Block new_block = Downcast<Block>(
PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_, PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
pipeline_loop_, max_stage_ != 1)(block)); pipeline_loop_, max_stage_ != 1)(block));
...@@ -674,6 +722,8 @@ private: ...@@ -674,6 +722,8 @@ private:
PrimExpr normalized_access_index = PrimExpr normalized_access_index =
is_unit_loop ? skewed_loop_var : skewed_loop_var + delta; is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
normalized_access_index = analyzer_.Simplify(normalized_access_index);
// Adjust the block predicate and the body according to the final loop // Adjust the block predicate and the body according to the final loop
// bound // bound
// [pipeline_loop_->min, extent). // [pipeline_loop_->min, extent).
...@@ -701,17 +751,18 @@ private: ...@@ -701,17 +751,18 @@ private:
if (pipeline_info_[block].async) { if (pipeline_info_[block].async) {
auto &local_state = async_states_local[stage]; auto &local_state = async_states_local[stage];
local_state.producer_head = normalized_access_index; local_state.producer_head = normalized_access_index;
local_state.commit_predicate = inbound;
BlockNode *n = new_block.CopyOnWrite(); BlockNode *n = new_block.CopyOnWrite();
n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope, n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope,
1, n->body); 1, n->body);
} }
new_blocks.push_back({stage, order, inbound, new_block, new_blocks.push_back({stage, order, start, end, inbound, new_block,
normalized_access_index, normalized_access_index,
pipeline_info_[block].async}); pipeline_info_[block].async});
} }
PopulateWaitCounts(new_blocks, &async_states_local); PopulateWaitCounts(new_blocks, &async_states_local, is_epilogue);
auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local); auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);
...@@ -1008,7 +1059,8 @@ private: ...@@ -1008,7 +1059,8 @@ private:
pipeline_async_stages.find(stage) != pipeline_async_stages.end(); pipeline_async_stages.find(stage) != pipeline_async_stages.end();
PipelineAnnotation stage_order{ PipelineAnnotation stage_order{
stage, stage,
/*order=*/static_cast<int>(pipeline_orders[i]->value), is_async}; /*order=*/static_cast<int>(pipeline_orders[i]->value), is_async,
/*original_idx=*/static_cast<int>(i)};
pipeline_info.emplace(original_order[i], stage_order); pipeline_info.emplace(original_order[i], stage_order);
} }
......
...@@ -10,6 +10,7 @@ def _check(original, transformed): ...@@ -10,6 +10,7 @@ def _check(original, transformed):
mod = tl.transform.InjectSoftwarePipeline()(mod) mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tl.transform.Simplify()(mod) mod = tl.transform.Simplify()(mod)
mod = tl.transform.LowerOpaqueBlock()(mod) mod = tl.transform.LowerOpaqueBlock()(mod)
mod = tl.transform.Simplify()(mod)
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"),
True) True)
......
...@@ -217,6 +217,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: ...@@ -217,6 +217,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.InjectFenceProxy()(mod) mod = tilelang.transform.InjectFenceProxy()(mod)
mod = tilelang.transform.LowerOpaqueBlock()(mod) mod = tilelang.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.Simplify()(mod)
mod = tir.transform.NarrowDataType(32)(mod) mod = tir.transform.NarrowDataType(32)(mod)
mod = tilelang.transform.FlattenBuffer()(mod) mod = tilelang.transform.FlattenBuffer()(mod)
# ConfigIndexBitwidth must be applied after FlattenBuffer # ConfigIndexBitwidth must be applied after FlattenBuffer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment