Unverified Commit 49d5d80e authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Pipeline] Phaseout fragment and double buffer info from pipeline pass (#711)

* Update submodule 'tvm' to commit e11521e6936a827efa334588d29571fbb4620107

* Refactor inject_pipeline.cc to enhance pipeline body rewriting and condition handling

- Introduced a new function to replace IfThenElse nodes with their then_case while preserving attributes.
- Streamlined the PipelineBodyRewriter to improve buffer access rewriting and async state management.
- Enhanced the handling of pipeline loop conditions and added support for predicate conditions in the pipeline body.
- Removed obsolete code and improved overall code clarity and maintainability.

* lint fix

* Refactor return statements in inject_pipeline.cc to remove unnecessary std::move calls

- Updated return statements in multiple methods to return objects directly instead of using std::move, improving code clarity and potentially avoiding unnecessary moves.
- Ensured consistent handling of BufferStore and BufferLoad nodes during pipeline transformations.

* test fix
parent 64bd0651
...@@ -22,7 +22,6 @@ ...@@ -22,7 +22,6 @@
* \brief Transform annotated loops into pipelined one that parallelize * \brief Transform annotated loops into pipelined one that parallelize
* producers and consumers * producers and consumers
*/ */
#include <tvm/ffi/reflection/registry.h>
#include <tvm/target/target.h> #include <tvm/target/target.h>
#include <tvm/tir/builtin.h> #include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
...@@ -83,138 +82,6 @@ struct BufferAccessInfo { ...@@ -83,138 +82,6 @@ struct BufferAccessInfo {
int use = -1; // the last using stage of the buffer int use = -1; // the last using stage of the buffer
}; };
class PipelineOpaqueAccessRewriter {
public:
/*!
* \brief Constructor
* \param buffer_data_to_buffer The map from buffer data to buffer.
* \param buffer_remap The map from original buffer to the buffer with updated
* shape for multi-versioning in the software pipeline. \param pipeline_loop
* The original loop to be software pipelined. \param fragment_info
* Information about tensor core fragment
*/
PipelineOpaqueAccessRewriter(
const Map<Var, Buffer> &buffer_data_to_buffer,
const Map<Buffer, Buffer> &buffer_remap, const For &pipeline_loop,
const std::unordered_map<const VarNode *, FragmentInfo> &fragment_info)
: buffer_data_to_buffer_(buffer_data_to_buffer),
buffer_remap_(buffer_remap), pipeline_loop_(pipeline_loop),
fragment_info_(fragment_info) {}
PrimExpr Rewrite(const Call &call) {
// Intrinsic calls should be handled explicitly here as they are opaque
// accesses to buffer.
static const auto &load_matrix_sync = builtin::tvm_load_matrix_sync();
static const auto &store_matrix_sync = builtin::tvm_store_matrix_sync();
static const auto &mma_sync = builtin::tvm_mma_sync();
static const auto &access_ptr = builtin::tvm_access_ptr();
static const auto &ptx_ldmatrix = builtin::ptx_ldmatrix();
static const auto &ptx_mma = builtin::ptx_mma();
if (call->op.same_as(load_matrix_sync) ||
call->op.same_as(store_matrix_sync)) {
const Buffer &buffer =
buffer_data_to_buffer_.at(Downcast<Var>(call->args[0]));
auto it = buffer_remap_.find(buffer);
if (it != buffer_remap_.end()) {
Array<PrimExpr> new_args = call->args;
const Buffer &new_buffer = (*it).second;
new_args.Set(
4, RewriteWmmaFragmentIndex(buffer, new_buffer, call->args[4]));
return Call(call->dtype, call->op, new_args, call->span);
}
} else if (call->op.same_as(mma_sync)) {
Array<PrimExpr> new_args = call->args;
for (int i = 0; i < 4; i++) {
const Var &buffer_var = Downcast<Var>(call->args[i * 2]);
const PrimExpr &index = call->args[i * 2 + 1];
const Buffer &buffer = buffer_data_to_buffer_.at(buffer_var);
auto it = buffer_remap_.find(buffer);
if (it != buffer_remap_.end()) {
PrimExpr new_index =
RewriteWmmaFragmentIndex(buffer, (*it).second, index);
new_args.Set(i * 2 + 1, new_index);
}
}
return Call(call->dtype, call->op, new_args, call->span);
} else if (call->op.same_as(access_ptr)) {
return RewriteBufferAccess(call, {1});
} else if (call->op.same_as(ptx_mma)) {
return RewriteBufferAccess(call, {6, 8, 10});
} else if (call->op.same_as(ptx_ldmatrix)) {
return RewriteBufferAccess(call, {3});
}
return call;
}
private:
int GetWmmaFragmentSize(const Buffer &buffer) {
auto it = fragment_info_.find(buffer->data.get());
ICHECK(it != fragment_info_.end());
const FragmentInfo &info = (*it).second;
return info.GetSize();
}
PrimExpr RewriteWmmaFragmentIndex(const Buffer &old_buffer,
const Buffer &new_buffer,
const PrimExpr &old_index) {
PrimExpr new_buffer_offset = old_index;
int fragment_size = GetWmmaFragmentSize(old_buffer);
PrimExpr offset = floordiv(
foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), old_buffer->shape),
fragment_size);
new_buffer_offset +=
floormod(pipeline_loop_->loop_var - pipeline_loop_->min,
new_buffer->shape[0]) *
offset;
return new_buffer_offset;
}
PrimExpr RewriteBufferAccess(const Call &call,
const std::vector<int> arg_indices) {
auto product = [](const Array<PrimExpr> &input) {
return foldl(
[](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), input);
};
Array<PrimExpr> new_args = call->args;
for (int i : arg_indices) {
const Buffer &buffer =
buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
auto it = buffer_remap_.find(buffer);
if (it != buffer_remap_.end()) {
const Buffer &new_buffer = (*it).second;
const PrimExpr &old_index = call->args[i + 1];
PrimExpr offset;
if (new_buffer->strides.empty()) {
offset = product(buffer->shape);
} else {
offset = new_buffer->strides[0];
}
if (buffer.scope() == "m16n8k8.matrixA" ||
buffer.scope() == "m16n8k8.matrixB") {
// mma scope size will shrink by warp size
// @see transform_mma_buffer_layout
ICHECK_EQ(Downcast<IntImm>(floormod(offset, 32))->value, 0)
<< "mma scope size should be multiple of warp size";
offset = floordiv(offset, 32);
}
PrimExpr new_index =
old_index +
floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
new_args.Set(i + 1, new_index);
}
}
return Call(call->dtype, call->op, new_args, call->span);
}
const Map<Var, Buffer> &buffer_data_to_buffer_;
const Map<Buffer, Buffer> &buffer_remap_;
const For &pipeline_loop_;
const std::unordered_map<const VarNode *, FragmentInfo> &fragment_info_;
};
/*! /*!
* \brief Rewriter for the body of the software pipeline. This pass inserts * \brief Rewriter for the body of the software pipeline. This pass inserts
* `floormod` to indices of the remapped buffer to select the version * `floormod` to indices of the remapped buffer to select the version
...@@ -231,19 +98,14 @@ public: ...@@ -231,19 +98,14 @@ public:
* Whether all versions the buffers in the software pipeline are accessed. * Whether all versions the buffers in the software pipeline are accessed.
* This will be used to update block access region. In the prologue and * This will be used to update block access region. In the prologue and
* epilogue of a two-stage software pipeline, only one version of these * epilogue of a two-stage software pipeline, only one version of these
* buffers are accessed. \param fragment_info Information about tensor core * buffers are accessed.
* fragment
*/ */
PipelineBodyRewriter( PipelineBodyRewriter(const Map<Var, Buffer> &buffer_data_to_buffer,
const Map<Var, Buffer> &buffer_data_to_buffer, const Map<Buffer, Buffer> &buffer_remap,
const Map<Buffer, Buffer> &buffer_remap, For pipeline_loop, For pipeline_loop, bool access_all_versions)
bool access_all_versions,
const std::unordered_map<const VarNode *, FragmentInfo> &fragment_info)
: buffer_data_to_buffer_(buffer_data_to_buffer), : buffer_data_to_buffer_(buffer_data_to_buffer),
buffer_remap_(buffer_remap), pipeline_loop_(pipeline_loop), buffer_remap_(buffer_remap), pipeline_loop_(pipeline_loop),
access_all_versions_(access_all_versions), access_all_versions_(access_all_versions) {}
opaque_access_rewriter_(buffer_data_to_buffer_, buffer_remap_,
pipeline_loop_, fragment_info) {}
private: private:
BufferRegion BufferRegion
...@@ -267,6 +129,36 @@ private: ...@@ -267,6 +129,36 @@ private:
return buffer_region; return buffer_region;
} }
PrimExpr RewriteBufferAccess(const Call &call,
const std::vector<int> arg_indices) {
auto product = [](const Array<PrimExpr> &input) {
return foldl(
[](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
make_const(DataType::Int(32), 1), input);
};
Array<PrimExpr> new_args = call->args;
for (int i : arg_indices) {
const Buffer &buffer =
buffer_data_to_buffer_.at(Downcast<Var>(call->args[i]));
auto it = buffer_remap_.find(buffer);
if (it != buffer_remap_.end()) {
const Buffer &new_buffer = (*it).second;
const PrimExpr &old_index = call->args[i + 1];
PrimExpr offset;
if (new_buffer->strides.empty()) {
offset = product(buffer->shape);
} else {
offset = new_buffer->strides[0];
}
PrimExpr new_index =
old_index +
floormod(pipeline_loop_->loop_var, new_buffer->shape[0]) * offset;
new_args.Set(i + 1, new_index);
}
}
return Call(call->dtype, call->op, new_args, call->span);
}
Stmt VisitStmt_(const BlockNode *op) final { Stmt VisitStmt_(const BlockNode *op) final {
for (const Buffer &alloc_buffer : op->alloc_buffers) { for (const Buffer &alloc_buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer); buffer_data_to_buffer_.Set(alloc_buffer->data, alloc_buffer);
...@@ -317,14 +209,16 @@ private: ...@@ -317,14 +209,16 @@ private:
PrimExpr VisitExpr_(const CallNode *op) final { PrimExpr VisitExpr_(const CallNode *op) final {
Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op)); Call call = Downcast<Call>(StmtExprMutator::VisitExpr_(op));
return opaque_access_rewriter_.Rewrite(call); if (call->op.same_as(builtin::tvm_access_ptr())) {
return RewriteBufferAccess(call, {1});
}
return call;
} }
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, Buffer> buffer_remap_; Map<Buffer, Buffer> buffer_remap_;
For pipeline_loop_; For pipeline_loop_;
bool access_all_versions_; bool access_all_versions_;
PipelineOpaqueAccessRewriter opaque_access_rewriter_;
}; };
/*! /*!
...@@ -333,35 +227,12 @@ private: ...@@ -333,35 +227,12 @@ private:
*/ */
class PipelineRewriter : public StmtExprMutator { class PipelineRewriter : public StmtExprMutator {
public: public:
static Stmt Rewrite( PipelineRewriter(Map<Var, Buffer> buffer_data_to_buffer,
Map<Var, Buffer> buffer_data_to_buffer, const Array<Buffer> &pipeline_allocs,
const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> const For &pipeline_loop, const PipelineInfo &pipeline_info)
&double_buffers,
const Array<Buffer> pipeline_allocs, const For &pipeline_loop,
const PipelineInfo &pipeline_info,
const std::unordered_map<const VarNode *, FragmentInfo> &fragment_info,
const Map<String, ffi::Any> preserved_annotations) {
PipelineRewriter rewriter(buffer_data_to_buffer, double_buffers,
pipeline_allocs, pipeline_loop, pipeline_info,
fragment_info, preserved_annotations);
return rewriter.BuildPipeline();
}
private:
PipelineRewriter(
Map<Var, Buffer> buffer_data_to_buffer,
const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>
&double_buffers,
const Array<Buffer> &pipeline_allocs, const For &pipeline_loop,
const PipelineInfo &pipeline_info,
const std::unordered_map<const VarNode *, FragmentInfo> &fragment_info,
const Map<String, ffi::Any> preserved_annotations)
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)), : buffer_data_to_buffer_(std::move(buffer_data_to_buffer)),
double_buffers_(double_buffers), pipeline_allocs_(pipeline_allocs), pipeline_allocs_(pipeline_allocs), pipeline_loop_(pipeline_loop),
pipeline_loop_(pipeline_loop), pipeline_info_(pipeline_info), pipeline_info_(pipeline_info) {}
fragment_info_(fragment_info),
preserved_annotations_(preserved_annotations) {}
Stmt BuildPipeline() { Stmt BuildPipeline() {
// Step 1: Analyze accesses to the buffers in the pipeline and compute the // Step 1: Analyze accesses to the buffers in the pipeline and compute the
...@@ -376,36 +247,61 @@ private: ...@@ -376,36 +247,61 @@ private:
} }
ordered_stmts_.resize(pipeline_info_.size()); ordered_stmts_.resize(pipeline_info_.size());
for (const auto &pair : pipeline_info_) { for (const auto &[block, anno] : pipeline_info_) {
const Block &block = pair.first; ordered_stmts_.Set(anno.order, block);
int order = pair.second.order;
ordered_stmts_.Set(order, block);
} }
// Step 2: Emit the pipeline prologue, body and epilogue. for (const Block &block : ordered_stmts_) {
Stmt prologue = int stage = pipeline_info_[block].stage;
EmitImpl(pipeline_loop_->min, pipeline_loop_->min + max_stage_, true); if (pipeline_info_[block].async) {
Stmt body = EmitImpl(pipeline_loop_->min + max_stage_, auto &state = async_states[stage];
pipeline_loop_->min + pipeline_loop_->extent, false); state.producer_head = pipeline_loop_->min - 1;
// introduce extra lowerbound when the loop length is smaller than num for (auto write_region : block->writes) {
// stages to ensure the epilogue interval do not overlap the prologue auto buffer = write_region->buffer;
// interval. state.dst_buffers.insert(buffer.get());
PrimExpr epigogue_start = pipeline_loop_->min + pipeline_loop_->extent; if (buffer_remap_.count(buffer))
Optional<PrimExpr> extra_epilogue_lower_bound = std::nullopt; state.dst_buffers.insert(buffer_remap_[buffer].get());
if (max_stage_ > 1 && }
!analyzer_.CanProveGreaterEqual(pipeline_loop_->extent, max_stage_)) { }
if (is_const_int(epigogue_start)) { }
epigogue_start = max(epigogue_start, pipeline_loop_->min + max_stage_); std::unordered_set<int> consumed;
} else { for (const Block &block : ordered_stmts_) {
// for dynamic case, introduce extra lowerbound as loop predicate int stage = pipeline_info_[block].stage;
// to ensure the epilogue part unrollable. if (pipeline_info_[block].async) {
extra_epilogue_lower_bound = pipeline_loop_->min + max_stage_; auto &state = async_states[stage];
if (state.commit_groups.empty() || consumed.count(stage)) {
state.commit_groups.push_back({});
}
state.commit_groups.back().push_back(pipeline_info_[block].order);
consumed.erase(stage);
for (auto write_region : block->writes) {
auto buffer = buffer_remap_.count(write_region->buffer)
? buffer_remap_[write_region->buffer]
: write_region->buffer;
state.buffer_to_commit_group_[buffer.get()] =
state.commit_groups.size() - 1;
}
}
for (auto read_region : block->reads) {
for (const auto &[producer_stage_id, producer_state] : async_states) {
if (producer_stage_id <= stage &&
producer_state.writes(read_region->buffer)) {
consumed.insert(producer_stage_id);
}
}
} }
} }
Stmt epilogue =
EmitImpl(epigogue_start, // Step 2: Emit the pipeline prologue, body and epilogue.
pipeline_loop_->min + pipeline_loop_->extent + max_stage_, Stmt prologue = EmitImpl(pipeline_loop_->min,
true, extra_epilogue_lower_bound); pipeline_loop_->min + max_stage_, true, true);
Stmt body =
EmitImpl(pipeline_loop_->min + max_stage_,
pipeline_loop_->min + pipeline_loop_->extent, false, false);
Stmt epilogue = EmitImpl(
pipeline_loop_->min + pipeline_loop_->extent,
pipeline_loop_->min + pipeline_loop_->extent + max_stage_, true, true);
SeqStmt stmt = SeqStmt({prologue, body, epilogue}); SeqStmt stmt = SeqStmt({prologue, body, epilogue});
...@@ -550,9 +446,6 @@ private: ...@@ -550,9 +446,6 @@ private:
num_versions--; num_versions--;
} }
} }
if (num_versions == 1 && double_buffers_.count(buffer)) {
num_versions = 2;
}
return num_versions; return num_versions;
} }
...@@ -584,15 +477,16 @@ private: ...@@ -584,15 +477,16 @@ private:
// valid, it is the "sum of extents of loops that have been executed" - 1, // valid, it is the "sum of extents of loops that have been executed" - 1,
// e.g. for epilogue it is prologue extent + body extent - 1. This is only // e.g. for epilogue it is prologue extent + body extent - 1. This is only
// needed to compute wait count for epilogue without async producers. // needed to compute wait count for epilogue without async producers.
Optional<PrimExpr> producer_head{PrimExpr(-1)}; PrimExpr producer_head;
std::vector<std::vector<int>> commit_groups;
std::unordered_map<const BufferNode *, int> buffer_to_commit_group_;
bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; } bool writes(Buffer buf) const { return dst_buffers.count(buf.get()) > 0; }
}; };
// Per-stage states that are local to each of pipeline prologue, body, and // Per-stage states that are local to each of pipeline prologue, body, and
// epilogue. // epilogue.
struct AsyncStateLocal { struct AsyncStateLocal {
struct { struct PendingWait {
// The index into a list of blocks, where async_wait_queue should be // The index into a list of blocks, where async_wait_queue should be
// attached at the beginning. // attached at the beginning.
int insert_before; int insert_before;
...@@ -601,198 +495,76 @@ private: ...@@ -601,198 +495,76 @@ private:
PrimExpr wait_count{nullptr}; PrimExpr wait_count{nullptr};
bool valid() const { return wait_count.defined(); } bool valid() const { return wait_count.defined(); }
} pending_wait; };
// Destination buffers of async operations that have been encountered so far std::vector<PendingWait> pending_waits;
// in the loop
//
// for (size_t i = 0; i < new_blocks.size(); ++i) {
// ...
// }
//
// This is for tracking which async operations have been issued at the
// "current" iteration, up until a point where we encounter a consumer of
// async result buffers. This is used to decide if the producer_head of each
// buffer points to a copy written in the current or previous iteration.
std::unordered_set<const BufferNode *> seen;
// A symbolic expression representing the index the latest async operation // A symbolic expression representing the index the latest async operation
// associated with this stage has written into, at the "current" iteration. // associated with this stage has written into, at the "current" iteration.
Optional<PrimExpr> producer_head; Optional<PrimExpr> producer_head;
// The predicate of BlockRealize containing the async operation of this
// stage.
Optional<PrimExpr> predicate;
// Indices into a list of blocks, where async_commit_queue scope should be
// attached. If multiple async producers are interleaved with their consumer
// in between, we need separate async_commit_queue for each producer. Thus,
// we need multiple sets of indices.
std::vector<std::vector<size_t>> commit_groups;
// This is set to true when we reach a stage that consumes this async stage.
bool consumed{false};
}; };
/*! Structure holding intermediate information for pipeline loop rewriting. */ /*! Structure holding intermediate information for pipeline loop rewriting. */
struct RewrittenBlockInfo { struct RewrittenBlockInfo {
int stage; int stage;
int order;
PrimExpr predicate; PrimExpr predicate;
Block block; Block block;
PrimExpr access_index; PrimExpr access_index;
bool is_async; bool is_async;
}; };
// Determine where to insert async_wait and the corresponding wait count. void PopulateWaitCounts(const std::vector<RewrittenBlockInfo> &new_blocks,
void PopulateWaitCounts( std::map<int, AsyncStateLocal> *async_states_local) {
const std::vector<RewrittenBlockInfo> &new_blocks,
arith::Analyzer *ana_normalized,
const std::unordered_map<const BufferNode *, int> &buffer_to_commit_group,
std::map<int, AsyncStateLocal> *async_states_local) {
for (size_t i = 0; i < new_blocks.size(); ++i) { for (size_t i = 0; i < new_blocks.size(); ++i) {
if (new_blocks[i].is_async) {
// Record the fact that we have encountered these write buffers.
for (auto write_region : new_blocks[i].block->writes) {
(*async_states_local)[new_blocks[i].stage].seen.insert(
write_region->buffer.get());
}
}
int producer_stage_idx = -1; int producer_stage_idx = -1;
for (auto read_region : new_blocks[i].block->reads) { for (auto read_region : new_blocks[i].block->reads) {
for (auto kv : async_states) { for (const auto &[stage, state] : async_states) {
if (kv.first <= new_blocks[i].stage && if (stage <= new_blocks[i].stage &&
kv.second.writes(read_region->buffer)) { state.writes(read_region->buffer)) {
// Found an earlier stage where read_region->buffer was // Found an earlier stage where read_region->buffer was
// asynchronously written // asynchronously written
ICHECK(producer_stage_idx == -1 || producer_stage_idx == kv.first) ICHECK(producer_stage_idx == -1 || producer_stage_idx == stage)
<< "A dependency on multiple async stages is not supported"; << "A dependency on multiple async stages is not supported";
producer_stage_idx = kv.first; producer_stage_idx = stage;
} }
} }
} }
if (producer_stage_idx == -1) if (producer_stage_idx == -1)
continue; continue;
const auto &state = async_states[producer_stage_idx];
// The following logic has become complicated to handle case like this:
//
// for i in range(13):
// # Stage 0
// async_commit_queue(0):
// async_scope:
// A_shared[(i + 3) % 4] = A[...]
//
//
// # Stage 1
// async_wait_queue(0, 5):
// compute(A_shared[i], B_shared[i])
//
// # Stage 0
// async_commit_queue(0)
// async_scope:
// B_shared[(i + 3) % 4] = B[...]
//
//
// Here, multiple async producers in the same stage are interleaved with
// their consumer in between. Since each buffer is associated with
// different commit groups, the wait_count before the consumer should be
// bigger than the simpler case:
//
// for i in range(13):
// # Stage 0
// async_commit_queue(0):
// async_scope:
// A_shared[(i + 3) % 4] = A[...]
// B_shared[(i + 3) % 4] = B[...]
//
// # Stage 1
// async_wait_queue(0, 3):
// compute(A_shared[i], B_shared[i])
//
// The correct wait_count can be determined by considering each commit
// group separately, and summing "per-commit" wait_counts.
//
// From A_shared's perspective, it allows for (i + 3) - i async commit
// groups to be in flight while from B_shared's perspective, the producer
// head at compute points to the copy done by the previous iteration, so
// its wait_count is calculated as ((i - 1) + 3) - i. The sum of the two
// wait_counts gives 5.
// print async_states_local
auto &dep_local_state = (*async_states_local)[producer_stage_idx]; auto &dep_local_state = (*async_states_local)[producer_stage_idx];
const auto num_commit_group = dep_local_state.commit_groups.size(); PrimExpr in_flight_cnt = 0;
std::vector<Optional<PrimExpr>> producer_head_per_commit; for (const auto &group : state.commit_groups) {
PrimExpr consumer_head = new_blocks[i].access_index;
auto add_unique_producer_head = PrimExpr producer_head;
[&](const Optional<PrimExpr> &producer_head) { if (dep_local_state.producer_head.defined()) {
// if producer_head already in producer_head_per_commit, return producer_head = dep_local_state.producer_head.value();
for (const auto &head : producer_head_per_commit) { // if the group is after the wait point, minus by 1
if (StructuralEqual()(head, producer_head)) { if (group.front() > new_blocks[i].order)
return; producer_head -= 1;
} } else {
} producer_head = state.producer_head;
producer_head_per_commit.push_back(producer_head);
};
if (num_commit_group == 0) {
// Epilogue, no async producer. Since "local" producer_head is not
// available, use "global" producer_head.
ICHECK(!dep_local_state.producer_head);
add_unique_producer_head(
async_states[producer_stage_idx].producer_head);
} else {
ICHECK(dep_local_state.producer_head);
std::vector<bool> need_wait_count(num_commit_group, true);
for (auto read_region : new_blocks[i].block->reads) {
if (!async_states[producer_stage_idx].writes(read_region->buffer))
continue;
auto commit_group_id =
buffer_to_commit_group.at(read_region->buffer.get());
if (!need_wait_count[commit_group_id])
continue;
if (!dep_local_state.seen.count(read_region->buffer.get())) {
// Multiple async producers interleaved: The most recent async write
// is from the previous iteration. This is the B_shared case above.
add_unique_producer_head(dep_local_state.producer_head.value() - 1);
} else {
// Normal case
add_unique_producer_head(dep_local_state.producer_head.value());
}
need_wait_count[commit_group_id] = false;
} }
in_flight_cnt += producer_head - consumer_head;
} }
auto wait_count = [=, &ana_normalized]() { // We can relax the in-flight-count by the number of independent commit.
auto sum = PrimExpr(0); std::unordered_set<int> dependent_groups;
for (const auto &producer_head : producer_head_per_commit) { for (const auto &read_region : new_blocks[i].block->reads) {
if (producer_head && if (state.buffer_to_commit_group_.count(read_region->buffer.get()))
ana_normalized->CanProve(producer_head.value() >= 0)) { dependent_groups.insert(
// Here, new_blocks[i].access_index corresponds to "consumer_head". state.buffer_to_commit_group_.at(read_region->buffer.get()));
// The difference of producer_head and consumer_head is precisely }
// the number of async commit groups that can still be in flight for (int i = int(state.commit_groups.size()) - 1; i >= 0; i--) {
// after this wait. if (dependent_groups.count(i) == 0)
sum += analyzer_.Simplify(producer_head.value() - in_flight_cnt += 1;
new_blocks[i].access_index); else
} else { break; // stop relaxing
// The precise count cannot be determined, give up.
return PrimExpr(0);
}
}
return sum;
}();
auto &pending_wait = dep_local_state.pending_wait;
if (!pending_wait.valid()) {
pending_wait = {static_cast<int>(i), wait_count};
} else if (analyzer_.CanProve(wait_count < pending_wait.wait_count)) {
// Coalesce multiple wait_queue if the later one allows fewer in-flight
// ops.
pending_wait = {pending_wait.insert_before, wait_count};
} }
in_flight_cnt = analyzer_.Simplify(in_flight_cnt);
dep_local_state.pending_waits.push_back(
{static_cast<int>(i), in_flight_cnt});
} }
} }
...@@ -800,85 +572,38 @@ private: ...@@ -800,85 +572,38 @@ private:
// statements with async scopes (if any). // statements with async scopes (if any).
Array<Stmt> CompletePipelineLoopStatements( Array<Stmt> CompletePipelineLoopStatements(
const std::vector<RewrittenBlockInfo> &blocks, const std::vector<RewrittenBlockInfo> &blocks,
const std::map<int, AsyncStateLocal> &async_states_local, const std::map<int, AsyncStateLocal> &async_states_local) const {
arith::Analyzer *ana_normalized) const {
std::vector<RewrittenBlockInfo> new_blocks = blocks; std::vector<RewrittenBlockInfo> new_blocks = blocks;
std::vector<int> commit_group_indices(new_blocks.size(), -1);
for (const auto &[stage_id, state] : async_states_local) { for (const auto &[stage_id, state] : async_states_local) {
if (!state.commit_groups.empty()) { for (const auto &pw : state.pending_waits) {
for (size_t i = 0; i < state.commit_groups.size(); ++i) { auto &block = new_blocks[pw.insert_before].block;
for (size_t j = 0; j < state.commit_groups[i].size(); ++j) { BlockNode *n = block.CopyOnWrite();
ICHECK(state.commit_groups[i][0] + j < new_blocks.size()); auto zero = make_zero(DataType::Int(32));
commit_group_indices[state.commit_groups[i][0] + j] = stage_id; n->body = AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
} AttrStmt(zero, tir::attr::async_wait_inflight_count,
} pw.wait_count, n->body));
} }
}
if (state.pending_wait.valid()) { // mark the last async stmt as commit
auto attach_wait_scope = [&new_blocks](int i, int stage_id, std::unordered_set<int> commit_group_indices;
PrimExpr wait_count) { for (const auto &[stage_id, state] : async_states) {
auto &block = new_blocks[i].block; for (size_t i = 0; i < state.commit_groups.size(); ++i) {
BlockNode *n = block.CopyOnWrite(); commit_group_indices.insert(state.commit_groups[i].back());
auto zero = make_zero(DataType::Int(32));
n->body =
AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
AttrStmt(zero, tir::attr::async_wait_inflight_count,
wait_count, n->body));
};
if (state.predicate &&
!ana_normalized->CanProve(state.predicate.value())) {
// If the async operation that this wait_queue is waiting on is
// predicated, and we cannot prove that the predicate is always true,
// the precise wait count is only valid at iterations where the
// predicate is true;
auto wait_count =
Call(DataType::Int(32), builtin::if_then_else(),
{state.predicate.value(), state.pending_wait.wait_count, 0});
attach_wait_scope(state.pending_wait.insert_before, stage_id,
wait_count);
} else {
attach_wait_scope(state.pending_wait.insert_before, stage_id,
state.pending_wait.wait_count);
}
} }
} }
Array<Stmt> stmts; Array<Stmt> stmts;
for (size_t i = 0; i < new_blocks.size();) { for (size_t i = 0; i < new_blocks.size(); i++) {
if (commit_group_indices[i] == -1) { Block block = new_blocks[i].block;
// A synchrnous block, not part of any commit group if (commit_group_indices.count(new_blocks[i].order)) {
stmts.push_back( auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
BlockRealize({}, new_blocks[i].predicate, new_blocks[i].block)); tir::attr::async_commit_queue_scope,
++i; new_blocks[i].stage, block->body);
} else { block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
Array<Stmt> group_bodies;
auto stage_id = commit_group_indices[i];
auto predicate = new_blocks[i].predicate;
for (; i < commit_group_indices.size() &&
commit_group_indices[i] == stage_id;
++i) {
ICHECK(tvm::StructuralEqual()(predicate, new_blocks[i].predicate))
<< "Predicates in the same stage are expected to be identical";
group_bodies.push_back(new_blocks[i].block->body);
}
if (group_bodies.size() > 1) {
auto merged_bodies = SeqStmt(group_bodies);
group_bodies.clear();
group_bodies.push_back(merged_bodies);
}
for (auto body : group_bodies) {
auto commit_queue_scope =
AttrStmt(make_zero(DataType::Int(32)),
tir::attr::async_commit_queue_scope, stage_id, body);
auto new_block =
MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
stmts.push_back(BlockRealize({}, predicate, new_block));
}
} }
stmts.push_back(BlockRealize({}, new_blocks[i].predicate, block));
} }
return stmts; return stmts;
...@@ -889,21 +614,16 @@ private: ...@@ -889,21 +614,16 @@ private:
* \param start The start of the range * \param start The start of the range
* \param end The end of the range * \param end The end of the range
* \param unroll_loop Whether the loop should be unrolled. * \param unroll_loop Whether the loop should be unrolled.
* \param extra_loop_lower_bound Extra loop lower bound.
* \return The result loop. * \return The result loop.
*/ */
Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop, Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop,
Optional<PrimExpr> extra_loop_lower_bound = std::nullopt) { bool need_bound_check) {
PrimExpr new_loop_var; PrimExpr new_loop_var;
PrimExpr extent = end - start; PrimExpr extent = end - start;
auto make_nop = []() { auto make_nop = []() {
return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {})); return BlockRealize({}, Bool(true), MakeBlock(Evaluate(0), {}));
}; };
if (analyzer_.CanProve(extent <= 0)) {
return make_nop();
}
bool is_unit_loop = analyzer_.CanProveEqual(extent, 1); bool is_unit_loop = analyzer_.CanProveEqual(extent, 1);
if (is_unit_loop) { if (is_unit_loop) {
new_loop_var = start; // use constants as the loop var for unit loops new_loop_var = start; // use constants as the loop var for unit loops
...@@ -912,36 +632,26 @@ private: ...@@ -912,36 +632,26 @@ private:
analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end)); analyzer_.Bind(Downcast<Var>(new_loop_var), Range(start, end));
} }
// In contrast to analyzer_ which is bound to [start, end), this one is
// bound to the "normalized" range, [pipeline_loop_->min, extent).
arith::Analyzer ana_normalized;
if (!is_unit_loop) {
ana_normalized.Bind(Downcast<Var>(new_loop_var),
Range(pipeline_loop_->min, extent));
}
std::vector<RewrittenBlockInfo> new_blocks; std::vector<RewrittenBlockInfo> new_blocks;
// Async related // Async related
std::map<int, AsyncStateLocal> async_states_local; std::map<int, AsyncStateLocal> async_states_local;
std::unordered_map<const BufferNode *, int> buffer_to_commit_group;
for (const Block &block : ordered_stmts_) { for (const Block &block : ordered_stmts_) {
int stage = pipeline_info_.at(block).stage; int stage = pipeline_info_.at(block).stage;
int order = pipeline_info_.at(block).order;
PrimExpr inbound = Bool(true);
PrimExpr skewed_loop_var = new_loop_var - stage; PrimExpr skewed_loop_var = new_loop_var - stage;
PrimExpr inbound = if (need_bound_check)
analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) && inbound =
(skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent); analyzer_.Simplify(pipeline_loop_->min <= skewed_loop_var) &&
if (extra_loop_lower_bound.defined()) { (skewed_loop_var < pipeline_loop_->min + pipeline_loop_->extent);
inbound = analyzer_.Simplify(
inbound && new_loop_var >= extra_loop_lower_bound.value());
}
if (analyzer_.CanProve(!inbound)) { if (analyzer_.CanProve(!inbound)) {
continue; continue;
} }
Block new_block = Downcast<Block>(PipelineBodyRewriter( Block new_block = Downcast<Block>(
buffer_data_to_buffer_, buffer_remap_, pipeline_loop_, PipelineBodyRewriter(buffer_data_to_buffer_, buffer_remap_,
max_stage_ != 1, fragment_info_)(block)); pipeline_loop_, max_stage_ != 1)(block));
PrimExpr delta = start - pipeline_loop_->min; PrimExpr delta = start - pipeline_loop_->min;
// This variable corresponds to // This variable corresponds to
...@@ -958,76 +668,31 @@ private: ...@@ -958,76 +668,31 @@ private:
Var loop_iter = Downcast<Var>(new_loop_var); Var loop_iter = Downcast<Var>(new_loop_var);
inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}}); inbound = Substitute(inbound, {{loop_iter, loop_iter + delta}});
} }
new_block = Downcast<Block>(Substitute( new_block = Downcast<Block>(Substitute(
new_block, {{pipeline_loop_->loop_var, normalized_access_index}})); new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));
if (pipeline_info_[block].async) { if (pipeline_info_[block].async) {
auto &local_state = async_states_local[stage]; auto &local_state = async_states_local[stage];
int commit_group_id = -1;
if (local_state.commit_groups.empty() || local_state.consumed) {
// consumed == true means there is already a consumer stage waiting
// for an eariler async operation of this stage. In such cases, we
// make multiple commit_queue for this stage.
commit_group_id = local_state.commit_groups.size();
local_state.commit_groups.push_back({new_blocks.size()});
} else {
// This is the case when one commit_queue groups multiple async
// blocks. with commit_queue(stage):
// async_scope:
// A_shared[...] = ...
// async_scope:
// B_shared[...] = ...
commit_group_id = local_state.commit_groups.size() - 1;
local_state.commit_groups.back().push_back(new_blocks.size());
}
for (auto write_region : new_block->writes) {
async_states[stage].dst_buffers.insert(write_region->buffer.get());
buffer_to_commit_group[write_region->buffer.get()] = commit_group_id;
}
local_state.producer_head = normalized_access_index; local_state.producer_head = normalized_access_index;
if (!local_state.predicate ||
ana_normalized.CanProve(local_state.predicate.value())) {
local_state.predicate = inbound;
} else if (local_state.predicate) {
local_state.predicate =
ana_normalized.Simplify(local_state.predicate.value() & inbound);
}
BlockNode *n = new_block.CopyOnWrite(); BlockNode *n = new_block.CopyOnWrite();
n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope, n->body = AttrStmt(make_zero(DataType::Int(32)), tir::attr::async_scope,
1, n->body); 1, n->body);
} }
new_blocks.push_back({stage, inbound, new_block, normalized_access_index, new_blocks.push_back({stage, order, inbound, new_block,
normalized_access_index,
pipeline_info_[block].async}); pipeline_info_[block].async});
for (auto read_region : new_block->reads) {
for (auto kv : async_states) {
int producer_stage_id = kv.first;
if (producer_stage_id <= stage &&
kv.second.writes(read_region->buffer)) {
async_states_local[producer_stage_id].consumed = true;
}
}
}
} }
PopulateWaitCounts(new_blocks, &ana_normalized, buffer_to_commit_group, PopulateWaitCounts(new_blocks, &async_states_local);
&async_states_local);
auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local, auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local);
&ana_normalized);
Stmt new_loop{nullptr}; Stmt new_loop{nullptr};
if (stmts.empty()) { if (stmts.empty()) {
return make_nop(); return make_nop();
} }
if (stmts.size() == 1) { if (stmts.size() == 1) {
new_loop = stmts[0]; new_loop = stmts[0];
} else { } else {
...@@ -1035,26 +700,22 @@ private: ...@@ -1035,26 +700,22 @@ private:
} }
if (!is_unit_loop) { if (!is_unit_loop) {
Map<String, Any> preserved_annotations;
for (const auto &kv : pipeline_loop_->annotations) {
const String &key = kv.first;
if (kv.first != tir::attr::software_pipeline_stage &&
kv.first != tir::attr::software_pipeline_order &&
kv.first != tir::attr::software_pipeline_async_stages) {
preserved_annotations.Set(key, kv.second);
}
}
new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent, new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind,
std::move(new_loop), std::nullopt, preserved_annotations_); std::move(new_loop), std::nullopt, preserved_annotations);
} }
// Update producer heads in the global async states. // Update producer heads in the global async states.
for (const auto &kv : async_states_local) { for (const auto &[stage_id, state] : async_states_local) {
const int stage_id = kv.first; async_states[stage_id].producer_head += extent;
const AsyncStateLocal &state = kv.second;
if (state.predicate && ana_normalized.CanProve(state.predicate.value()) &&
async_states[stage_id].producer_head) {
// Advance the "global" producer head if it is still valid and we know
// exactly how much we can increment
async_states[stage_id].producer_head =
async_states[stage_id].producer_head.value() + extent;
} else {
// Otherwise, invalidate the global producer head
async_states[stage_id].producer_head = std::nullopt;
}
} }
return BlockRealize({}, Bool(true), return BlockRealize({}, Bool(true),
...@@ -1063,17 +724,13 @@ private: ...@@ -1063,17 +724,13 @@ private:
arith::Analyzer analyzer_; arith::Analyzer analyzer_;
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
const std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual>
&double_buffers_;
Array<Buffer> pipeline_allocs_; Array<Buffer> pipeline_allocs_;
For pipeline_loop_; For pipeline_loop_;
PipelineInfo pipeline_info_; PipelineInfo pipeline_info_;
const std::unordered_map<const VarNode *, FragmentInfo> &fragment_info_;
int max_stage_ = -1; int max_stage_ = -1;
Map<Buffer, Buffer> buffer_remap_; Map<Buffer, Buffer> buffer_remap_;
Array<Block> ordered_stmts_; Array<Block> ordered_stmts_;
std::map<int, AsyncStateGlobal> async_states; std::map<int, AsyncStateGlobal> async_states;
Map<String, ffi::Any> preserved_annotations_;
}; };
/*! /*!
...@@ -1088,7 +745,8 @@ void BuildDependencyGraph(const Array<Block> &blocks, ...@@ -1088,7 +745,8 @@ void BuildDependencyGraph(const Array<Block> &blocks,
ObjectPtrEqual> *dep_src2dst, ObjectPtrEqual> *dep_src2dst,
std::unordered_map<Block, Array<Block>, ObjectPtrHash, std::unordered_map<Block, Array<Block>, ObjectPtrHash,
ObjectPtrEqual> *dep_dst2src) { ObjectPtrEqual> *dep_dst2src) {
std::unordered_map<Var, Array<Block>> buffer_writers; std::unordered_map<Var, Array<Block>, ObjectPtrHash, ObjectPtrEqual>
buffer_writers;
for (const Block &block : blocks) { for (const Block &block : blocks) {
for (const BufferRegion &read : block->reads) { for (const BufferRegion &read : block->reads) {
...@@ -1119,7 +777,6 @@ public: ...@@ -1119,7 +777,6 @@ public:
const Buffer &buffer = kv.second; const Buffer &buffer = kv.second;
injector.buffer_data_to_buffer_.Set(buffer->data, buffer); injector.buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
injector.fragment_info_ = GetTensorCoreFragmentInfo(func->body);
return injector(func->body); return injector(func->body);
} }
...@@ -1178,7 +835,7 @@ private: ...@@ -1178,7 +835,7 @@ private:
// Step 1: Recursively rewrite the children first. // Step 1: Recursively rewrite the children first.
For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op)); For for_node = Downcast<For>(StmtExprMutator::VisitStmt_(op));
if (!HasPipelineAnnotation(op)) { if (!HasPipelineAnnotation(op)) {
return std::move(for_node); return for_node;
} }
// Step 2: Find the body and buffer allocations of the pipeline. The body // Step 2: Find the body and buffer allocations of the pipeline. The body
// can be direct child of the for-loop. If the for-loop has BlockRealize as // can be direct child of the for-loop. If the for-loop has BlockRealize as
...@@ -1256,16 +913,6 @@ private: ...@@ -1256,16 +913,6 @@ private:
} }
} }
Map<String, ffi::Any> preserved_annotations;
for (const auto &kv : op->annotations) {
const String &key = kv.first;
if (kv.first != tir::attr::software_pipeline_stage &&
kv.first != tir::attr::software_pipeline_order &&
kv.first != tir::attr::software_pipeline_async_stages) {
preserved_annotations.Set(key, kv.second);
}
}
for (size_t i = 0; i < pipeline_stages.size(); i++) { for (size_t i = 0; i < pipeline_stages.size(); i++) {
int stage = static_cast<int>(pipeline_stages[i]->value); int stage = static_cast<int>(pipeline_stages[i]->value);
bool is_async = bool is_async =
...@@ -1279,9 +926,9 @@ private: ...@@ -1279,9 +926,9 @@ private:
ValidatePipelineBody(pipeline_info, original_order); ValidatePipelineBody(pipeline_info, original_order);
// Step 4: Rewrite the pipeline body. // Step 4: Rewrite the pipeline body.
Stmt pipeline = PipelineRewriter::Rewrite( Stmt pipeline = PipelineRewriter(buffer_data_to_buffer_, pipeline_allocs,
buffer_data_to_buffer_, double_buffers, pipeline_allocs, GetRef<For>(op), pipeline_info)
GetRef<For>(op), pipeline_info, fragment_info_, preserved_annotations); .BuildPipeline();
if (const auto *realize = op->body.as<BlockRealizeNode>()) { if (const auto *realize = op->body.as<BlockRealizeNode>()) {
const auto &block = realize->block; const auto &block = realize->block;
...@@ -1297,16 +944,6 @@ private: ...@@ -1297,16 +944,6 @@ private:
buffer_data_to_buffer_.Set(buffer->data, buffer); buffer_data_to_buffer_.Set(buffer->data, buffer);
} }
auto it = op->annotations.find(tir::attr::double_buffer_scope);
if (it != op->annotations.end()) {
int buffer_index = Downcast<Integer>((*it).second).IntValue();
CHECK(buffer_index >= 0 &&
static_cast<size_t>(buffer_index) < op->writes.size())
<< "ValueError: Index of the buffer exceeds the size of the write "
"regions of the block. ("
<< buffer_index << " vs. " << op->writes.size() << ")";
double_buffers.insert(op->writes[buffer_index]->buffer);
}
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op)); Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));
for (const auto &buffer : op->alloc_buffers) { for (const auto &buffer : op->alloc_buffers) {
...@@ -1325,21 +962,18 @@ private: ...@@ -1325,21 +962,18 @@ private:
} }
if (has_stage) { if (has_stage) {
LOG(FATAL) LOG(FATAL)
<< "ValueError: Order of the software pipeline is not defined."; << "ValueError: Stage of the software pipeline is not defined.";
} }
if (has_order) { if (has_order) {
LOG(FATAL) LOG(FATAL)
<< "ValueError: Stage of the software pipeline is not defined."; << "ValueError: Order of the software pipeline is not defined.";
} }
return false; return false;
} }
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
std::unordered_map<const VarNode *, FragmentInfo> fragment_info_;
std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> double_buffers;
Optional<String> global_symbol_; Optional<String> global_symbol_;
}; };
} // namespace software_pipeline } // namespace software_pipeline
/*! /*!
......
...@@ -9,7 +9,6 @@ def _check(original, transformed): ...@@ -9,7 +9,6 @@ def _check(original, transformed):
mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main")) mod = tvm.IRModule.from_expr(func.with_attr("global_symbol", "main"))
mod = tl.transform.InjectSoftwarePipeline()(mod) mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tl.transform.Simplify()(mod) mod = tl.transform.Simplify()(mod)
print(mod["main"])
tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"), tvm.ir.assert_structural_equal(mod["main"], transformed.with_attr("global_symbol", "main"),
True) True)
...@@ -40,21 +39,29 @@ def test_trival_pipeline(): ...@@ -40,21 +39,29 @@ def test_trival_pipeline():
C[tx, i] = B[tx, 0] + T.float32(1) C[tx, i] = B[tx, 0] + T.float32(1)
@T.prim_func @T.prim_func
def expected(A: T.Tensor((16, 1), "float32"), C: T.Tensor((16, 1), "float32")) -> None: def expected(A: T.Buffer((16, 1), "float32"), C: T.Buffer((16, 1), "float32")):
for tx in T.thread_binding(16, thread="threadIdx.x"): for tx in T.thread_binding(16, thread="threadIdx.x"):
with T.block(""): with T.block():
T.reads(A[tx, 0]) T.reads(A[tx, 0])
T.writes(C[tx, 0]) T.writes(C[tx, 0])
B = T.alloc_buffer((2, 16, 1), scope="shared") B = T.alloc_buffer((2, 16, 1), scope="shared")
with T.block(""): with T.block():
T.reads(A[tx, 0]) T.reads(A[tx, 0])
T.writes(B[0, tx, 0]) T.writes(B[0, tx, 0])
B[0, tx, 0] = A[tx, 0] * T.float32(2.0) B[0, tx, 0] = A[tx, 0] * T.float32(2.0)
with T.block(""): with T.block():
T.reads() T.reads(A[tx, 1:1], B[0:2, tx, 0])
T.writes() T.writes(B[1:1, tx, 0], C[tx, 0:0])
T.evaluate(0) for i in range(0):
with T.block(""): with T.block():
T.reads(A[tx, i + 1])
T.writes(B[i + 1, tx, 0])
B[i + 1, tx, 0] = A[tx, i + 1] * T.float32(2.0)
with T.block():
T.reads(B[i, tx, 0])
T.writes(C[tx, i])
C[tx, i] = B[i, tx, 0] + T.float32(1.0)
with T.block():
T.reads(B[0, tx, 0]) T.reads(B[0, tx, 0])
T.writes(C[tx, 0]) T.writes(C[tx, 0])
C[tx, 0] = B[0, tx, 0] + T.float32(1.0) C[tx, 0] = B[0, tx, 0] + T.float32(1.0)
......
...@@ -79,7 +79,6 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: ...@@ -79,7 +79,6 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule:
mod = tilelang.transform.LegalizeVectorizedLoop()(mod) mod = tilelang.transform.LegalizeVectorizedLoop()(mod)
# Add safety checks for memory accesses # Add safety checks for memory accesses
mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod) mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod)
# Align dynamic shared memory allocations
# Simplify again to clean up any duplicated conditions # Simplify again to clean up any duplicated conditions
# that may have been introduced by safety checks # that may have been introduced by safety checks
# use an enhanced pass to simplify the dynamic symbolics # use an enhanced pass to simplify the dynamic symbolics
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment