Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8ee846c2
Unverified
Commit
8ee846c2
authored
Oct 03, 2025
by
Yannick Schnider
Committed by
GitHub
Oct 03, 2025
Browse files
[Bugfix] Re-enable prefill of max model length (#24446)
Signed-off-by:
Yannick Schnider
<
yannick.schnider1@ibm.com
>
parent
812b7f54
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
113 additions
and
8 deletions
+113
-8
tests/v1/e2e/test_context_length.py
tests/v1/e2e/test_context_length.py
+91
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+22
-8
No files found.
tests/v1/e2e/test_context_length.py
0 → 100644
View file @
8ee846c2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
end-to-end tests for context length corner cases of vLLM v1 model runner
versus HuggingFace's transformers.
This test verifies the following behavior: allow a prefill that fills the
model's maximum context length and then request a single new token.
Test strategy
- Build a textual prompt that tokenizes to exactly ``max_model_len`` tokens.
- Run vLLM generation requesting a single new token (max_tokens=1).
- Run HF generation on the same prompt requesting a single token too.
- Assert both return the same number of generated tokens and the same ids.
"""
import
pytest
import
torch
from
transformers
import
AutoModelForCausalLM
from
tests.models.utils
import
check_outputs_equal
from
tests.utils
import
create_new_process_for_each_test
from
vllm
import
LLM
,
SamplingParams
from
vllm.inputs
import
TokensPrompt
@
create_new_process_for_each_test
()
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"JackFram/llama-160m"
])
@
pytest
.
mark
.
parametrize
(
"max_model_len"
,
[
2048
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
1
])
def
test_prefill_max_context_length
(
model
:
str
,
max_model_len
:
int
,
max_tokens
:
int
,
)
->
None
:
"""Compare vLLM and HuggingFace when the prompt already fills the
model's maximum context length and we request a single new token.
The test ensures vLLM does not raise the "Sampled token IDs exceed the
max model length" assertion and that both vLLM and HF produce the same
single token when given the same inputs.
"""
# Construct a prompt of size max_model_len
prompt_ids
=
[[
43
]
*
max_model_len
]
# Generate max_tokens new tokens deterministically.
sampling_params
=
[
SamplingParams
(
max_tokens
=
max_tokens
,
temperature
=
0.0
,
ignore_eos
=
True
)
]
# --- vLLM generation ---
llm
=
LLM
(
model
=
model
,
tokenizer
=
model
,
max_num_seqs
=
1
,
tensor_parallel_size
=
1
,
)
vllm_token_prompts
=
[
TokensPrompt
(
prompt_token_ids
=
prompt_ids
[
0
])]
vllm_results
=
llm
.
generate
(
vllm_token_prompts
,
sampling_params
)
vllm_output_ids
=
vllm_results
[
0
].
outputs
[
0
].
token_ids
# --- HuggingFace generation ---
with
torch
.
no_grad
():
hf_model
=
AutoModelForCausalLM
.
from_pretrained
(
model
)
# HF expects a tensor of input ids shaped (batch, seq_len).
hf_input_tokens
=
torch
.
tensor
(
prompt_ids
[
0
]).
unsqueeze
(
0
)
# Generate max_tokens new tokens deterministically.
hf_generated
=
hf_model
.
generate
(
hf_input_tokens
,
do_sample
=
False
,
min_new_tokens
=
max_tokens
,
max_new_tokens
=
max_tokens
,
)
# HF returns the prompt + generated tokens. Slice off the prompt.
hf_output_ids
=
hf_generated
.
cpu
().
tolist
()[
0
][
len
(
prompt_ids
[
0
]):]
# check that vLLM outputs (token ids) match HF outputs
# Note: for simplicity don't pass detokenized string
check_outputs_equal
(
outputs_0_lst
=
[(
hf_output_ids
,
""
)],
outputs_1_lst
=
[(
vllm_output_ids
,
""
)],
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
vllm/v1/worker/gpu_model_runner.py
View file @
8ee846c2
...
...
@@ -2247,14 +2247,28 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
end_idx
=
start_idx
+
len
(
sampled_ids
)
assert
end_idx
<=
self
.
max_model_len
,
(
"Sampled token IDs exceed the max model length. "
f
"Total number of tokens:
{
end_idx
}
> max_model_len: "
f
"
{
self
.
max_model_len
}
"
)
self
.
input_batch
.
token_ids_cpu
[
req_idx
,
start_idx
:
end_idx
]
=
sampled_ids
self
.
input_batch
.
is_token_ids
[
req_idx
,
start_idx
:
end_idx
]
=
True
assert
end_idx
<=
self
.
max_model_len
+
1
,
(
"Sampled token IDs exceed the max model length + 1. "
f
"Total number of tokens:
{
end_idx
}
> max_model_len + 1: "
f
"
{
self
.
max_model_len
+
1
}
"
)
n_tokens_cache
=
len
(
sampled_ids
)
# Sampled token IDs exceed the max model length by 1. This is
# legitimate as we can still sample 1 last token when the context
# length equals the max model length. Note that we do not need to
# cache this token ID as the sequence finishes after this step.
# Additionally, the buffers token_ids_cpu and is_token_ids are of
# size max model length only.
if
end_idx
==
self
.
max_model_len
+
1
:
n_tokens_cache
-=
1
self
.
input_batch
.
token_ids_cpu
[
req_idx
,
start_idx
:(
start_idx
+
n_tokens_cache
)]
=
sampled_ids
[:
n_tokens_cache
]
self
.
input_batch
.
is_token_ids
[
req_idx
,
start_idx
:(
start_idx
+
n_tokens_cache
)]
=
True
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens
[
req_idx
]
=
end_idx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment