Unverified Commit 73bf8346 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Rebase pipeline injector from upstream tvm (#687)

* [Enhancement] Introduce software pipeline rewriter and refactor buffer access handling

- Added a new `PipelineOpaqueAccessRewriter` class to manage opaque buffer accesses in the software pipeline.
- Refactored the `PipelineBodyRewriter` to utilize the new rewriter for improved buffer access handling.
- Enhanced the `PipelineRewriter` to support additional fragment information and streamline pipeline construction.
- Updated tests to reflect changes in buffer management and access patterns, ensuring compatibility with the new structure.
- Removed obsolete code related to previous buffer access methods for clarity and maintainability.

* test fix
parent b45e9c45
...@@ -37,6 +37,8 @@ namespace tvm { ...@@ -37,6 +37,8 @@ namespace tvm {
namespace tl { namespace tl {
using namespace tir; using namespace tir;
namespace software_pipeline {
/*! /*!
* \brief Create a block and infer the access region with the given body. * \brief Create a block and infer the access region with the given body.
* *
...@@ -81,34 +83,137 @@ struct BufferAccessInfo { ...@@ -81,34 +83,137 @@ struct BufferAccessInfo {
int use = -1; // the last using stage of the buffer int use = -1; // the last using stage of the buffer
}; };
/*! class PipelineOpaqueAccessRewriter {
* \brief Replace IfThenElse nodes with their then_case, preserving attribute public:
* nodes \param body The statement to process \param condition The condition to /*!
* match in IfThenElse nodes \return The transformed statement * \brief Constructor
* \param buffer_data_to_buffer The map from buffer data to buffer.
* \param buffer_remap The map from original buffer to the buffer with updated
* shape for multi-versioning in the software pipeline. \param pipeline_loop
* The original loop to be software pipelined. \param fragment_info
* Information about tensor core fragment
*/ */
Stmt replace_if_then_else(Stmt body, PrimExpr condition) { PipelineOpaqueAccessRewriter(
if (const auto *if_node = body.as<IfThenElseNode>()) { const Map<Var, Buffer> &buffer_data_to_buffer,
// If this is an IfThenElse with the matching condition, replace it with its const Map<Buffer, Buffer> &buffer_remap, const For &pipeline_loop,
// then_case const std::unordered_map<const VarNode *, FragmentInfo> &fragment_info)
if (if_node->condition.same_as(condition)) { : buffer_data_to_buffer_(buffer_data_to_buffer),
return if_node->then_case; buffer_remap_(buffer_remap), pipeline_loop_(pipeline_loop),
} fragment_info_(fragment_info) {}
} else if (const auto *attr_node = body.as<AttrStmtNode>()) {
// For attribute nodes, preserve the attribute but process its body PrimExpr Rewrite(const Call &call) {
AttrStmt attr_stmt = GetRef<AttrStmt>(attr_node); // Intrinsic calls should be handled explicitly here as they are opaque
attr_stmt.CopyOnWrite()->body = // accesses to buffer.
replace_if_then_else(attr_node->body, condition); static const auto &load_matrix_sync = builtin::tvm_load_matrix_sync();
return attr_stmt; static const auto &store_matrix_sync = builtin::tvm_store_matrix_sync();
} else if (const auto *block_node = body.as<BlockNode>()) { static const auto &mma_sync = builtin::tvm_mma_sync();
// For block nodes, process the body static const auto &access_ptr = builtin::tvm_access_ptr();
Block block = GetRef<Block>(block_node); static const auto &ptx_ldmatrix = builtin::ptx_ldmatrix();
block.CopyOnWrite()->body = static const auto &ptx_mma = builtin::ptx_mma();
replace_if_then_else(block_node->body, condition); if (call->op.same_as(load_matrix_sync) ||
return block; call->op.same_as(store_matrix_sync)) {
const Buffer &buffer =
buffer_data_to_buffer_.at(Downcast<Var>(call->args[0]));
auto it = buffer_remap_.find(buffer);
if (it != buffer_remap_.end()) {
Array<PrimExpr> new_args = call->args;
const Buffer &new_buffer = (*it).second;
new_args.Set(
4, RewriteWmmaFragmentIndex(buffer, new_buffer, call->args[4]));
return Call(call->dtype, call->op, new_args, call->span);
} }
// For any other node type, return it unchanged } else if (call->op.same_as(mma_sync)) {
return body; Array<PrimExpr> new_args = call->args;
} for (int i = 0; i < 4; i++) {
const Var &buffer_var = Downcast<Var>(call->args[i * 2]);
const PrimExpr &index = call->args[i * 2 + 1];
const Buffer &buffer = buffer_data_to_buffer_.at(buffer_var);
auto it = buffer_remap_.find(buffer);
if (it != buffer_remap_.end()) {
PrimExpr new_index =
RewriteWmmaFragmentIndex(buffer, (*it).second, index);
new_args.Set(i * 2 + 1, new_index);
}
}
return Call(call->dtype, call->op, new_args, call->span);
} else if (call->op.same_as(access_ptr)) {
return RewriteBufferAccess(call, {1});
} else if (call->op.same_as(ptx_mma)) {
return RewriteBufferAccess(call, {6, 8, 10});
} else if (call->op.same_as(ptx_ldmatrix)) {
return RewriteBufferAccess(call, {3});
}
return call;
}
private:
int GetWmmaFragmentSize(const Buffer &buffer) {
auto it = fragment_info_.find(buffer->data.get());
ICHECK(it != fragment_info_.end());
const FragmentInfo &info = (*it).second;
return info.GetSize();
}
PrimExpr RewriteWmmaFragmentIndex(const Buffer &old_buffer,
const Buffer &new_buffer,
const PrimExpr &old_index) {
PrimExpr new_buffer_offset = old_index;
int fragment_size = GetWmmaFragmentSize(old_buffer);
PrimExpr offset = floordiv(
foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), old_buffer->shape),
fragment_size);
new_buffer_offset +=
floormod(pipeline_loop_->loop_var - pipeline_loop_->min,
new_buffer->shape[0]) *
offset;
return new_buffer_offset;
}
PrimExpr RewriteBufferAccess(const Call &call,
const std::vector<int> arg_indices) {
auto product = [](const Array<PrimExpr> &input) {
return foldl(
[](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), input);
};
Array<PrimExpr> new_args = call->args;
for (int i : arg_indices) {
const Buffer &buffer =
buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
auto it = buffer_remap_.find(buffer);
if (it != buffer_remap_.end()) {
const Buffer &new_buffer = (*it).second;
const PrimExpr &old_index = call->args[i + 1];
PrimExpr offset;
if (new_buffer->strides.empty()) {
offset = product(buffer->shape);
} else {
offset = new_buffer->strides[0];
}
if (buffer.scope() == "m16n8k8.matrixA" ||
buffer.scope() == "m16n8k8.matrixB") {
// mma scope size will shrink by warp size
// @see transform_mma_buffer_layout
ICHECK_EQ(Downcast<IntImm>(floormod(offset, 32))->value, 0)
<< "mma scope size should be multiple of warp size";
offset = floordiv(offset, 32);
}
PrimExpr new_index =
old_index +
floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
new_args.Set(i + 1, new_index);
}
}
return Call(call->dtype, call->op, new_args, call->span);
}
const Map<Var, Buffer> &buffer_data_to_buffer_;
const Map<Buffer, Buffer> &buffer_remap_;
const For &pipeline_loop_;
const std::unordered_map<const VarNode *, FragmentInfo> &fragment_info_;
};
/*! /*!
* \brief Rewriter for the body of the software pipeline. This pass inserts * \brief Rewriter for the body of the software pipeline. This pass inserts
...@@ -126,14 +231,19 @@ public: ...@@ -126,14 +231,19 @@ public:
* Whether all versions the buffers in the software pipeline are accessed. * Whether all versions the buffers in the software pipeline are accessed.
* This will be used to update block access region. In the prologue and * This will be used to update block access region. In the prologue and
* epilogue of a two-stage software pipeline, only one version of these * epilogue of a two-stage software pipeline, only one version of these
* buffers are accessed. * buffers are accessed. \param fragment_info Information about tensor core
* fragment
*/ */
PipelineBodyRewriter(const Map<Var, Buffer> &buffer_data_to_buffer, PipelineBodyRewriter(
const Map<Buffer, Buffer> &buffer_remap, const Map<Var, Buffer> &buffer_data_to_buffer,
For pipeline_loop, bool access_all_versions) const Map<Buffer, Buffer> &buffer_remap, For pipeline_loop,
bool access_all_versions,
const std::unordered_map<const VarNode *, FragmentInfo> &fragment_info)
: buffer_data_to_buffer_(buffer_data_to_buffer), : buffer_data_to_buffer_(buffer_data_to_buffer),
buffer_remap_(buffer_remap), pipeline_loop_(pipeline_loop), buffer_remap_(buffer_remap), pipeline_loop_(pipeline_loop),
access_all_versions_(access_all_versions) {} access_all_versions_(access_all_versions),
opaque_access_rewriter_(buffer_data_to_buffer_, buffer_remap_,
pipeline_loop_, fragment_info) {}
private: private:
BufferRegion BufferRegion
...@@ -157,36 +267,6 @@ private: ...@@ -157,36 +267,6 @@ private:
return buffer_region; return buffer_region;
} }
PrimExpr RewriteBufferAccess(const Call &call,
const std::vector<int> arg_indices) {
auto product = [](const Array<PrimExpr> &input) {
return foldl(
[](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), input);
};
Array<PrimExpr> new_args = call->args;
for (int i : arg_indices) {
const Buffer &buffer =
buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
auto it = buffer_remap_.find(buffer);
if (it != buffer_remap_.end()) {
const Buffer &new_buffer = (*it).second;
const PrimExpr &old_index = call->args[i + 1];
PrimExpr offset;
if (new_buffer->strides.empty()) {
offset = product(buffer->shape);
} else {
offset = new_buffer->strides[0];
}
PrimExpr new_index =
old_index +
floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
new_args.Set(i + 1, new_index);
}
}
return Call(call->dtype, call->op, new_args, call->span);
}
Stmt VisitStmt_(const BlockNode *op) final { Stmt VisitStmt_(const BlockNode *op) final {
for (const Buffer &alloc_buffer : op->alloc_buffers) { for (const Buffer &alloc_buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer); buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer);
...@@ -202,14 +282,14 @@ private: ...@@ -202,14 +282,14 @@ private:
for (const Buffer &alloc_buffer : op->alloc_buffers) { for (const Buffer &alloc_buffer : op->alloc_buffers) {
buffer_data_to_buffer_.erase(alloc_buffer->data); buffer_data_to_buffer_.erase(alloc_buffer->data);
} }
return std::move(block); return block;
} }
Stmt VisitStmt_(const BufferStoreNode *op) final { Stmt VisitStmt_(const BufferStoreNode *op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op)); BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto it = buffer_remap_.find(store->buffer); auto it = buffer_remap_.find(store->buffer);
if (it == buffer_remap_.end()) { if (it == buffer_remap_.end()) {
return std::move(store); return store;
} }
const Buffer &new_buffer = (*it).second; const Buffer &new_buffer = (*it).second;
auto *n = store.CopyOnWrite(); auto *n = store.CopyOnWrite();
...@@ -217,14 +297,14 @@ private: ...@@ -217,14 +297,14 @@ private:
PrimExpr version = floormod( PrimExpr version = floormod(
(pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
n->indices.insert(n->indices.begin(), version); n->indices.insert(n->indices.begin(), version);
return std::move(store); return store;
} }
PrimExpr VisitExpr_(const BufferLoadNode *op) final { PrimExpr VisitExpr_(const BufferLoadNode *op) final {
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op)); BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto it = buffer_remap_.find(load->buffer); auto it = buffer_remap_.find(load->buffer);
if (it == buffer_remap_.end()) { if (it == buffer_remap_.end()) {
return std::move(load); return load;
} }
const Buffer &new_buffer = (*it).second; const Buffer &new_buffer = (*it).second;
auto *n = load.CopyOnWrite(); auto *n = load.CopyOnWrite();
...@@ -232,21 +312,19 @@ private: ...@@ -232,21 +312,19 @@ private:
PrimExpr version = floormod( PrimExpr version = floormod(
(pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]); (pipeline_loop_->loop_var - pipeline_loop_->min), new_buffer->shape[0]);
n->indices.insert(n->indices.begin(), version); n->indices.insert(n->indices.begin(), version);
return std::move(load); return load;
} }
PrimExpr VisitExpr_(const CallNode *op) final { PrimExpr VisitExpr_(const CallNode *op) final {
Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op)); Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
if (call->op.same_as(builtin::tvm_access_ptr())) { return opaque_access_rewriter_.Rewrite(call);
return RewriteBufferAccess(call, {1});
}
return call;
} }
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, Buffer> buffer_remap_; Map<Buffer, Buffer> buffer_remap_;
For pipeline_loop_; For pipeline_loop_;
bool access_all_versions_; bool access_all_versions_;
PipelineOpaqueAccessRewriter opaque_access_rewriter_;
}; };
/*! /*!
...@@ -255,14 +333,35 @@ private: ...@@ -255,14 +333,35 @@ private:
*/ */
class PipelineRewriter : public StmtExprMutator { class PipelineRewriter : public StmtExprMutator {
public: public:
PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer, static Stmt Rewrite(
const Array<Buffer> &pipeline_allocs, Map<Var, Buffer> buffer_data_to_buffer,
const For &pipeline_loop, const PipelineInfo &pipeline_info, const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>
PrimExpr predicate_condition = PrimExpr()) &double_buffers,
const Array<Buffer> pipeline_allocs, const For &pipeline_loop,
const PipelineInfo &pipeline_info,
const std::unordered_map<const VarNode *, FragmentInfo> &fragment_info,
const Map<String, ffi::Any> preserved_annotations) {
PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers,
pipeline_allocs, pipeline_loop, pipeline_info,
fragment_info, preserved_annotations);
return rewriter.BuildPipeline();
}
private:
PipelineRewriter(
Map<Var, Buffer> buffer_data_to_buffer,
const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>
&double_buffers,
const Array<Buffer> &pipeline_allocs, const For &pipeline_loop,
const PipelineInfo &pipeline_info,
const std::unordered_map<const VarNode *, FragmentInfo> &fragment_info,
const Map<String, ffi::Any> preserved_annotations)
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop), double_buffers_(double_buffers), pipeline_allocs_(pipeline_allocs),
pipeline_info_(pipeline_info), pipeline_loop_(pipeline_loop), pipeline_info_(pipeline_info),
predicate_condition_(predicate_condition) {} fragment_info_(fragment_info),
preserved_annotations_(preserved_annotations) {}
Stmt BuildPipeline() { Stmt BuildPipeline() {
// Step 1: Analyze accesses to the buffers in the pipeline and compute the // Step 1: Analyze accesses to the buffers in the pipeline and compute the
...@@ -277,61 +376,36 @@ public: ...@@ -277,61 +376,36 @@ public:
} }
ordered_stmts_.resize(pipeline_info_.size()); ordered_stmts_.resize(pipeline_info_.size());
for (const auto &[block, anno] : pipeline_info_) { for (const auto &pair : pipeline_info_) {
ordered_stmts_.Set(anno.order, block); const Block &block = pair.first;
} int order = pair.second.order;
ordered_stmts_.Set(order, block);
for (const Block &block : ordered_stmts_) {
int stage = pipeline_info_[block].stage;
if (pipeline_info_[block].async) {
auto &state = async_states[stage];
state.producer_head = pipeline_loop_->min - 1;
for (auto write_region : block->writes) {
auto buffer = write_region->buffer;
state.dst_buffers.insert(buffer.get());
if (buffer_remap_.count(buffer))
state.dst_buffers.insert(buffer_remap_[buffer].get());
}
}
}
std::unordered_set<int> consumed;
for (const Block &block : ordered_stmts_) {
int stage = pipeline_info_[block].stage;
if (pipeline_info_[block].async) {
auto &state = async_states[stage];
if (state.commit_groups.empty() || consumed.count(stage)) {
state.commit_groups.push_back({});
}
state.commit_groups.back().push_back(pipeline_info_[block].order);
consumed.erase(stage);
for (auto write_region : block->writes) {
auto buffer = buffer_remap_.count(write_region->buffer)
? buffer_remap_[write_region->buffer]
: write_region->buffer;
state.buffer_to_commit_group_[buffer.get()] =
state.commit_groups.size() - 1;
}
} }
for (auto read_region : block->reads) { // Step 2: Emit the pipeline prologue, body and epilogue.
for (const auto &[producer_stage_id, producer_state] : async_states) { Stmt prologue =
if (producer_stage_id <= stage && EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true);
producer_state.writes(read_region->buffer)) { Stmt body = EmitImpl(pipeline_loop_->min + max_stage_,
consumed.insert(producer_stage_id); pipeline_loop_->min + pipeline_loop_->extent, false);
} // introduce extra lowerbound when the loop length is smaller than num
} // stages to ensure the epilogue interval do not overlap the prologue
// interval.
PrimExpr epigogue_start = pipeline_loop_->min + pipeline_loop_->extent;
Optional<PrimExpr> extra_epilogue_lower_bound = std::nullopt;
if (max_stage_ > 1 &&
!analyzer_.CanProveGreaterEqual(pipeline_loop_->extent, max_stage_)) {
if (is_const_int(epigogue_start)) {
epigogue_start = max(epigogue_start, pipeline_loop_->min + max_stage_);
} else {
// for dynamic case, introduce extra lowerbound as loop predicate
// to ensure the epilogue part unrollable.
extra_epilogue_lower_bound = pipeline_loop_->min + max_stage_;
} }
} }
Stmt epilogue =
// Step 2: Emit the pipeline prologue, body and epilogue. EmitImpl(epigogue_start,
Stmt prologue = EmitImpl(pipeline_loop_->min, pipeline_loop_->min + pipeline_loop_->extent + max_stage_,
pipeline_loop_->min + max_stage_, true, true); true, extra_epilogue_lower_bound);
Stmt body =
EmitImpl(pipeline_loop_->min + max_stage_,
pipeline_loop_->min + pipeline_loop_->extent, false, false);
Stmt epilogue = EmitImpl(
pipeline_loop_->min + pipeline_loop_->extent,
pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true);
SeqStmt stmt = SeqStmt({prologue, body, epilogue}); SeqStmt stmt = SeqStmt({prologue, body, epilogue});
...@@ -434,7 +508,7 @@ private: ...@@ -434,7 +508,7 @@ private:
// We optimize a few case where the number of versions can be smaller than // We optimize a few case where the number of versions can be smaller than
// the upper bound // the upper bound
int num_versions = buffer_info.use - buffer_info.def + 1; int num_versions = buffer_info.use - buffer_info.def + 1;
if (num_versions >= 2) { if (num_versions == 2) {
// A special case when `use - def + 1 == 2`. Double buffering is only // A special case when `use - def + 1 == 2`. Double buffering is only
// needed in this case when these exists a reader block_i and a writer // needed in this case when these exists a reader block_i and a writer
// block_j such that order(block_i) < order(block_j) and stage(block_i) < // block_j such that order(block_i) < order(block_j) and stage(block_i) <
...@@ -473,8 +547,11 @@ private: ...@@ -473,8 +547,11 @@ private:
} }
} }
if (!need_multi_version) { if (!need_multi_version) {
num_versions--; num_versions = 1;
}
} }
if (num_versions == 1 && double_buffers_.count(buffer)) {
num_versions = 2;
} }
return num_versions; return num_versions;
} }
...@@ -507,16 +584,15 @@ private: ...@@ -507,16 +584,15 @@ private:
// valid, it is the "sum of extents of loops that have been executed" - 1, // valid, it is the "sum of extents of loops that have been executed" - 1,
// e.g. for epilogue it is prologue extent + body extent - 1. This is only // e.g. for epilogue it is prologue extent + body extent - 1. This is only
// needed to compute wait count for epilogue without async producers. // needed to compute wait count for epilogue without async producers.
PrimExpr producer_head; Optional<PrimExpr> producer_head{PrimExpr(-1)};
std::vector<std::vector<int>> commit_groups;
std::unordered_map<const BufferNode *, int> buffer_to_commit_group_;
bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; } bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; }
}; };
// Per-stage states that are local to each of pipeline prologue, body, and // Per-stage states that are local to each of pipeline prologue, body, and
// epilogue. // epilogue.
struct AsyncStateLocal { struct AsyncStateLocal {
struct PendingWait { struct {
// The index into a list of blocks, where async_wait_queue should be // The index into a list of blocks, where async_wait_queue should be
// attached at the beginning. // attached at the beginning.
int insert_before; int insert_before;
...@@ -525,76 +601,187 @@ private: ...@@ -525,76 +601,187 @@ private:
PrimExpr wait_count{nullptr}; PrimExpr wait_count{nullptr};
bool valid() const { return wait_count.defined(); } bool valid() const { return wait_count.defined(); }
}; } pending_wait;
std::vector<PendingWait> pending_waits; // Destination buffers of async operations that have been encountered so far
// in the loop
//
// for (size_t i = 0; i < new_blocks.size(); ++i) {
// ...
// }
//
// This is for tracking which async operations have been issued at the
// "current" iteration, up until a point where we encounter a consumer of
// async result buffers. This is used to decide if the producer_head of each
// buffer points to a copy written in the current or previous iteration.
std::unordered_set<const BufferNode *> seen;
// A symbolic expression representing the index the latest async operation // A symbolic expression representing the index the latest async operation
// associated with this stage has written into, at the "current" iteration. // associated with this stage has written into, at the "current" iteration.
Optional<PrimExpr> producer_head; Optional<PrimExpr> producer_head;
// The predicate of BlockRealize containing the async operation of this
// stage.
Optional<PrimExpr> predicate;
// Indices into a list of blocks, where async_commit_queue scope should be
// attached. If multiple async producers are interleaved with their consumer
// in between, we need separate async_commit_queue for each producer. Thus,
// we need multiple sets of indices.
std::vector<std::vector<size_t>> commit_groups;
// This is set to true when we reach a stage that consumes this async stage.
bool consumed{false};
}; };
/*! Structure holding intermediate information for pipeline loop rewriting. */ /*! Structure holding intermediate information for pipeline loop rewriting. */
struct RewrittenBlockInfo { struct RewrittenBlockInfo {
int stage; int stage;
int order;
PrimExpr predicate; PrimExpr predicate;
Block block; Block block;
PrimExpr access_index; PrimExpr access_index;
bool is_async; bool is_async;
}; };
void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks, // Determine where to insert async_wait and the corresponding wait count.
void PopulateWaitCounts(
const std::vector<RewrittenBlockInfo> &new_blocks,
arith::Analyzer *ana_normalized,
const std::unordered_map<const BufferNode *, int> &buffer_to_commit_group,
std::map<int, AsyncStateLocal> *async_states_local) { std::map<int, AsyncStateLocal> *async_states_local) {
for (size_t i = 0; i < new_blocks.size(); ++i) { for (size_t i = 0; i < new_blocks.size(); ++i) {
if (new_blocks[i].is_async) {
// Record the fact that we have encountered these write buffers.
for (auto write_region : new_blocks[i].block->writes) {
(*async_states_local)[new_blocks[i].stage].seen.insert(
write_region->buffer.get());
}
}
int producer_stage_idx = -1; int producer_stage_idx = -1;
for (auto read_region : new_blocks[i].block->reads) { for (auto read_region : new_blocks[i].block->reads) {
for (const auto &[stage, state] : async_states) { for (auto kv : async_states) {
if (stage <= new_blocks[i].stage && if (kv.first <= new_blocks[i].stage &&
state.writes(read_region->buffer)) { kv.second.writes(read_region->buffer)) {
// Found an earlier stage where read_region->buffer was // Found an earlier stage where read_region->buffer was
// asynchronously written // asynchronously written
ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage) ICHECK(producer_stage_idx == -1 || producer_stage_idx == kv.first)
<< "A dependency on multiple async stages is not supported"; << "A dependency on multiple async stages is not supported";
producer_stage_idx = stage; producer_stage_idx = kv.first;
} }
} }
} }
if (producer_stage_idx == -1) if (producer_stage_idx == -1)
continue; continue;
const auto &state = async_states[producer_stage_idx];
// The following logic has become complicated to handle case like this:
//
// for i in range(13):
// # Stage 0
// async_commit_queue(0):
// async_scope:
// A_shared[(i + 3) % 4] = A[...]
//
//
// # Stage 1
// async_wait_queue(0, 5):
// compute(A_shared[i], B_shared[i])
//
// # Stage 0
// async_commit_queue(0)
// async_scope:
// B_shared[(i + 3) % 4] = B[...]
//
//
// Here, multiple async producers in the same stage are interleaved with
// their consumer in between. Since each buffer is associated with
// different commit groups, the wait_count before the consumer should be
// bigger than the simpler case:
//
// for i in range(13):
// # Stage 0
// async_commit_queue(0):
// async_scope:
// A_shared[(i + 3) % 4] = A[...]
// B_shared[(i + 3) % 4] = B[...]
//
// # Stage 1
// async_wait_queue(0, 3):
// compute(A_shared[i], B_shared[i])
//
// The correct wait_count can be determined by considering each commit
// group separately, and summing "per-commit" wait_counts.
//
// From A_shared's perspective, it allows for (i + 3) - i async commit
// groups to be in flight while from B_shared's perspective, the producer
// head at compute points to the copy done by the previous iteration, so
// its wait_count is calculated as ((i - 1) + 3) - i. The sum of the two
// wait_counts gives 5.
auto &dep_local_state = (*async_states_local)[producer_stage_idx]; auto &dep_local_state = (*async_states_local)[producer_stage_idx];
PrimExpr in_flight_cnt = 0; const auto num_commit_group = dep_local_state.commit_groups.size();
for (const auto &group : state.commit_groups) { std::vector<Optional<PrimExpr>> producer_head_per_commit;
PrimExpr consumer_head = new_blocks[i].access_index;
PrimExpr producer_head; if (num_commit_group == 0) {
if (dep_local_state.producer_head.defined()) { // Epilogue, no async producer. Since "local" producer_head is not
producer_head = dep_local_state.producer_head.value(); // available, use "global" producer_head.
// if the group is after the wait point, minus by 1 ICHECK(!dep_local_state.producer_head);
if (group.front() > new_blocks[i].order) producer_head_per_commit.push_back(
producer_head -= 1; async_states[producer_stage_idx].producer_head);
} else { } else {
producer_head = state.producer_head; ICHECK(dep_local_state.producer_head);
std::vector<bool> need_wait_count(num_commit_group, true);
for (auto read_region : new_blocks[i].block->reads) {
if (!async_states[producer_stage_idx].writes(read_region->buffer))
continue;
auto commit_group_id =
buffer_to_commit_group.at(read_region->buffer.get());
if (!need_wait_count[commit_group_id])
continue;
if (!dep_local_state.seen.count(read_region->buffer.get())) {
// Multiple async producers interleaved: The most recent async write
// is from the previous iteration. This is the B_shared case above.
producer_head_per_commit.push_back(
dep_local_state.producer_head.value() - 1);
} else {
// Normal case
producer_head_per_commit.push_back(
dep_local_state.producer_head.value());
}
need_wait_count[commit_group_id] = false;
} }
in_flight_cnt += producer_head - consumer_head;
} }
// We can relax the in-flight-count by the number of independent commit. auto wait_count = [=, &ana_normalized]() {
std::unordered_set<int> dependent_groups; auto sum = PrimExpr(0);
for (const auto &read_region : new_blocks[i].block->reads) { for (auto producer_head : producer_head_per_commit) {
if (state.buffer_to_commit_group_.count(read_region->buffer.get())) if (producer_head &&
dependent_groups.insert( ana_normalized->CanProve(producer_head.value() >= 0)) {
state.buffer_to_commit_group_.at(read_region->buffer.get())); // Here, new_blocks[i].access_index corresponds to "consumer_head".
// The difference of producer_head and consumer_head is precisely
// the number of async commit groups that can still be in flight
// after this wait.
sum += analyzer_.Simplify(producer_head.value() -
new_blocks[i].access_index);
} else {
// The precise count cannot be determined, give up.
return PrimExpr(0);
} }
for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) {
if (dependent_groups.count(i) == 0)
in_flight_cnt += 1;
else
break; // stop relaxing
} }
in_flight_cnt = analyzer_.Simplify(in_flight_cnt); return sum;
dep_local_state.pending_waits.push_back( }();
{static_cast<int>(i), in_flight_cnt});
auto &pending_wait = dep_local_state.pending_wait;
if (!pending_wait.valid()) {
pending_wait = {static_cast<int>(i), wait_count};
} else if (analyzer_.CanProve(wait_count < pending_wait.wait_count)) {
// Coalesce multiple wait_queue if the later one allows fewer in-flight
// ops.
pending_wait = {pending_wait.insert_before, wait_count};
}
} }
} }
...@@ -602,38 +789,85 @@ private: ...@@ -602,38 +789,85 @@ private:
// statements with async scopes (if any). // statements with async scopes (if any).
Array<Stmt> CompletePipelineLoopStatements( Array<Stmt> CompletePipelineLoopStatements(
const std::vector<RewrittenBlockInfo> &blocks, const std::vector<RewrittenBlockInfo> &blocks,
const std::map<int, AsyncStateLocal> &async_states_local) const { const std::map<int, AsyncStateLocal> &async_states_local,
arith::Analyzer *ana_normalized) const {
std::vector<RewrittenBlockInfo> new_blocks = blocks; std::vector<RewrittenBlockInfo> new_blocks = blocks;
std::vector<int> commit_group_indices(new_blocks.size(), -1);
for (const auto &[stage_id, state] : async_states_local) { for (const auto &[stage_id, state] : async_states_local) {
for (const auto &pw : state.pending_waits) { if (!state.commit_groups.empty()) {
auto &block = new_blocks[pw.insert_before].block; for (size_t i = 0; i < state.commit_groups.size(); ++i) {
for (size_t j = 0; j < state.commit_groups[i].size(); ++j) {
ICHECK(state.commit_groups[i][0] + j < new_blocks.size());
commit_group_indices[state.commit_groups[i][0] + j] = stage_id;
}
}
}
if (state.pending_wait.valid()) {
auto attach_wait_scope = [&new_blocks](int i, int stage_id,
PrimExpr wait_count) {
auto &block = new_blocks[i].block;
BlockNode *n = block.CopyOnWrite(); BlockNode *n = block.CopyOnWrite();
auto zero = make_zero(DataType::Int(32)); auto zero = make_zero(DataType::Int(32));
n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id, n->body =
AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
AttrStmt(zero, tir::attr::async_wait_inflight_count, AttrStmt(zero, tir::attr::async_wait_inflight_count,
pw.wait_count, n->body)); wait_count, n->body));
} };
}
// mark the last async stmt as commit if (state.predicate &&
std::unordered_set<int> commit_group_indices; !ana_normalized->CanProve(state.predicate.value())) {
for (const auto &[stage_id, state] : async_states) { // If the async operation that this wait_queue is waiting on is
for (size_t i = 0; i < state.commit_groups.size(); ++i) { // predicated, and we cannot prove that the predicate is always true,
commit_group_indices.insert(state.commit_groups[i].back()); // the precise wait count is only valid at iterations where the
// predicate is true;
auto wait_count =
Call(DataType::Int(32), builtin::if_then_else(),
{state.predicate.value(), state.pending_wait.wait_count, 0});
attach_wait_scope(state.pending_wait.insert_before, stage_id,
wait_count);
} else {
attach_wait_scope(state.pending_wait.insert_before, stage_id,
state.pending_wait.wait_count);
}
} }
} }
Array<Stmt> stmts; Array<Stmt> stmts;
for (size_t i = 0; i < new_blocks.size(); i++) { for (size_t i = 0; i < new_blocks.size();) {
Block block = new_blocks[i].block; if (commit_group_indices[i] == -1) {
if (commit_group_indices.count(new_blocks[i].order)) { // A synchrnous block, not part of any commit group
auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)), stmts.push_back(
tir::attr::async_commit_queue_scope, BlockRealize({}, new_blocks[i].predicate, new_blocks[i].block));
new_blocks[i].stage, block->body); ++i;
block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_); } else {
Array<Stmt> group_bodies;
auto stage_id = commit_group_indices[i];
auto predicate = new_blocks[i].predicate;
for (; i < commit_group_indices.size() &&
commit_group_indices[i] == stage_id;
++i) {
ICHECK(tvm::StructuralEqual()(predicate, new_blocks[i].predicate))
<< "Predicates in the same stage are expected to be identical";
group_bodies.push_back(new_blocks[i].block->body);
}
if (group_bodies.size() > 1) {
auto merged_bodies = SeqStmt(group_bodies);
group_bodies.clear();
group_bodies.push_back(merged_bodies);
}
for (auto body : group_bodies) {
auto commit_queue_scope =
AttrStmt(make_zero(DataType::Int(32)),
tir::attr::async_commit_queue_scope, stage_id, body);
auto new_block =
MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
stmts.push_back(BlockRealize({}, predicate, new_block));
}
} }
stmts.push_back(BlockRealize({}, new_blocks[i].predicate, block));
} }
return stmts; return stmts;
...@@ -644,16 +878,21 @@ private: ...@@ -644,16 +878,21 @@ private:
* \param start The start of the range * \param start The start of the range
* \param end The end of the range * \param end The end of the range
* \param unroll_loop Whether the loop should be unrolled. * \param unroll_loop Whether the loop should be unrolled.
* \param extra_loop_lower_bound Extra loop lower bound.
* \return The result loop. * \return The result loop.
*/ */
Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop, Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop,
bool need_bound_check) { Optional<PrimExpr> extra_loop_lower_bound = std::nullopt) {
PrimExpr new_loop_var; PrimExpr new_loop_var;
PrimExpr extent = end - start; PrimExpr extent = end - start;
auto make_nop = []() { auto make_nop = []() {
return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {})); return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {}));
}; };
if (analyzer_.CanProve(extent <= 0)) {
return make_nop();
}
bool is_unit_loop = analyzer_.CanProveEqual(extent, 1); bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
if (is_unit_loop) { if (is_unit_loop) {
new_loop_var = start; // use constants as the loop var for unit loops new_loop_var = start; // use constants as the loop var for unit loops
...@@ -662,34 +901,43 @@ private: ...@@ -662,34 +901,43 @@ private:
analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end)); analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
} }
// In contrast to analyzer_ which is bound to [start, end), this one is
// bound to the "normalized" range, [pipeline_loop_->min, extent).
arith::Analyzer ana_normalized;
if (!is_unit_loop) {
ana_normalized.Bind(Downcast<Var>(new_loop_var),
Range(pipeline_loop_->min, extent));
}
std::vector<RewrittenBlockInfo> new_blocks; std::vector<RewrittenBlockInfo> new_blocks;
// Async related // Async related
std::map<int, AsyncStateLocal> async_states_local; std::map<int, AsyncStateLocal> async_states_local;
PrimExpr normalized_access_index; std::unordered_map<const BufferNode *, int> buffer_to_commit_group;
for (const Block &block : ordered_stmts_) { for (const Block &block : ordered_stmts_) {
int stage = pipeline_info_.at(block).stage; int stage = pipeline_info_.at(block).stage;
int order = pipeline_info_.at(block).order;
PrimExpr inbound = Bool(true);
PrimExpr skewed_loop_var = new_loop_var - stage; PrimExpr skewed_loop_var = new_loop_var - stage;
if (need_bound_check) PrimExpr inbound =
inbound =
analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) && analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) &&
(skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent); (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent);
if (extra_loop_lower_bound.defined()) {
inbound = analyzer_.Simplify(
inbound && new_loop_var >= extra_loop_lower_bound.value());
}
if (analyzer_.CanProve(!inbound)) { if (analyzer_.CanProve(!inbound)) {
continue; continue;
} }
Block new_block = Downcast<Block>( Block new_block = Downcast<Block>(PipelineBodyRewriter(
PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_, buffer_data_to_buffer_, buffer_remap_, pipeline_loop_,
pipeline_loop_, max_stage_ != 1)(block)); max_stage_ != 1, fragment_info_)(block));
PrimExpr delta = start - pipeline_loop_->min; PrimExpr delta = start - pipeline_loop_->min;
// This variable corresponds to // This variable corresponds to
// - "producer_head" if this stage is an async producer // - "producer_head" if this stage is an async producer
// - "consumer_head" if this stage reads from asynchronously written // - "consumer_head" if this stage reads from asynchronously written
// buffers. // buffers.
normalized_access_index = PrimExpr normalized_access_index =
is_unit_loop ? skewed_loop_var : skewed_loop_var + delta; is_unit_loop ? skewed_loop_var : skewed_loop_var + delta;
// Adjust the block predicate and the body according to the final loop // Adjust the block predicate and the body according to the final loop
...@@ -699,38 +947,76 @@ private: ...@@ -699,38 +947,76 @@ private:
Var loop_iter = Downcast<Var>(new_loop_var); Var loop_iter = Downcast<Var>(new_loop_var);
inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}}); inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
} }
new_block = Downcast<Block>(Substitute( new_block = Downcast<Block>(Substitute(
new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
if (predicate_condition_.defined()) {
BlockNode *n = new_block.CopyOnWrite();
n->body = IfThenElse(
Substitute(predicate_condition_,
{{pipeline_loop_->loop_var, normalized_access_index}}),
n->body);
}
if (pipeline_info_[block].async) { if (pipeline_info_[block].async) {
auto &local_state = async_states_local[stage]; auto &local_state = async_states_local[stage];
int commit_group_id = -1;
if (local_state.commit_groups.empty() || local_state.consumed) {
// consumed == true means there is already a consumer stage waiting
// for an eariler async operation of this stage. In such cases, we
// make multiple commit_queue for this stage.
commit_group_id = local_state.commit_groups.size();
local_state.commit_groups.push_back({new_blocks.size()});
} else {
// This is the case when one commit_queue groups multiple async
// blocks. with commit_queue(stage):
// async_scope:
// A_shared[...] = ...
// async_scope:
// B_shared[...] = ...
commit_group_id = local_state.commit_groups.size() - 1;
local_state.commit_groups.back().push_back(new_blocks.size());
}
for (auto write_region : new_block->writes) {
async_states[stage].dst_buffers.insert(write_region->buffer.get());
buffer_to_commit_group[write_region->buffer.get()] = commit_group_id;
}
local_state.producer_head = normalized_access_index; local_state.producer_head = normalized_access_index;
if (!local_state.predicate ||
ana_normalized.CanProve(local_state.predicate.value())) {
local_state.predicate = inbound;
} else if (local_state.predicate) {
local_state.predicate =
ana_normalized.Simplify(local_state.predicate.value() & inbound);
}
BlockNode *n = new_block.CopyOnWrite(); BlockNode *n = new_block.CopyOnWrite();
n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope, n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope,
1, n->body); 1, n->body);
} }
new_blocks.push_back({stage, order, inbound, new_block, new_blocks.push_back({stage, inbound, new_block, normalized_access_index,
normalized_access_index,
pipeline_info_[block].async}); pipeline_info_[block].async});
}
PopulateWaitCounts(new_blocks, &async_states_local); for (auto read_region : new_block->reads) {
for (auto kv : async_states) {
int producer_stage_id = kv.first;
if (producer_stage_id <= stage &&
kv.second.writes(read_region->buffer)) {
async_states_local[producer_stage_id].consumed = true;
}
}
}
}
auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local); PopulateWaitCounts(new_blocks, &ana_normalized, buffer_to_commit_group,
&async_states_local);
auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local,
&ana_normalized);
Stmt new_loop{nullptr}; Stmt new_loop{nullptr};
if (stmts.empty()) { if (stmts.empty()) {
return make_nop(); return make_nop();
} }
if (stmts.size() == 1) { if (stmts.size() == 1) {
new_loop = stmts[0]; new_loop = stmts[0];
} else { } else {
...@@ -738,22 +1024,26 @@ private: ...@@ -738,22 +1024,26 @@ private:
} }
if (!is_unit_loop) { if (!is_unit_loop) {
Map<String, Any> preserved_annotations;
for (const auto &kv : pipeline_loop_->annotations) {
const String &key = kv.first;
if (kv.first != tir::attr::software_pipeline_stage &&
kv.first != tir::attr::software_pipeline_order &&
kv.first != tir::attr::software_pipeline_async_stages) {
preserved_annotations.Set(key, kv.second);
}
}
new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent, new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
std::move(new_loop), std::nullopt, preserved_annotations); std::move(new_loop), std::nullopt, preserved_annotations_);
} }
// Update producer heads in the global async states. // Update producer heads in the global async states.
for (const auto &[stage_id, state] : async_states_local) { for (const auto &kv : async_states_local) {
async_states[stage_id].producer_head += extent; const int stage_id = kv.first;
const AsyncStateLocal &state = kv.second;
if (state.predicate && ana_normalized.CanProve(state.predicate.value()) &&
async_states[stage_id].producer_head) {
// Advance the "global" producer head if it is still valid and we know
// exactly how much we can increment
async_states[stage_id].producer_head =
async_states[stage_id].producer_head.value() + extent;
} else {
// Otherwise, invalidate the global producer head
async_states[stage_id].producer_head = std::nullopt;
}
} }
return BlockRealize({}, Bool(true), return BlockRealize({}, Bool(true),
...@@ -762,14 +1052,17 @@ private: ...@@ -762,14 +1052,17 @@ private:
arith::Analyzer analyzer_; arith::Analyzer analyzer_;
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>
&double_buffers_;
Array<Buffer> pipeline_allocs_; Array<Buffer> pipeline_allocs_;
For pipeline_loop_; For pipeline_loop_;
PipelineInfo pipeline_info_; PipelineInfo pipeline_info_;
PrimExpr predicate_condition_; const std::unordered_map<const VarNode *, FragmentInfo> &fragment_info_;
int max_stage_ = -1; int max_stage_ = -1;
Map<Buffer, Buffer> buffer_remap_; Map<Buffer, Buffer> buffer_remap_;
Array<Block> ordered_stmts_; Array<Block> ordered_stmts_;
std::map<int, AsyncStateGlobal> async_states; std::map<int, AsyncStateGlobal> async_states;
Map<String, ffi::Any> preserved_annotations_;
}; };
/*! /*!
...@@ -784,8 +1077,7 @@ void BuildDependencyGraph(const Array<Block> &blocks, ...@@ -784,8 +1077,7 @@ void BuildDependencyGraph(const Array<Block> &blocks,
ObjectPtrEqual> *dep_src2dst, ObjectPtrEqual> *dep_src2dst,
std::unordered_map<Block, Array<Block>, ObjectPtrHash, std::unordered_map<Block, Array<Block>, ObjectPtrHash,
ObjectPtrEqual> *dep_dst2src) { ObjectPtrEqual> *dep_dst2src) {
std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual> std::unordered_map<Var, Array<Block>> buffer_writers;
buffer_writers;
for (const Block &block : blocks) { for (const Block &block : blocks) {
for (const BufferRegion &read : block->reads) { for (const BufferRegion &read : block->reads) {
...@@ -816,6 +1108,7 @@ public: ...@@ -816,6 +1108,7 @@ public:
const Buffer &buffer = kv.second; const Buffer &buffer = kv.second;
injector.buffer_data_to_buffer_.Set(buffer->data, buffer); injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
injector.fragment_info_ = GetTensorCoreFragmentInfo(func->body);
return injector(func->body); return injector(func->body);
} }
...@@ -880,7 +1173,6 @@ private: ...@@ -880,7 +1173,6 @@ private:
// can be direct child of the for-loop. If the for-loop has BlockRealize as // can be direct child of the for-loop. If the for-loop has BlockRealize as
// its child, the pipeline body will be the child of the block. // its child, the pipeline body will be the child of the block.
Stmt pipeline_body{nullptr}; Stmt pipeline_body{nullptr};
PrimExpr predicate_condition{nullptr};
Array<Buffer> pipeline_allocs; Array<Buffer> pipeline_allocs;
if (const auto *realize = for_node->body.as<BlockRealizeNode>()) { if (const auto *realize = for_node->body.as<BlockRealizeNode>()) {
const auto &block = realize->block; const auto &block = realize->block;
...@@ -888,15 +1180,7 @@ private: ...@@ -888,15 +1180,7 @@ private:
ICHECK(buffer->IsInstance<BufferNode>()); ICHECK(buffer->IsInstance<BufferNode>());
buffer_data_to_buffer_.Set(buffer->data, buffer); buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
if (const auto *if_then_else = block->body.as<IfThenElseNode>()) {
ICHECK(!if_then_else->else_case.defined())
<< "Pipeline_Planning: Can't handle the body of the loop because "
"it is not a SeqStmt";
pipeline_body = if_then_else->then_case;
predicate_condition = if_then_else->condition;
} else {
pipeline_body = block->body; pipeline_body = block->body;
}
pipeline_allocs = block->alloc_buffers; pipeline_allocs = block->alloc_buffers;
} else { } else {
pipeline_body = for_node->body; pipeline_body = for_node->body;
...@@ -961,6 +1245,16 @@ private: ...@@ -961,6 +1245,16 @@ private:
} }
} }
Map<String, ffi::Any> preserved_annotations;
for (const auto &kv : op->annotations) {
const String &key = kv.first;
if (kv.first != tir::attr::software_pipeline_stage &&
kv.first != tir::attr::software_pipeline_order &&
kv.first != tir::attr::software_pipeline_async_stages) {
preserved_annotations.Set(key, kv.second);
}
}
for (size_t i = 0; i < pipeline_stages.size(); i++) { for (size_t i = 0; i < pipeline_stages.size(); i++) {
int stage = static_cast<int>(pipeline_stages[i]->value); int stage = static_cast<int>(pipeline_stages[i]->value);
bool is_async = bool is_async =
...@@ -974,10 +1268,9 @@ private: ...@@ -974,10 +1268,9 @@ private:
ValidatePipelineBody(pipeline_info, original_order); ValidatePipelineBody(pipeline_info, original_order);
// Step 4: Rewrite the pipeline body. // Step 4: Rewrite the pipeline body.
Stmt pipeline = Stmt pipeline = PipelineRewriter::Rewrite(
PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs, buffer_data_to_buffer_, double_buffers, pipeline_allocs,
GetRef<For>(op), pipeline_info, predicate_condition) GetRef<For>(op), pipeline_info, fragment_info_, preserved_annotations);
.BuildPipeline();
if (const auto *realize = op->body.as<BlockRealizeNode>()) { if (const auto *realize = op->body.as<BlockRealizeNode>()) {
const auto &block = realize->block; const auto &block = realize->block;
...@@ -988,17 +1281,44 @@ private: ...@@ -988,17 +1281,44 @@ private:
return pipeline; return pipeline;
} }
/*!
* \brief Add buffer allocations to a block and update the write region of the
* block. \param n The block pointer to which the buffer allocations are
* added. \param alloc_buffers The buffer allocations to be added.
*/
void AddAllocBuffers(BlockNode *n, const Array<Buffer> alloc_buffers) {
for (const Buffer &alloc_buffer : alloc_buffers) {
n->alloc_buffers.push_back(alloc_buffer);
Region region;
region.reserve(alloc_buffer->shape.size());
for (const PrimExpr &dim : alloc_buffer->shape) {
region.push_back(Range::FromMinExtent(0, dim));
}
n->writes.push_back(BufferRegion(alloc_buffer, region));
}
}
Stmt VisitStmt_(const BlockNode *op) final { Stmt VisitStmt_(const BlockNode *op) final {
for (const auto &buffer : op->alloc_buffers) { for (const auto &buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer); buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
auto it = op->annotations.find(tir::attr::double_buffer_scope);
if (it != op->annotations.end()) {
int buffer_index = Downcast<Integer>((*it).second).IntValue();
CHECK(buffer_index >= 0 &&
static_cast<size_t>(buffer_index) < op->writes.size())
<< "ValueError: Index of the buffer exceeds the size of the write "
"regions of the block. ("
<< buffer_index << " vs. " << op->writes.size() << ")";
double_buffers.insert(op->writes[buffer_index]->buffer);
}
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op)); Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
for (const auto &buffer : op->alloc_buffers) { for (const auto &buffer : op->alloc_buffers) {
buffer_data_to_buffer_.erase(buffer->data); buffer_data_to_buffer_.erase(buffer->data);
} }
return std::move(block); return block;
} }
bool HasPipelineAnnotation(const ForNode *op) const { bool HasPipelineAnnotation(const ForNode *op) const {
...@@ -1011,19 +1331,23 @@ private: ...@@ -1011,19 +1331,23 @@ private:
} }
if (has_stage) { if (has_stage) {
LOG(FATAL) LOG(FATAL)
<< "ValueError: Stage of the software pipeline is not defined."; << "ValueError: Order of the software pipeline is not defined.";
} }
if (has_order) { if (has_order) {
LOG(FATAL) LOG(FATAL)
<< "ValueError: Order of the software pipeline is not defined."; << "ValueError: Stage of the software pipeline is not defined.";
} }
return false; return false;
} }
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
std::unordered_map<const VarNode *, FragmentInfo> fragment_info_;
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> double_buffers;
Optional<String> global_symbol_; Optional<String> global_symbol_;
}; };
} // namespace software_pipeline
/*! /*!
* \brief Transform annotated loops into pipelined one that parallelize * \brief Transform annotated loops into pipelined one that parallelize
* producers and consumers. \return The IR transform pass. * producers and consumers. \return The IR transform pass.
...@@ -1032,7 +1356,7 @@ tir::transform::Pass InjectSoftwarePipeline() { ...@@ -1032,7 +1356,7 @@ tir::transform::Pass InjectSoftwarePipeline() {
using namespace tir::transform; using namespace tir::transform;
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto *fptr = f.CopyOnWrite(); auto *fptr = f.CopyOnWrite();
fptr->body = PipelineInjector::Inject(f); fptr->body = software_pipeline::PipelineInjector::Inject(f);
fptr->body = ConvertSSA(std::move(fptr->body)); fptr->body = ConvertSSA(std::move(fptr->body));
return f; return f;
}; };
......
import torch import torch
import torch.backends
import tilelang.testing import tilelang.testing
from tilelang import tvm as tvm from tilelang import tvm as tvm
import tilelang.language as T import tilelang.language as T
......
...@@ -9,6 +9,7 @@ def _check(original, transformed): ...@@ -9,6 +9,7 @@ def _check(original, transformed):
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = tl.transform.InjectSoftwarePipeline()(mod) mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tl.transform.Simplify()(mod) mod = tl.transform.Simplify()(mod)
print(mod["main"])
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"),
True) True)
...@@ -41,30 +42,22 @@ def test_trival_pipeline(): ...@@ -41,30 +42,22 @@ def test_trival_pipeline():
@T.prim_func @T.prim_func
def expected(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")) -> None: def expected(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")) -> None:
for tx in T.thread_binding(16, thread="threadIdx.x"): for tx in T.thread_binding(16, thread="threadIdx.x"):
with T.block(): with T.block(""):
T.reads(A[tx, 0]) T.reads(A[tx, 0])
T.writes(C[tx, 0]) T.writes(C[tx, 0])
B = T.alloc_buffer([2, 16, 1], dtype="float32", scope="shared") B = T.alloc_buffer((2, 16, 1), scope="shared")
with T.block(): with T.block(""):
T.reads(A[tx, 0]) T.reads(A[tx, 0])
T.writes(B[0, tx, 0]) T.writes(B[0, tx, 0])
B[0, tx, 0] = A[tx, 0] * T.float32(2) B[0, tx, 0] = A[tx, 0] * T.float32(2.0)
with T.block():
T.reads(A[tx, 1:1], B[0:2, tx, 0])
T.writes(B[1:1, tx, 0], C[tx, 0:0])
for i in range(0):
with T.block(""): with T.block(""):
T.reads(A[tx, i + 1]) T.reads()
T.writes(B[i + 1, tx, 0]) T.writes()
B[i + 1, tx, 0] = A[tx, i + 1] * T.float32(2) T.evaluate(0)
with T.block(""): with T.block(""):
T.reads(B[i, tx, 0])
T.writes(C[tx, i])
C[tx, i] = B[i, tx, 0] + T.float32(1)
with T.block():
T.reads(B[0, tx, 0]) T.reads(B[0, tx, 0])
T.writes(C[tx, 0]) T.writes(C[tx, 0])
C[tx, 0] = B[0, tx, 0] + T.float32(1) C[tx, 0] = B[0, tx, 0] + T.float32(1.0)
_check(before, expected) _check(before, expected)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment