Unverified Commit f2858fa1 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Refactor inflight computing to support dynamic pipeline extents (#1399)

* [Build] Update CMake configuration for tilelang_cython_wrapper installation

- Adjusted output directories for the tilelang_cython_wrapper to ensure that development builds place the extension in build/lib.
- Updated installation paths to place the extension in tilelang/lib within the wheel, improving organization and avoiding potential conflicts with other modules.
- Modified the internal library path exposure in env.py to prevent shadowing of common module names, enhancing compatibility and usability in user projects.

* [Build] Standardize output directories for tilelang libraries

- Set output directories for both tilelang and tilelang_module libraries to "${CMAKE_BINARY_DIR}/lib" for consistency in development builds.
- This change enhances organization and ensures that all build artifacts are located in a unified directory structure.

* [Refactor] Update TVM subproject and enhance pipeline loop handling

- Updated the TVM subproject to commit 90581fe9e5287bbcf1844ad14255a1e1e8cdf7f0.
- Added new fields to `PipelineAnnotation` and `RewrittenBlockInfo` structures to track original statement indices and improve async state management.
- Refactored `EmitImpl` and `PopulateWaitCounts` methods to enhance clarity and functionality, including better handling of commit groups and wait counts.
- Simplified access index calculations and strengthened analyzer constraints for loop bounds.

* [Cleanup] Remove license block and unused includes from inject_pipeline.cc

- Eliminated the Apache license block from the top of the file to streamline the code.
- Removed unused include directives for memory and stringstream to enhance code clarity and reduce unnecessary dependencies.

* [Refactor] Enhance transformation pipeline and test execution

- Added an additional Simplify transformation in the InjectSoftwarePipeline to improve optimization.
- Updated the test file to call `test_trival_pipeline()` directly, commenting out the previous main execution for better test isolation.
parent bc084aa4
Subproject commit 3a32b763e9d8393b14e4d0f824b2846f70041bc1
Subproject commit 90581fe9e5287bbcf1844ad14255a1e1e8cdf7f0
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file inject_software_pipeline.cc
* \brief Transform annotated loops into pipelined one that parallelize
......@@ -79,6 +60,8 @@ struct PipelineAnnotation {
int stage;
int order;
bool async;
// Index of the statement in the original loop body order (SeqStmt order)
int original_idx = -1;
};
using PipelineInfo = std::unordered_map<Block, PipelineAnnotation,
......@@ -304,15 +287,17 @@ public:
}
// Step 2: Emit the pipeline prologue, body and epilogue.
Stmt prologue = EmitImpl(pipeline_loop_->min,
pipeline_loop_->min + max_stage_, true, true);
Stmt body =
EmitImpl(pipeline_loop_->min + max_stage_,
pipeline_loop_->min + pipeline_loop_->extent, false, false);
Stmt epilogue = EmitImpl(
pipeline_loop_->min + pipeline_loop_->extent,
pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true);
Stmt prologue =
EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true,
true, false);
Stmt body = EmitImpl(pipeline_loop_->min + max_stage_,
pipeline_loop_->min + pipeline_loop_->extent, false,
false, false);
Stmt epilogue =
EmitImpl(pipeline_loop_->min + pipeline_loop_->extent,
pipeline_loop_->min + pipeline_loop_->extent + max_stage_,
true, true, true);
SeqStmt stmt = SeqStmt({prologue, body, epilogue});
// Step 3: Make a new block that contains new buffer allocations after
......@@ -515,12 +500,16 @@ private:
// A symbolic expression representing the index the latest async operation
// associated with this stage has written into, at the "current" iteration.
Optional<PrimExpr> producer_head;
// the commit block's predicate
PrimExpr commit_predicate{nullptr};
};
/*! Structure holding intermediate information for pipeline loop rewriting. */
struct RewrittenBlockInfo {
int stage;
int order;
PrimExpr start;
PrimExpr end;
PrimExpr predicate;
Block block;
PrimExpr access_index;
......@@ -528,56 +517,103 @@ private:
};
void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks,
std::map<int, AsyncStateLocal> *async_states_local) {
std::map<int, AsyncStateLocal> *async_states_local,
bool is_epilogue = false) {
// Precompute which orders are present in this emit, and their access_index
std::unordered_map<int, PrimExpr> order_to_access_index;
std::unordered_set<int> present_orders;
for (const auto &nb : new_blocks) {
order_to_access_index[nb.order] = nb.access_index;
present_orders.insert(nb.order);
}
for (size_t i = 0; i < new_blocks.size(); ++i) {
// 1. Find the unique async producer stage
int producer_stage_idx = -1;
for (auto read_region : new_blocks[i].block->reads) {
for (const auto &read_region : new_blocks[i].block->reads) {
for (const auto &[stage, state] : async_states) {
if (stage <= new_blocks[i].stage &&
state.writes(read_region->buffer)) {
// Found an earlier stage where read_region->buffer was
// asynchronously written
// Currently only a single async stage dependency is supported
ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage)
<< "A dependency on multiple async stages is not supported";
producer_stage_idx = stage;
}
}
}
if (producer_stage_idx == -1)
if (producer_stage_idx == -1) {
// This block does not depend on any async producer
continue;
}
const auto &state = async_states[producer_stage_idx];
auto &dep_local_state = (*async_states_local)[producer_stage_idx];
PrimExpr in_flight_cnt = 0;
for (const auto &group : state.commit_groups) {
PrimExpr consumer_head = new_blocks[i].access_index;
PrimExpr producer_head;
if (dep_local_state.producer_head.defined()) {
producer_head = dep_local_state.producer_head.value();
// if the group is after the wait point, minus by 1
if (group.front() > new_blocks[i].order)
producer_head -= 1;
} else {
producer_head = state.producer_head;
}
in_flight_cnt += producer_head - consumer_head;
}
// We can relax the in-flight-count by the number of independent commit.
// 2. Use buffer_to_commit_group_ to find all actually dependent commit
// groups
std::unordered_set<int> dependent_groups;
for (const auto &read_region : new_blocks[i].block->reads) {
if (state.buffer_to_commit_group_.count(read_region->buffer.get()))
dependent_groups.insert(
state.buffer_to_commit_group_.at(read_region->buffer.get()));
auto it = state.buffer_to_commit_group_.find(read_region->buffer.get());
if (it != state.buffer_to_commit_group_.end()) {
dependent_groups.insert(it->second);
}
}
for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) {
if (dependent_groups.count(i) == 0)
in_flight_cnt += 1;
else
break; // stop relaxing
// If there is no dependent commit group, no wait needs to be inserted
if (dependent_groups.empty()) {
continue;
}
// 3. Compute wait = max_g max(0, t_consumer - committed_before[g])
PrimExpr t_consumer = new_blocks[i].access_index;
PrimExpr wait_expr = make_zero(t_consumer.dtype());
PrimExpr current_head = dep_local_state.producer_head.defined()
? dep_local_state.producer_head.value()
: state.producer_head;
int consumer_order = new_blocks[i].order;
for (int g : dependent_groups) {
const auto &group = state.commit_groups[g];
if (group.empty())
continue;
int commit_order = group.back();
bool commit_present = present_orders.count(commit_order) > 0;
PrimExpr committed_before;
if (commit_present && commit_order <= consumer_order) {
// Commit point is in this iteration and earlier than the current
// consumer; this iteration's head is visible
auto commit_predicate = dep_local_state.commit_predicate;
if (analyzer_.CanProve(!commit_predicate,
arith::ProofStrength::kSymbolicBound)) {
// it means the commit block is not executed in this iteration
committed_before = new_blocks[i].start - 1;
} else if (is_epilogue) {
committed_before = new_blocks[i].start - 1;
} else {
committed_before = order_to_access_index.at(commit_order);
}
} else {
// Commit point is later than the current consumer or not in this
// iteration; only the previous iteration's head is visible
if (dep_local_state.producer_head.defined()) {
auto commit_predicate = dep_local_state.commit_predicate;
if (analyzer_.CanProve(!commit_predicate,
arith::ProofStrength::kSymbolicBound)) {
committed_before = new_blocks[i].start - 1;
} else if (is_epilogue) {
committed_before = new_blocks[i].start - 1;
} else {
committed_before = current_head - 1;
}
}
}
wait_expr = analyzer_.Simplify(committed_before - t_consumer);
}
in_flight_cnt = analyzer_.Simplify(in_flight_cnt);
dep_local_state.pending_waits.push_back(
{static_cast<int>(i), in_flight_cnt});
wait_expr = analyzer_.Simplify(wait_expr);
dep_local_state.pending_waits.push_back({static_cast<int>(i), wait_expr});
}
}
......@@ -630,7 +666,7 @@ private:
* \return The result loop.
*/
Stmt EmitImpl(const PrimExpr &start, const PrimExpr &end, bool unroll_loop,
bool need_bound_check) {
bool need_bound_check, bool is_epilogue = false) {
PrimExpr new_loop_var;
PrimExpr extent = end - start;
auto make_nop = []() {
......@@ -642,7 +678,20 @@ private:
new_loop_var = start; // use constants as the loop var for unit loops
} else {
new_loop_var = pipeline_loop_->loop_var.copy_with_suffix("");
analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
// Bind the iteration domain [start, end) to strengthen analyzer facts.
analyzer_.Bind(Downcast<Var>(new_loop_var),
Range::FromMinExtent(start, end - start));
}
// Keep the bound constraints active for all analysis below.
// Only meaningful when the loop var is symbolic (non-unit loop).
std::unique_ptr<With<arith::ConstraintContext>> ctx_lb_guard;
std::unique_ptr<With<arith::ConstraintContext>> ctx_ub_guard;
if (!is_unit_loop) {
Var loop_iter = Downcast<Var>(new_loop_var);
ctx_lb_guard.reset(
new With<arith::ConstraintContext>(&analyzer_, loop_iter >= start));
ctx_ub_guard.reset(
new With<arith::ConstraintContext>(&analyzer_, loop_iter < end));
}
std::vector<RewrittenBlockInfo> new_blocks;
......@@ -653,15 +702,14 @@ private:
for (const Block &block : ordered_stmts_) {
int stage = pipeline_info_.at(block).stage;
int order = pipeline_info_.at(block).order;
PrimExpr inbound = Bool(true);
PrimExpr skewed_loop_var = new_loop_var - stage;
if (need_bound_check)
inbound =
analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) &&
(skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent);
if (analyzer_.CanProve(!inbound)) {
continue;
}
inbound = And(
pipeline_loop_->min <= skewed_loop_var,
(skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent));
Block new_block = Downcast<Block>(
PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
pipeline_loop_, max_stage_ != 1)(block));
......@@ -674,6 +722,8 @@ private:
PrimExpr normalized_access_index =
is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
normalized_access_index = analyzer_.Simplify(normalized_access_index);
// Adjust the block predicate and the body according to the final loop
// bound
// [pipeline_loop_->min, extent).
......@@ -701,17 +751,18 @@ private:
if (pipeline_info_[block].async) {
auto &local_state = async_states_local[stage];
local_state.producer_head = normalized_access_index;
local_state.commit_predicate = inbound;
BlockNode *n = new_block.CopyOnWrite();
n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope,
1, n->body);
}
new_blocks.push_back({stage, order, inbound, new_block,
new_blocks.push_back({stage, order, start, end, inbound, new_block,
normalized_access_index,
pipeline_info_[block].async});
}
PopulateWaitCounts(new_blocks, &async_states_local);
PopulateWaitCounts(new_blocks, &async_states_local, is_epilogue);
auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);
......@@ -1008,7 +1059,8 @@ private:
pipeline_async_stages.find(stage) != pipeline_async_stages.end();
PipelineAnnotation stage_order{
stage,
/*order=*/static_cast<int>(pipeline_orders[i]->value), is_async};
/*order=*/static_cast<int>(pipeline_orders[i]->value), is_async,
/*original_idx=*/static_cast<int>(i)};
pipeline_info.emplace(original_order[i], stage_order);
}
......
......@@ -10,6 +10,7 @@ def _check(original, transformed):
mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tl.transform.Simplify()(mod)
mod = tl.transform.LowerOpaqueBlock()(mod)
mod = tl.transform.Simplify()(mod)
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"),
True)
......
......@@ -217,6 +217,7 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.InjectFenceProxy()(mod)
mod = tilelang.transform.LowerOpaqueBlock()(mod)
mod = tilelang.transform.Simplify()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tilelang.transform.FlattenBuffer()(mod)
# ConfigIndexBitwidth must be applied after FlattenBuffer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment