Commit 44243542 authored by Yu Cheng's avatar Yu Cheng Committed by LeiWang1999
Browse files

[Refactor] Refactor warp_specialized_rewriter to support multiple acquire/release patterns. (#391)

Updated SyncPatternMap to use vectors for acquire and release, enhancing flexibility in handling synchronization patterns. Improved barrier handling logic for both producer and consumer cases, ensuring accurate synchronization in the pipeline.
parent bf0032f8
...@@ -632,17 +632,19 @@ private: ...@@ -632,17 +632,19 @@ private:
} }
} }
if (map.acquire[i] != -1) { for (int pattern_idx : map.acquire[i]) {
PrimExpr acquire_barrier_id = PrimExpr acquire_barrier_id =
stage_ + num_barriers_ + num_stages_ * map.acquire[i]; stage_ + num_barriers_ + num_stages_ * pattern_idx;
PrimExpr parity = map.is_loop_dependency(map.acquire[i]) PrimExpr parity = map.is_loop_dependency(pattern_idx)
? bitwise_xor(parity_, 1) ? bitwise_xor(parity_, 1)
: parity_; : parity_;
block_stmt.push_back(makeParityWait(acquire_barrier_id, parity)); block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
} }
ICHECK(map.release[i] >= 0); ICHECK(map.release[i].size() > 0);
for (size_t j = 0; j < map.release[i].size(); j++) {
int pattern_idx = map.release[i][j];
PrimExpr release_barrier_id = PrimExpr release_barrier_id =
stage_ + num_barriers_ + num_stages_ * map.release[i]; stage_ + num_barriers_ + num_stages_ * pattern_idx;
auto stmt = auto stmt =
MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id); MbarrierRewriter::Rewrite(seq_transformed[i], release_barrier_id);
collector.Collect(stmt); collector.Collect(stmt);
...@@ -657,45 +659,45 @@ private: ...@@ -657,45 +659,45 @@ private:
if (collector.HasSimtCopy() > 0) { if (collector.HasSimtCopy() > 0) {
block_stmt.push_back(makeCpAsyncBarrier(release_barrier_id)); block_stmt.push_back(makeCpAsyncBarrier(release_barrier_id));
} }
if (map.release_after[i]) { if (map.release_after[i][j]) {
block_stmt.push_back(makeArriveBarrier(release_barrier_id)); block_stmt.push_back(makeArriveBarrier(release_barrier_id));
for (int j = 0; j < num_stages_; j++) { for (int s = 0; s < num_stages_; s++) {
released_barrier_.insert(j + num_barriers_ + released_barrier_.insert(s + num_barriers_ +
num_stages_ * map.release[i]); num_stages_ * pattern_idx);
} }
} }
collector.Clear(); collector.Clear();
new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 new_body.push_back(MakeGroupBlock(
? block_stmt[0] block_stmt.size() == 1 ? block_stmt[0]
: SeqStmt(std::move(block_stmt)), : SeqStmt(std::move(block_stmt)),
annotations)); annotations));
} }
}
} else { // consumer case } else { // consumer case
for (int i = 0; i < static_cast<int>(op->seq.size()); i++) { for (int i = 0; i < static_cast<int>(op->seq.size()); i++) {
Array<Stmt> block_stmt = {}; Array<Stmt> block_stmt = {};
if (marker_.GetRole(op->seq[i]) == Role::kProducer) if (marker_.GetRole(op->seq[i]) == Role::kProducer)
continue; continue;
if (map.acquire[i] != -1) { for (int pattern_idx : map.acquire[i]) {
PrimExpr acquire_barrier_id = PrimExpr acquire_barrier_id =
stage_ + num_barriers_ + num_stages_ * map.acquire[i]; stage_ + num_barriers_ + num_stages_ * pattern_idx;
PrimExpr parity = map.is_loop_dependency(map.acquire[i]) PrimExpr parity = map.is_loop_dependency(pattern_idx)
? bitwise_xor(parity_, 1) ? bitwise_xor(parity_, 1)
: parity_; : parity_;
block_stmt.push_back(makeParityWait(acquire_barrier_id, parity)); block_stmt.push_back(makeParityWait(acquire_barrier_id, parity));
} }
block_stmt.push_back(seq_transformed[i]); block_stmt.push_back(seq_transformed[i]);
// new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 ? for (size_t j = 0; j < map.release[i].size(); j++) {
// block_stmt[0] : SeqStmt(std::move(block_stmt)), annotations)); if (map.release_after[i][j]) {
if (map.release_after[i]) { int pattern_idx = map.release[i][j];
PrimExpr release_barrier_id = PrimExpr release_barrier_id =
stage_ + num_barriers_ + num_stages_ * map.release[i]; stage_ + num_barriers_ + num_stages_ * pattern_idx;
block_stmt.push_back(makeArriveBarrier(release_barrier_id)); block_stmt.push_back(makeArriveBarrier(release_barrier_id));
for (int j = 0; j < num_stages_; j++) { for (int s = 0; s < num_stages_; s++) {
released_barrier_.insert(j + num_barriers_ + released_barrier_.insert(s + num_barriers_ +
num_stages_ * map.release[i]); num_stages_ * pattern_idx);
}
} }
// Update the pipeline info
// Todo: handle sync
} }
new_body.push_back(MakeGroupBlock(block_stmt.size() == 1 new_body.push_back(MakeGroupBlock(block_stmt.size() == 1
? block_stmt[0] ? block_stmt[0]
...@@ -828,13 +830,20 @@ private: ...@@ -828,13 +830,20 @@ private:
}; };
struct SyncPatternMap { struct SyncPatternMap {
std::vector<int> acquire; std::vector<std::vector<int>> acquire;
std::vector<int> release; std::vector<std::vector<int>> release;
std::vector<bool> release_after; std::vector<std::vector<bool>> release_after;
std::vector<SyncPattern> patterns; std::vector<SyncPattern> patterns;
bool is_loop_dependency(int i) {
// return if the acquire is based on release in the previous iteration void resize(size_t n) {
return patterns[i].release_idx > patterns[i].acquire_idx; acquire.resize(n);
release.resize(n);
release_after.resize(n);
}
bool is_loop_dependency(int pattern_idx) {
return patterns[pattern_idx].release_idx >
patterns[pattern_idx].acquire_idx;
} }
}; };
...@@ -960,29 +969,41 @@ private: ...@@ -960,29 +969,41 @@ private:
// } // }
SyncPatternMap map; SyncPatternMap map;
map.resize(num_stmts);
map.patterns = sync_patterns; map.patterns = sync_patterns;
map.acquire.resize(num_stmts, -1);
map.release.resize(num_stmts, -1);
map.release_after.resize(num_stmts, false);
for (size_t i = 0; i < sync_patterns.size(); i++) { for (size_t i = 0; i < sync_patterns.size(); i++) {
map.acquire[sync_patterns[i].acquire_idx] = i; int acquire_idx = sync_patterns[i].acquire_idx;
map.release[sync_patterns[i].release_idx] = i; int release_idx = sync_patterns[i].release_idx;
map.release_after[sync_patterns[i].release_idx] = true;
map.acquire[acquire_idx].push_back(i);
map.release[release_idx].push_back(i);
map.release_after[release_idx].push_back(true);
} }
int cur_consumer_barrier = -1, cur_producer_barrier = -1; std::vector<int> cur_consumer_barrier, cur_producer_barrier;
for (int i = num_stmts - 1; i >= 0; i--) { for (int i = num_stmts - 1; i >= 0; i--) {
if (is_producer[i]) { if (is_producer[i]) {
if (map.release[i] == -1) { if (map.release[i].size() == 0) {
map.release[i] = cur_producer_barrier; for (auto pattern_idx : cur_producer_barrier) {
map.release[i].push_back(pattern_idx);
map.release_after[i].push_back(false);
}
} else { } else {
cur_producer_barrier = map.release[i]; for (auto pattern_idx : map.release[i]) {
cur_producer_barrier.push_back(pattern_idx);
}
} }
} else { } else {
if (map.release[i] == -1) { if (map.release[i].size() == 0) {
map.release[i] = cur_consumer_barrier; for (auto pattern_idx : cur_consumer_barrier) {
map.release[i].push_back(pattern_idx);
map.release_after[i].push_back(false);
}
} else { } else {
cur_consumer_barrier = map.release[i]; for (auto pattern_idx : map.release[i]) {
cur_consumer_barrier.push_back(pattern_idx);
}
} }
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment