"ts/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "0e8a9f8272affb1b41d576bc06c9eb59e763bbc3"
Unverified Commit 376ba9eb authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Pipeline] Optimize inject software pipeline and pipeline planing pass (#706)

* Refactor inject_pipeline.cc to improve version handling and add unique producer head tracking

- Updated version check to allow for cases with two or more versions.
- Adjusted logic to decrement num_versions when multi-versioning is not needed.
- Introduced a helper function to ensure unique producer heads are added to the commit group.
- Removed obsolete AddAllocBuffers method to streamline code.

* lint fix

* Refactor pipeline planning logic to enhance copy stage dependency management

- Removed obsolete conditional expression handling from the pipeline planning code.
- Introduced a new structure to manage copy stage dependency reads, improving clarity and efficiency.
- Updated logic to correctly identify producer stages for copy stages, ensuring accurate pipeline stage assignment.
- Added a new block sparse matrix multiplication function in the testing suite to validate the pipeline planning changes.

* Update ci.yml

* Fix structural equality checks in AddUnique and Contains methods to compare buffer references instead of entire regions in pipeline planning.

* Refactor pipeline planning logic to improve copy stage dependency propagation

- Updated structural equality checks in AddUnique and Contains methods to use buffer reference comparison.
- Enhanced the iteration logic for managing copy stage dependencies, ensuring accurate identification of producer stages.
- Added safeguards against exceeding maximum iterations during dependency propagation.
parent 407117e1
...@@ -111,11 +111,11 @@ jobs: ...@@ -111,11 +111,11 @@ jobs:
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
cd examples cd examples
unset PYTHONPATH unset PYTHONPATH
python -m pytest -n 8 **/test*.py python -m pytest -n 4 **/test*.py
- name: Run tests - name: Run tests
run: | run: |
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate" source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
cd testing/python cd testing/python
unset PYTHONPATH unset PYTHONPATH
python -m pytest -n 8 python -m pytest -n 4
...@@ -508,7 +508,7 @@ private: ...@@ -508,7 +508,7 @@ private:
// We optimize a few case where the number of versions can be smaller than // We optimize a few case where the number of versions can be smaller than
// the upper bound // the upper bound
int num_versions = buffer_info.use - buffer_info.def + 1; int num_versions = buffer_info.use - buffer_info.def + 1;
if (num_versions == 2) { if (num_versions >= 2) {
// A special case when `use - def + 1 == 2`. Double buffering is only // A special case when `use - def + 1 == 2`. Double buffering is only
// needed in this case when these exists a reader block_i and a writer // needed in this case when these exists a reader block_i and a writer
// block_j such that order(block_i) < order(block_j) and stage(block_i) < // block_j such that order(block_i) < order(block_j) and stage(block_i) <
...@@ -547,7 +547,7 @@ private: ...@@ -547,7 +547,7 @@ private:
} }
} }
if (!need_multi_version) { if (!need_multi_version) {
num_versions = 1; num_versions--;
} }
} }
if (num_versions == 1 && double_buffers_.count(buffer)) { if (num_versions == 1 && double_buffers_.count(buffer)) {
...@@ -647,6 +647,7 @@ private: ...@@ -647,6 +647,7 @@ private:
arith::Analyzer *ana_normalized, arith::Analyzer *ana_normalized,
const std::unordered_map<const BufferNode *, int> &buffer_to_commit_group, const std::unordered_map<const BufferNode *, int> &buffer_to_commit_group,
std::map<int, AsyncStateLocal> *async_states_local) { std::map<int, AsyncStateLocal> *async_states_local) {
for (size_t i = 0; i < new_blocks.size(); ++i) { for (size_t i = 0; i < new_blocks.size(); ++i) {
if (new_blocks[i].is_async) { if (new_blocks[i].is_async) {
// Record the fact that we have encountered these write buffers. // Record the fact that we have encountered these write buffers.
...@@ -716,16 +717,28 @@ private: ...@@ -716,16 +717,28 @@ private:
// head at compute points to the copy done by the previous iteration, so // head at compute points to the copy done by the previous iteration, so
// its wait_count is calculated as ((i - 1) + 3) - i. The sum of the two // its wait_count is calculated as ((i - 1) + 3) - i. The sum of the two
// wait_counts gives 5. // wait_counts gives 5.
// print async_states_local
auto &dep_local_state = (*async_states_local)[producer_stage_idx]; auto &dep_local_state = (*async_states_local)[producer_stage_idx];
const auto num_commit_group = dep_local_state.commit_groups.size(); const auto num_commit_group = dep_local_state.commit_groups.size();
std::vector<Optional<PrimExpr>> producer_head_per_commit; std::vector<Optional<PrimExpr>> producer_head_per_commit;
auto add_unique_producer_head =
[&](const Optional<PrimExpr> &producer_head) {
// if producer_head already in producer_head_per_commit, return
for (const auto &head : producer_head_per_commit) {
if (StructuralEqual()(head, producer_head)) {
return;
}
}
producer_head_per_commit.push_back(producer_head);
};
if (num_commit_group == 0) { if (num_commit_group == 0) {
// Epilogue, no async producer. Since "local" producer_head is not // Epilogue, no async producer. Since "local" producer_head is not
// available, use "global" producer_head. // available, use "global" producer_head.
ICHECK(!dep_local_state.producer_head); ICHECK(!dep_local_state.producer_head);
producer_head_per_commit.push_back( add_unique_producer_head(
async_states[producer_stage_idx].producer_head); async_states[producer_stage_idx].producer_head);
} else { } else {
ICHECK(dep_local_state.producer_head); ICHECK(dep_local_state.producer_head);
...@@ -742,12 +755,10 @@ private: ...@@ -742,12 +755,10 @@ private:
if (!dep_local_state.seen.count(read_region->buffer.get())) { if (!dep_local_state.seen.count(read_region->buffer.get())) {
// Multiple async producers interleaved: The most recent async write // Multiple async producers interleaved: The most recent async write
// is from the previous iteration. This is the B_shared case above. // is from the previous iteration. This is the B_shared case above.
producer_head_per_commit.push_back( add_unique_producer_head(dep_local_state.producer_head.value() - 1);
dep_local_state.producer_head.value() - 1);
} else { } else {
// Normal case // Normal case
producer_head_per_commit.push_back( add_unique_producer_head(dep_local_state.producer_head.value());
dep_local_state.producer_head.value());
} }
need_wait_count[commit_group_id] = false; need_wait_count[commit_group_id] = false;
...@@ -756,7 +767,7 @@ private: ...@@ -756,7 +767,7 @@ private:
auto wait_count = [=, &ana_normalized]() { auto wait_count = [=, &ana_normalized]() {
auto sum = PrimExpr(0); auto sum = PrimExpr(0);
for (auto producer_head : producer_head_per_commit) { for (const auto &producer_head : producer_head_per_commit) {
if (producer_head && if (producer_head &&
ana_normalized->CanProve(producer_head.value() >= 0)) { ana_normalized->CanProve(producer_head.value() >= 0)) {
// Here, new_blocks[i].access_index corresponds to "consumer_head". // Here, new_blocks[i].access_index corresponds to "consumer_head".
...@@ -1281,23 +1292,6 @@ private: ...@@ -1281,23 +1292,6 @@ private:
return pipeline; return pipeline;
} }
/*!
* \brief Add buffer allocations to a block and update the write region of the
* block. \param n The block pointer to which the buffer allocations are
* added. \param alloc_buffers The buffer allocations to be added.
*/
void AddAllocBuffers(BlockNode *n, const Array<Buffer> alloc_buffers) {
for (const Buffer &alloc_buffer : alloc_buffers) {
n->alloc_buffers.push_back(alloc_buffer);
Region region;
region.reserve(alloc_buffer->shape.size());
for (const PrimExpr &dim : alloc_buffer->shape) {
region.push_back(Range::FromMinExtent(0, dim));
}
n->writes.push_back(BufferRegion(alloc_buffer, region));
}
}
Stmt VisitStmt_(const BlockNode *op) final { Stmt VisitStmt_(const BlockNode *op) final {
for (const auto &buffer : op->alloc_buffers) { for (const auto &buffer : op->alloc_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer); buffer_data_to_buffer_.Set(buffer->data, buffer);
......
...@@ -49,8 +49,6 @@ public: ...@@ -49,8 +49,6 @@ public:
bool GetGlobalCopyPattern() const { return is_global_copy_pattern_; } bool GetGlobalCopyPattern() const { return is_global_copy_pattern_; }
PrimExpr GetConditonalExpr() const { return conditonal_expr; }
private: private:
void VisitStmt_(const BufferStoreNode *op) final { void VisitStmt_(const BufferStoreNode *op) final {
Buffer store_buffer = op->buffer; Buffer store_buffer = op->buffer;
...@@ -105,31 +103,11 @@ private: ...@@ -105,31 +103,11 @@ private:
// because we only care about the buffer itself instead of indices // because we only care about the buffer itself instead of indices
reads_.push_back(buffer_region); reads_.push_back(buffer_region);
} }
} else if (op->op.same_as(tir::builtin::if_then_else())) {
// Simplify nested if_then_else
// if (cond) { if (inner_cond) { inner_then_expr } else { inner_else_expr
// } } else { else_expr }
// => if (cond && inner_cond) { inner_then_expr } else { else_expr }
const PrimExpr &cond = op->args[0];
const PrimExpr &then_expr = op->args[1];
const PrimExpr &else_expr = op->args[2];
conditonal_expr = cond;
this->VisitExpr(then_expr);
this->VisitExpr(else_expr);
} else { } else {
StmtExprVisitor::VisitExpr_(op); StmtExprVisitor::VisitExpr_(op);
} }
} }
void VisitStmt_(const IfThenElseNode *op) final {
// Skip condition
this->VisitStmt(op->then_case);
conditonal_expr = op->condition;
if (op->else_case.defined()) {
this->VisitStmt(op->else_case.value());
}
}
private: private:
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
Array<BufferRegion> reads_; Array<BufferRegion> reads_;
...@@ -137,7 +115,6 @@ private: ...@@ -137,7 +115,6 @@ private:
bool is_global_read_ = false; bool is_global_read_ = false;
bool under_buffer_store_ = false; bool under_buffer_store_ = false;
bool is_global_copy_pattern_ = false; bool is_global_copy_pattern_ = false;
PrimExpr conditonal_expr;
}; };
class PipelinePlanner : public StmtExprMutator { class PipelinePlanner : public StmtExprMutator {
...@@ -162,23 +139,38 @@ private: ...@@ -162,23 +139,38 @@ private:
* *
* \param reads Array of buffer regions read by this stage * \param reads Array of buffer regions read by this stage
* \param writes Array of buffer regions written by this stage * \param writes Array of buffer regions written by this stage
* \param original_order Original position of this stage in the pipeline * \param original_stmt_index Original position of this stage in the pipeline
* before reordering \param order Current position of this stage in the * before reordering \param order Current position of this stage in the
* pipeline after reordering (-1 if not yet assigned) \param stage Pipeline * pipeline after reordering (-1 if not yet assigned) \param stage Pipeline
* stage number this operation belongs to (-1 if not yet assigned) \param * stage number this operation belongs to (-1 if not yet assigned) \param
* copy_stage Whether this stage is a memory copy operation \param * copy_stage Whether this stage is a memory copy operation \param
* last_use_stage Last pipeline stage that uses the results of this stage (-1 * last_use_stmt_index Index of the last statement (in original order) that
* if not yet determined) * uses the results of this stage (-1 if not yet determined). This field is
* crucial for pipeline optimization:
* - For copy stages: indicates the index of the last statement that reads
* from the copied data, helping determine optimal placement of copy
* operations
* - Used to ensure copy operations are scheduled before their consumers
* - A value of -1 means no subsequent statement uses this stage's output
* - This information enables better pipeline scheduling by minimizing data
* dependencies and maximizing parallelism
*/ */
struct PipelineStageInfo { struct PipelineStageInfo {
Array<BufferRegion> reads, writes; Array<BufferRegion> reads, writes;
int original_order; int original_stmt_index;
int order = -1, stage = -1; int order = -1, stage = -1;
bool copy_stage = false; bool copy_stage = false;
bool prepare_for_condition = false; bool producer_for_copy = false;
int last_use_stage = -1; int last_use_stmt_index =
// represent the stage is used in a conditional statement -1; // Initialized to -1, indicating no consumers found yet
PrimExpr conditonal_expr;
public:
bool is_first_stage() const { return copy_stage || producer_for_copy; }
bool is_copy_stage() const { return copy_stage; }
bool is_producer_for_copy() const { return producer_for_copy; }
bool is_last_use_stmt_index_valid() const {
return last_use_stmt_index != -1;
}
}; };
PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx) { PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx) {
...@@ -191,9 +183,8 @@ private: ...@@ -191,9 +183,8 @@ private:
PipelineStageInfo pinfo; PipelineStageInfo pinfo;
pinfo.reads = std::move(collector.GetReads()); pinfo.reads = std::move(collector.GetReads());
pinfo.writes = std::move(collector.GetWrites()); pinfo.writes = std::move(collector.GetWrites());
pinfo.original_order = idx; pinfo.original_stmt_index = idx;
pinfo.copy_stage = collector.GetGlobalCopyPattern(); pinfo.copy_stage = collector.GetGlobalCopyPattern();
pinfo.conditonal_expr = collector.GetConditonalExpr();
return std::move(pinfo); return std::move(pinfo);
} }
...@@ -287,40 +278,135 @@ private: ...@@ -287,40 +278,135 @@ private:
pipeline_stage_infos.push_back(std::move(pinfo)); pipeline_stage_infos.push_back(std::move(pinfo));
} }
// process the conditional stage // For every copy stage, mark all its dependency stages as producer_for_copy
// assign conditional stage (analysis the copy stage) // Helper struct to manage copy stage dependency reads
for (auto &pinfo : pipeline_stage_infos) { struct CopyStageDependencyReadsManager {
for (const auto &write : pinfo.writes) { std::vector<BufferRegion> regions;
for (const auto &other : pipeline_stage_infos) {
if (other.conditonal_expr.defined()) { // Add a region if not already present (by structural equality)
auto check_var = [&](const ObjectRef &n) { void AddUnique(const BufferRegion &region) {
if (const auto *buffer_load = n.as<BufferLoadNode>()) { for (const BufferRegion &copy_read : regions) {
if (buffer_load->buffer == write->buffer) { if (region->buffer.same_as(copy_read->buffer)) {
pinfo.prepare_for_condition = true; return;
}
}
regions.push_back(region);
}
// Check if a region is present (by structural equality)
bool Contains(const BufferRegion &region) const {
for (const BufferRegion &copy_read : regions) {
if (region->buffer.same_as(copy_read->buffer)) {
return true;
} }
} }
return false;
}
size_t Size() const { return regions.size(); }
}; };
PostOrderVisit(other.conditonal_expr, check_var);
CopyStageDependencyReadsManager copy_stage_dependency_reads_mgr;
// Step 1. Collect Copy reads
for (const auto &pinfo : pipeline_stage_infos) {
if (pinfo.is_copy_stage()) {
for (const BufferRegion &read : pinfo.reads) {
copy_stage_dependency_reads_mgr.AddUnique(read);
}
}
}
// Step 2. find if pinfo write the copy reads, then update the
// copy_stage_dependency_reads To prevent infinite loops, we set a maximum
// number of iterations. In theory, the number of possible updates is
// bounded by the number of pipeline stages, since each stage can only be
// marked as producer_for_copy once, and each read can only be added once.
// But for safety, we add a hard limit.
const size_t max_iterations = (pipeline_stage_infos.size() * 4) + 16;
size_t iter_count = 0;
for (auto &pinfo : pipeline_stage_infos) {
if (!pinfo.is_copy_stage()) {
continue;
} }
auto original_copy_stmt_index = pinfo.original_stmt_index;
bool updated = true;
while (updated) {
updated = false;
for (auto &pinfo_inner : pipeline_stage_infos) {
if (pinfo_inner.is_copy_stage()) {
continue;
}
if (pinfo_inner.original_stmt_index >= original_copy_stmt_index) {
break;
}
bool should_prepare = false;
for (const BufferRegion &write : pinfo_inner.writes) {
if (copy_stage_dependency_reads_mgr.Contains(write)) {
should_prepare = true;
break;
}
}
if (should_prepare && !pinfo_inner.is_producer_for_copy()) {
pinfo_inner.producer_for_copy = true;
updated = true;
}
if (should_prepare) {
for (const BufferRegion &read : pinfo_inner.reads) {
size_t before = copy_stage_dependency_reads_mgr.Size();
copy_stage_dependency_reads_mgr.AddUnique(read);
if (copy_stage_dependency_reads_mgr.Size() > before) {
updated = true;
}
}
}
}
iter_count++;
if (iter_count > max_iterations) {
LOG(FATAL)
<< "Pipeline planning: Exceeded maximum iterations ("
<< max_iterations << ") in copy stage dependency propagation. "
<< "This may indicate a cyclic or pathological dependency graph.";
} }
} }
} }
// analysis use-def chain // Analysis use-def chain to determine last_use_stmt_index for copy
// operations This step is critical for pipeline optimization as it
// identifies the index of the last statement that consumes data produced by
// copy stages, enabling optimal placement of copy operations in the
// pipeline schedule.
for (auto &pinfo : pipeline_stage_infos) { for (auto &pinfo : pipeline_stage_infos) {
for (int i = pinfo.original_order + 1; // Only analyze copy stages (memory copy operations)
i < static_cast<int>(pipeline_body_seq->size()); i++) { if (!pinfo.is_first_stage())
if (!pinfo.copy_stage)
continue; continue;
// Check all subsequent statements to find the latest consumer
for (int i = pinfo.original_stmt_index + 1;
i < static_cast<int>(pipeline_body_seq->size()); i++) {
// Check if any read operation in statement 'i' uses data written by
// this copy stage
for (const BufferRegion &read : pipeline_stage_infos[i].reads) { for (const BufferRegion &read : pipeline_stage_infos[i].reads) {
// Look for overlapping buffer regions between this stage's writes and
// stage 'i's reads
if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(), if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(),
[&](const BufferRegion &r) { [&](const BufferRegion &r) {
return r->buffer == read->buffer && return r->buffer == read->buffer &&
MayConflict(r->region, read->region); MayConflict(r->region, read->region);
}) != pinfo.writes.end()) { }) != pinfo.writes.end()) {
pinfo.last_use_stage = std::max(pinfo.last_use_stage, i); // Update last_use_stmt_index to the maximum (latest) statement
// index that uses this data This ensures we capture the final
// consumer of the copied data
pinfo.last_use_stmt_index = std::max(pinfo.last_use_stmt_index, i);
} }
} }
// Check for write-after-write conflicts (multiple stages writing to
// same buffer region) This is important for pipeline correctness and
// affects last_use_stmt_index analysis
if (pinfo.is_copy_stage()) {
for (const BufferRegion &write : pipeline_stage_infos[i].writes) { for (const BufferRegion &write : pipeline_stage_infos[i].writes) {
if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(), if (std::find_if(pinfo.writes.begin(), pinfo.writes.end(),
[&](const BufferRegion &r) { [&](const BufferRegion &r) {
...@@ -329,25 +415,30 @@ private: ...@@ -329,25 +415,30 @@ private:
}) != pinfo.writes.end()) { }) != pinfo.writes.end()) {
LOG(FATAL) << "Pipeline planning error: Multiple writes to " LOG(FATAL) << "Pipeline planning error: Multiple writes to "
"overlapping buffer regions detected. " "overlapping buffer regions detected. "
<< "Stage " << pinfo.original_order << " and stage " << i << "Stage " << pinfo.original_stmt_index
<< " are both writing to buffer '" << write->buffer->name << " and stage " << i
<< " are both writing to buffer '"
<< write->buffer->name
<< "' with overlapping regions. This is not supported " << "' with overlapping regions. This is not supported "
"in pipeline planning."; "in pipeline planning.";
} }
} }
} }
} }
}
// Making stages and orders // Making stages and orders
int order_idx = 0; int order_idx = 0;
// Create pipeline stages and assign order // Stage 1. Create pipeline stages and assign order
for (auto &pinfo : pipeline_stage_infos) { for (auto &pinfo : pipeline_stage_infos) {
// Skip elements that must be in first stage: // Skip elements that must be in first stage:
// 1. Copy stages (with active last_use_stage) // 1. Copy stages (with active last_use_stmt_index) - these need special
// 2. Condition preparation stages // handling
if ((pinfo.copy_stage && pinfo.last_use_stage != -1) || // because they have consumers that depend on their data
pinfo.prepare_for_condition) // 2. All Producer stages for copy stages.
if (pinfo.is_first_stage() && pinfo.is_last_use_stmt_index_valid()) {
continue; continue;
}
// Main logic stage assignment: // Main logic stage assignment:
// - Increment order index // - Increment order index
...@@ -355,34 +446,15 @@ private: ...@@ -355,34 +446,15 @@ private:
pinfo.order = order_idx++; pinfo.order = order_idx++;
pinfo.stage = num_stages; pinfo.stage = num_stages;
// Schedule copy stages that have this stage as their last consumer
// This ensures copy operations are placed right before their final
// consumer for optimal pipeline efficiency
for (auto &pinfo_1 : pipeline_stage_infos) { for (auto &pinfo_1 : pipeline_stage_infos) {
if ((pinfo_1.copy_stage && if ((pinfo_1.is_first_stage() &&
pinfo_1.last_use_stage == pinfo.original_order)) { pinfo_1.last_use_stmt_index == pinfo.original_stmt_index)) {
pinfo_1.order = order_idx++; pinfo_1.order = order_idx++;
pinfo_1.stage = 0; pinfo_1.stage = 0; // Copy stages are typically assigned to stage 0
}
}
}
// Handle trailing unassigned copy stages:
// These are typically final copy operations needing post-main-stage
// insertion
auto &head_pinfo = pipeline_stage_infos.at(0);
int unassigned_order_elem = -1;
// Process dependent copy stages:
// Insert copy stages after current stage but assign to stage 0
// and adjust the order index
for (auto &pinfo : pipeline_stage_infos) {
if (pinfo.order == unassigned_order_elem) {
pinfo.order = unassigned_order_elem++;
// traverse the from the next info
for (auto it = pipeline_stage_infos.begin() + unassigned_order_elem;
it != pipeline_stage_infos.end(); it++) {
it->order += 1;
} }
pinfo.stage = 0;
order_idx++;
} }
} }
...@@ -392,14 +464,14 @@ private: ...@@ -392,14 +464,14 @@ private:
<< "Got " << order_idx << " stages and " << pipeline_stage_infos.size() << "Got " << order_idx << " stages and " << pipeline_stage_infos.size()
<< " pipeline stages."; << " pipeline stages.";
// if all the copy is at the end of the order, we can move these copy to the // Step 2. if all the copy is at the end of the order, we can move these
// beginning of the order and shrink the stage offset by 1. // copy to the beginning of the order and shrink the stage offset by 1.
int copy_stage_at_end = [&]() { int copy_stage_at_end = [&]() {
int copy_stage_cnt = 0; int copy_stage_cnt = 0;
int copy_order_min = pipeline_stage_infos.size(); int copy_order_min = pipeline_stage_infos.size();
int non_copy_order_max = 0; int non_copy_order_max = 0;
for (auto &pinfo : pipeline_stage_infos) { for (auto &pinfo : pipeline_stage_infos) {
if (pinfo.copy_stage || pinfo.prepare_for_condition) { if (pinfo.is_first_stage()) {
copy_stage_cnt++; copy_stage_cnt++;
copy_order_min = std::min(copy_order_min, pinfo.order); copy_order_min = std::min(copy_order_min, pinfo.order);
} else { } else {
...@@ -414,7 +486,7 @@ private: ...@@ -414,7 +486,7 @@ private:
for (auto &pinfo : pipeline_stage_infos) { // move copy to the beginning for (auto &pinfo : pipeline_stage_infos) { // move copy to the beginning
pinfo.order = pinfo.order =
(pinfo.order + copy_stage_at_end) % pipeline_stage_infos.size(); (pinfo.order + copy_stage_at_end) % pipeline_stage_infos.size();
if (!pinfo.copy_stage && !pinfo.prepare_for_condition) if (!pinfo.is_copy_stage() && !pinfo.is_producer_for_copy())
pinfo.stage--; pinfo.stage--;
} }
} }
......
from tilelang import tvm as tvm
import tilelang.testing
def matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
accum_dtype,
threads,
order,
stage,
):
A_shape = (K, M) if trans_A else (M, K)
B_shape = (N, K) if trans_B else (K, N)
A_shared_shape = (block_K, block_M) if trans_A else (block_M, block_K)
B_shared_shape = (block_N, block_K) if trans_B else (block_K, block_N)
import tilelang.language as T
@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
T.copy(A[by * block_M, k * block_K], A_shared)
if trans_B:
T.copy(B[bx * block_N, k * block_K], B_shared)
else:
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local, trans_A, trans_B)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_gemm(
order,
stage,
):
M = 1024
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
trans_A = False
trans_B = False
in_dtype = "float16"
out_dtype = "float16"
dtypeAccum = "float32"
num_threads = 128
program = matmul(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A,
trans_B,
in_dtype,
out_dtype,
dtypeAccum,
num_threads,
order,
stage,
)
kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = kernel.get_profiler()
def ref_program(A, B):
import torch
if trans_A:
A = A.T
if trans_B:
B = B.T
if in_dtype == "float32":
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A = ((A.view(torch.int32) - 0x1000)).view(torch.float32)
B = ((B.view(torch.int32) - 0x1000)).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_pipeline_order_stage():
run_gemm(order=[0, 1, 2], stage=[0, 0, 1])
run_gemm(order=[0, 1, 2], stage=[0, 0, 2])
run_gemm(order=[1, 2, 0], stage=[0, 0, 2])
run_gemm(order=[1, 2, 0], stage=[0, 0, 1])
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
def blocksparse_matmul(M,
N,
K,
block_M,
block_N,
block_K,
num_stages,
dtype="float16",
accum_dtype="float"):
block_mask_shape = (M // block_M, N // block_N, K // block_K)
import tilelang.language as T
@T.prim_func
def block_sparse_matmul(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
BlockMask: T.Tensor(block_mask_shape, "bool"),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
block_mask = T.alloc_local((1,), "bool")
C_shared = T.alloc_shared((block_M, block_N), dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
block_mask[0] = BlockMask[by, bx, k]
if block_mask[0]:
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M, bx * block_N])
return block_sparse_matmul
def run_blocksparse_matmul(num_stages):
import torch
M = 256
N = 256
K = 256
block_M = 128
block_N = 128
block_K = 32
sparsity = 0.5
# Initialize input matrices A and B on the GPU with half precision
a = torch.randn(M, K).cuda().half()
b = torch.randn(K, N).cuda().half()
kernel = blocksparse_matmul(
M, N, K, block_M=block_M, block_N=block_N, block_K=block_K, num_stages=num_stages)
print(kernel.get_kernel_source())
# Create block mask with desired sparsity
mask_shape = (M // block_M, N // block_N, K // block_K)
block_mask = torch.rand(mask_shape).cuda() > sparsity
# Run the compiled kernel (either tuned or default) with the inputs
c = kernel(a, b, block_mask)
def ref_program(A, B, BlockMask, block_M, block_N, block_K):
ref_c = torch.zeros((M, N), dtype=torch.float16, device=A.device)
for i in range(M // block_M):
for j in range(N // block_N):
accu = torch.zeros((block_M, block_N), dtype=torch.float32, device=A.device)
for k in range(K // block_K):
if BlockMask[i, j, k]:
accu += (
A[i * block_M:(i + 1) * block_M, k * block_K:(k + 1) * block_K].to(
torch.float32) @ B[k * block_K:(k + 1) * block_K,
j * block_N:(j + 1) * block_N].to(torch.float32))
ref_c[i * block_M:(i + 1) * block_M,
j * block_N:(j + 1) * block_N] = accu.to(torch.float16)
return ref_c
# Compute the reference result using the naive PyTorch implementation
ref_c = ref_program(a, b, block_mask, block_M, block_N, block_K)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
def test_blocksparse_matmul():
run_blocksparse_matmul(num_stages=1)
run_blocksparse_matmul(num_stages=2)
run_blocksparse_matmul(num_stages=3)
if __name__ == "__main__":
tilelang.testing.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment