Unverified Commit a16f0cf5 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Enhancement] Improve buffer conflict detection in thread storage synchronization (#658)

* [Enhancement] Improve buffer conflict detection in thread storage synchronization

- Added a new boolean variable `range_is_overlap` to accurately determine if buffer indices overlap, enhancing the conflict detection logic in `thread_storage_sync.cc`.
- Updated the return logic to reflect the overlap status, ensuring correct conflict resolution based on buffer index comparisons.
- Removed an unnecessary comment in `OptimizeForTarget` to streamline the code and improve clarity.

* example fix

* enhancement

* improve ci
parent c8edb957
......@@ -94,11 +94,11 @@ jobs:
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
cd examples
unset PYTHONPATH
python -m pytest -n 4 **/test*.py
python -m pytest -n 8 **/test*.py
- name: Run tests
run: |
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
cd testing/python
unset PYTHONPATH
python -m pytest -n 4
\ No newline at end of file
python -m pytest -n 8
......@@ -20,10 +20,7 @@ from transformers import (
from vllm import LLM, SamplingParams
from vllm.assets.image import ImageAsset
from vllm.config import TokenizerPoolConfig
from vllm.distributed import (
destroy_distributed_environment,
destroy_model_parallel,
)
from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel)
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
from vllm.sequence import SampleLogprobs
......
......@@ -49,11 +49,9 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
scores_max_0 = T.alloc_fragment([block_H], accum_dtype)
scores_max_1 = T.alloc_fragment([block_H], accum_dtype)
scores_max = T.alloc_shared([block_H], accum_dtype)
# TODO(lei): this is a workaround for the bug of replicate if stmt.
# have to be optimized in future with index aware sync thread pass injection.
# scores_max_prev_0 and scores_max_prev_1 should be allocated in fragment.
scores_max_prev_0 = T.alloc_shared([block_H], accum_dtype)
scores_max_prev_1 = T.alloc_shared([block_H], accum_dtype)
scores_max_prev_0 = T.alloc_fragment([block_H], accum_dtype)
scores_max_prev_1 = T.alloc_fragment([block_H], accum_dtype)
scores_scale_0 = T.alloc_shared([block_H], accum_dtype)
scores_scale_1 = T.alloc_shared([block_H], accum_dtype)
......@@ -395,7 +393,7 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
return out
def main(batch=132, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64):
def main(batch=1, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64):
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
......@@ -414,7 +412,7 @@ def main(batch=132, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=132, help='batch size')
parser.add_argument('--batch', type=int, default=1, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
......
......@@ -258,6 +258,8 @@ private:
// TODO(tqchen) more standard set based testing.
bool has_same_index = true;
bool range_is_equal = true;
bool range_is_overlap = true;
for (const auto &kv : prev.thread_range) {
if (!StructuralEqual()(kv.second, curr.thread_range[kv.first])) {
range_is_equal = false;
......@@ -275,6 +277,40 @@ private:
const auto &curr_indice = curr.buffer_indices[i];
if (!ExprDeepEqual()(prev_indice, curr_indice)) {
has_same_index = false;
// If both are const, we can check if they are disjoint
// by checking if the bounds are disjoint
// [1024, 2048], [2048, 3072] are disjoint
// [1024, 2048], [1024, 1024] are not disjoint
auto prev_bound = analyzer_.const_int_bound(prev_indice);
auto curr_bound = analyzer_.const_int_bound(curr_indice);
if (prev_bound.defined() && curr_bound.defined()) {
if (prev_bound->min_value > curr_bound->max_value ||
curr_bound->min_value > prev_bound->max_value) {
range_is_overlap = false;
break;
}
}
// if we can prove prev_indice < curr_indice or prev_indice >
// curr_indice, then they are not overlap
auto prev_dtype = prev_indice.dtype();
auto curr_dtype = curr_indice.dtype();
if (prev_dtype.lanes() != curr_dtype.lanes()) {
// can not support different lanes binary op like <, >, <=, >=
// skip otherwise it will lead to error
continue;
}
bool provably_disjoint =
analyzer_.CanProve(prev_indice < curr_indice,
arith::ProofStrength::kSymbolicBound) ||
analyzer_.CanProve(prev_indice > curr_indice,
arith::ProofStrength::kSymbolicBound);
if (provably_disjoint) {
range_is_overlap = false;
break;
}
}
if (!(has_same_index)) {
......@@ -291,9 +327,13 @@ private:
if (prev.double_buffer_write && curr.type == kRead && !loop_carry) {
return false;
}
// If nothing else allows sharing the same buffer, then they are
// in conflict.
return true;
// if range_is_overlap is true, then they are in conflict, we should return
// true. if range_is_overlap is false, then they are not in conflict, we
// should return false.
return range_is_overlap;
}
void VisitStmt_(const AttrStmtNode *op) final {
......
......@@ -175,7 +175,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
mod)
mod = tilelang.transform.ThreadSync("shared")(mod)
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
# Inject PTX async copy must behind the thread sync pass
# as ptx async copy won't be recognized as a valid buffer load
mod = tilelang.transform.InjectPTXAsyncCopy()(mod)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment