Unverified Commit 39ac6404 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Bug] Fix batch invariant test issue, bs=1 with `max_seq_num = 1` (#39320)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 0b790a25
...@@ -36,10 +36,10 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( ...@@ -36,10 +36,10 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
using the high-level v1 LLM() API only (no manual batching). using the high-level v1 LLM() API only (no manual batching).
Strategy: Strategy:
- Create two LLM engines with identical config except max_num_seqs: 1 vs N. - Create a single LLM engine configured for the larger batch limit (N).
- Compute a baseline output for the needle prompt with the bs=1 engine. - Compute a baseline output for the needle prompt when it is run alone.
- For many trials, generate a batch (size N) where the needle appears at a - For many trials, generate a mixed batch (size N) where the needle appears
random position among random filler prompts using the bs=N engine. at a random position among random filler prompts using the same engine.
- Track how many trials match vs mismatch, and report totals at the end. - Track how many trials match vs mismatch, and report totals at the end.
The test fails if any mismatches occur, but we still dump pass/fail The test fails if any mismatches occur, but we still dump pass/fail
counts. counts.
...@@ -83,11 +83,9 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( ...@@ -83,11 +83,9 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
needle_prompt = "There once was a " needle_prompt = "There once was a "
llm_bs1 = None llm = None
llm_bsN = None
try: try:
# Engine with bs=1 behavior llm = LLM_with_max_seqs(
llm_bs1 = LLM_with_max_seqs(
model=model, model=model,
max_num_seqs=max_batch_size, max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util, gpu_memory_utilization=gpu_mem_util,
...@@ -96,20 +94,11 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( ...@@ -96,20 +94,11 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
) )
# Baseline generation for the needle prompt alone. # Baseline generation for the needle prompt alone.
baseline_out = llm_bs1.generate([needle_prompt], sampling) baseline_out = llm.generate([needle_prompt], sampling)
assert len(baseline_out) == 1 assert len(baseline_out) == 1
assert len(baseline_out[0].outputs) >= 1 assert len(baseline_out[0].outputs) >= 1
baseline_text = baseline_out[0].outputs[0].text baseline_text = baseline_out[0].outputs[0].text
# Engine with larger batch limit (e.g., 64)
llm_bsN = LLM_with_max_seqs(
model=model,
max_num_seqs=max_batch_size,
gpu_memory_utilization=gpu_mem_util,
max_model_len=max_model_len,
attention_config=attention_config,
)
mismatches = 0 mismatches = 0
for trial in range(num_trials): for trial in range(num_trials):
...@@ -124,8 +113,8 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( ...@@ -124,8 +113,8 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
else: else:
prompts.append(_random_prompt(min_random_prompt, max_random_prompt)) prompts.append(_random_prompt(min_random_prompt, max_random_prompt))
# Generate with the larger-batch engine # Generate with the same engine but in a larger batch.
outputs = llm_bsN.generate(prompts, sampling) outputs = llm.generate(prompts, sampling)
# Find the needle output by position # Find the needle output by position
needle_output = outputs[needle_pos] needle_output = outputs[needle_pos]
assert needle_output.prompt == needle_prompt assert needle_output.prompt == needle_prompt
...@@ -151,12 +140,9 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( ...@@ -151,12 +140,9 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
finally: finally:
# Ensure engines are shutdown to free GPU/VRAM across test sessions # Ensure engines are shutdown to free GPU/VRAM across test sessions
if llm_bs1 is not None: if llm is not None:
with contextlib.suppress(Exception):
llm_bs1.shutdown()
if llm_bsN is not None:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
llm_bsN.shutdown() llm.shutdown()
@skip_unsupported @skip_unsupported
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment