Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
31a4b3e6
Unverified
Commit
31a4b3e6
authored
Oct 08, 2025
by
Thomas Parnell
Committed by
GitHub
Oct 07, 2025
Browse files
Revert #24446 and #26168 (#26332)
Signed-off-by:
Thomas Parnell
<
tpa@zurich.ibm.com
>
parent
caf8b1c0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
10 additions
and
117 deletions
+10
-117
tests/entrypoints/llm/test_generate.py
tests/entrypoints/llm/test_generate.py
+2
-3
tests/v1/e2e/test_context_length.py
tests/v1/e2e/test_context_length.py
+0
-90
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+1
-1
vllm/v1/core/sched/utils.py
vllm/v1/core/sched/utils.py
+1
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+6
-22
No files found.
tests/entrypoints/llm/test_generate.py
View file @
31a4b3e6
...
@@ -85,11 +85,10 @@ def test_max_model_len():
...
@@ -85,11 +85,10 @@ def test_max_model_len():
num_total_tokens
=
len
(
output
.
prompt_token_ids
)
+
len
(
num_total_tokens
=
len
(
output
.
prompt_token_ids
)
+
len
(
output
.
outputs
[
0
].
token_ids
output
.
outputs
[
0
].
token_ids
)
)
# Total tokens must not exceed max_model_len + 1 (the last token can be
# Total tokens must not exceed max_model_len.
# generated with the context length equal to the max model length)
# It can be less if generation finishes due to other reasons (e.g., EOS)
# It can be less if generation finishes due to other reasons (e.g., EOS)
# before reaching the absolute model length limit.
# before reaching the absolute model length limit.
assert
num_total_tokens
<=
max_model_len
+
1
assert
num_total_tokens
<=
max_model_len
def
test_log_stats
():
def
test_log_stats
():
...
...
tests/v1/e2e/test_context_length.py
deleted
100644 → 0
View file @
caf8b1c0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
end-to-end tests for context length corner cases of vLLM v1 model runner
versus HuggingFace's transformers.
This test verifies the following behavior: allow prefill and decodes on the
model's maximum context length ``max_model_len`` and get one more token.
Test strategy
- Build a prompt consisting of exactly ``prompt_len`` tokens.
- Run vLLM generation requesting ``max_tokens`` new tokens.
- Run HF generation on the same prompt requesting the same number of tokens.
- Assert both return the same number of generated tokens and the same ids.
Test cases
- Prefill a prompt of ``max_model_len`` (2048) and request a single token which
will be sampled after the prefill (context length ``max_model_len``).
- Prefill a prompt of ``max_model_len`` - 1 (2047) and request two tokens where
the 1st will be sampled after the prefill and the 2nd after the first decode
(context length ``max_model_len``).
"""
import
pytest
from
tests.conftest
import
HfRunner
,
VllmRunner
from
tests.models.utils
import
check_outputs_equal
from
tests.utils
import
create_new_process_for_each_test
@
create_new_process_for_each_test
()
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"JackFram/llama-160m"
])
@
pytest
.
mark
.
parametrize
(
"prompt_len, max_tokens"
,
[
(
2048
,
1
),
# prompt_len = max_model_len
(
2047
,
2
),
# prompt_len = max_model_len - 1
],
)
def
test_max_context_length
(
model
:
str
,
vllm_runner
:
type
[
VllmRunner
],
hf_runner
:
type
[
HfRunner
],
prompt_len
:
int
,
max_tokens
:
int
,
)
->
None
:
"""Compare vLLM and HuggingFace when the prompt already fills the
model's maximum context length and we request a single new token.
The test ensures vLLM does not raise the "Sampled token IDs exceed the
max model length" assertion and that both vLLM and HF produce the same
single token when given the same inputs.
"""
# Construct a prompt of size prompt_len
prompt_ids
=
[[
43
]
*
prompt_len
]
# --- vLLM generation ---
with
vllm_runner
(
model_name
=
model
,
tokenizer_name
=
model
,
max_model_len
=
2048
,
max_num_seqs
=
1
,
tensor_parallel_size
=
1
,
)
as
vllm_model
:
# Generate max_tokens new tokens deterministically.
vllm_outputs
=
vllm_model
.
generate_greedy
(
prompt_ids
,
max_tokens
)
# --- HuggingFace generation ---
with
hf_runner
(
model_name
=
model
,
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_greedy
(
prompt_ids
,
max_tokens
)
# vLLM and HF runners return prompt + generated tokens. Slice off the prompt.
vllm_output_ids
=
vllm_outputs
[
0
][
0
][
prompt_len
:]
hf_output_ids
=
hf_outputs
[
0
][
0
][
prompt_len
:]
# check that exactly max_tokens tokens were generated with vLLM and HF
assert
len
(
vllm_output_ids
)
==
len
(
hf_output_ids
)
==
max_tokens
# check that vLLM outputs (token ids) match HF outputs
# Note: for simplicity don't pass detokenized string
check_outputs_equal
(
outputs_0_lst
=
[(
hf_output_ids
,
""
)],
outputs_1_lst
=
[(
vllm_output_ids
,
""
)],
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
vllm/v1/core/sched/scheduler.py
View file @
31a4b3e6
...
@@ -223,7 +223,7 @@ class Scheduler(SchedulerInterface):
...
@@ -223,7 +223,7 @@ class Scheduler(SchedulerInterface):
# Make sure the input position does not exceed the max model len.
# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
# This is necessary when using spec decoding.
num_new_tokens
=
min
(
num_new_tokens
=
min
(
num_new_tokens
,
self
.
max_model_len
-
request
.
num_computed_tokens
num_new_tokens
,
self
.
max_model_len
-
1
-
request
.
num_computed_tokens
)
)
# Schedule encoder inputs.
# Schedule encoder inputs.
...
...
vllm/v1/core/sched/utils.py
View file @
31a4b3e6
...
@@ -44,7 +44,7 @@ def check_stop(
...
@@ -44,7 +44,7 @@ def check_stop(
request
:
Request
,
max_model_len
:
int
,
pooler_output
:
Optional
[
torch
.
Tensor
]
=
None
request
:
Request
,
max_model_len
:
int
,
pooler_output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
bool
:
)
->
bool
:
if
(
if
(
request
.
num_tokens
>
max_model_len
request
.
num_tokens
>
=
max_model_len
or
request
.
num_output_tokens
>=
request
.
max_tokens
or
request
.
num_output_tokens
>=
request
.
max_tokens
):
):
request
.
status
=
RequestStatus
.
FINISHED_LENGTH_CAPPED
request
.
status
=
RequestStatus
.
FINISHED_LENGTH_CAPPED
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
31a4b3e6
...
@@ -2317,30 +2317,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
...
@@ -2317,30 +2317,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
end_idx
=
start_idx
+
len
(
sampled_ids
)
end_idx
=
start_idx
+
len
(
sampled_ids
)
assert
end_idx
<=
self
.
max_model_len
+
1
,
(
assert
end_idx
<=
self
.
max_model_len
,
(
"Sampled token IDs exceed the max model length
+ 1
. "
"Sampled token IDs exceed the max model length. "
f
"Total number of tokens:
{
end_idx
}
> max_model_len
+ 1
: "
f
"Total number of tokens:
{
end_idx
}
> max_model_len: "
f
"
{
self
.
max_model_len
+
1
}
"
f
"
{
self
.
max_model_len
}
"
)
)
n_tokens_cache
=
len
(
sampled_ids
)
self
.
input_batch
.
token_ids_cpu
[
req_idx
,
start_idx
:
end_idx
]
=
sampled_ids
self
.
input_batch
.
is_token_ids
[
req_idx
,
start_idx
:
end_idx
]
=
True
# Sampled token IDs exceed the max model length by 1. This is
# legitimate as we can still sample 1 last token when the context
# length equals the max model length. Note that we do not need to
# cache this token ID as the sequence finishes after this step.
# Additionally, the buffers token_ids_cpu and is_token_ids are of
# size max model length only.
if
end_idx
==
self
.
max_model_len
+
1
:
n_tokens_cache
-=
1
self
.
input_batch
.
token_ids_cpu
[
req_idx
,
start_idx
:
(
start_idx
+
n_tokens_cache
)
]
=
sampled_ids
[:
n_tokens_cache
]
self
.
input_batch
.
is_token_ids
[
req_idx
,
start_idx
:
(
start_idx
+
n_tokens_cache
)
]
=
True
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens
[
req_idx
]
=
end_idx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment