Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
45c0526a
Unverified
Commit
45c0526a
authored
Dec 18, 2025
by
Nick Hill
Committed by
GitHub
Dec 19, 2025
Browse files
[BugFix] Handle errors when preprocessing added requests (#30895)
Signed-off-by:
Nick Hill
<
nhill@redhat.com
>
parent
d6b3d39b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
93 additions
and
3 deletions
+93
-3
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+4
-1
tests/v1/engine/test_preprocess_error_handling.py
tests/v1/engine/test_preprocess_error_handling.py
+56
-0
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+33
-2
No files found.
.buildkite/test-pipeline.yaml
View file @
45c0526a
...
@@ -319,7 +319,10 @@ steps:
...
@@ -319,7 +319,10 @@ steps:
# TODO: accuracy does not match, whether setting
# TODO: accuracy does not match, whether setting
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
-
pytest -v -s v1/e2e
-
pytest -v -s v1/e2e
-
pytest -v -s v1/engine
# Run this test standalone for now;
# need to untangle use (implicit) use of spawn/fork across the tests.
-
pytest -v -s v1/engine/test_preprocess_error_handling.py
-
pytest -v -s v1/engine --ignore v1/engine/test_preprocess_error_handling.py
-
label
:
V1 Test entrypoints
# 35min
-
label
:
V1 Test entrypoints
# 35min
timeout_in_minutes
:
50
timeout_in_minutes
:
50
...
...
tests/v1/engine/test_preprocess_error_handling.py
0 → 100644
View file @
45c0526a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch.cuda
from
vllm
import
LLM
,
SamplingParams
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.core
import
EngineCore
MODEL_NAME
=
"hmellor/tiny-random-LlamaForCausalLM"
def
test_preprocess_error_handling
(
monkeypatch
:
pytest
.
MonkeyPatch
):
"""Test that preprocessing errors are handled gracefully."""
assert
not
torch
.
cuda
.
is_initialized
(),
(
"fork needs to be used for the engine "
"core process and this isn't possible if cuda is already initialized"
)
# Store original method to call for non-failing requests
original_preprocess
=
EngineCore
.
preprocess_add_request
# Monkeypatch to make preprocess_add_request raise an exception
# only for requests with "FAIL" in the first token
def
conditional_failing_preprocess
(
self
,
request
:
EngineCoreRequest
):
# Fail if the first token id is 333
if
request
.
prompt_token_ids
and
request
.
prompt_token_ids
[
0
]
==
333
:
raise
ValueError
(
"Simulated preprocessing error!"
)
return
original_preprocess
(
self
,
request
)
monkeypatch
.
setattr
(
EngineCore
,
"preprocess_add_request"
,
conditional_failing_preprocess
)
llm
=
LLM
(
model
=
MODEL_NAME
)
# Create a failing request by crafting a request with an invalid token
# We need to use a direct approach since LLM.generate tokenizes for us
from
vllm.inputs
import
TokensPrompt
# This should raise an exception due to the preprocessing failure
# Special token id to trigger the failure
failing_prompt
=
TokensPrompt
(
prompt_token_ids
=
[
333
])
outputs
=
llm
.
generate
(
failing_prompt
,
SamplingParams
(
max_tokens
=
10
))
# type: ignore
assert
len
(
outputs
)
==
1
assert
len
(
outputs
[
0
].
outputs
[
0
].
token_ids
)
==
0
assert
outputs
[
0
].
finished
assert
outputs
[
0
].
outputs
[
0
].
finish_reason
==
"error"
# Verify the engine is still functional with a normal request
outputs
=
llm
.
generate
(
"Hello, my name is"
,
SamplingParams
(
max_tokens
=
10
))
assert
len
(
outputs
)
==
1
assert
len
(
outputs
[
0
].
outputs
[
0
].
token_ids
)
>
0
assert
outputs
[
0
].
outputs
[
0
].
finish_reason
in
(
"stop"
,
"length"
)
vllm/v1/engine/core.py
View file @
45c0526a
...
@@ -43,9 +43,11 @@ from vllm.v1.core.kv_cache_utils import (
...
@@ -43,9 +43,11 @@ from vllm.v1.core.kv_cache_utils import (
from
vllm.v1.core.sched.interface
import
SchedulerInterface
from
vllm.v1.core.sched.interface
import
SchedulerInterface
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.engine
import
(
from
vllm.v1.engine
import
(
EngineCoreOutput
,
EngineCoreOutputs
,
EngineCoreOutputs
,
EngineCoreRequest
,
EngineCoreRequest
,
EngineCoreRequestType
,
EngineCoreRequestType
,
FinishReason
,
ReconfigureDistributedRequest
,
ReconfigureDistributedRequest
,
ReconfigureRankType
,
ReconfigureRankType
,
UtilityOutput
,
UtilityOutput
,
...
@@ -1055,9 +1057,14 @@ class EngineCoreProc(EngineCore):
...
@@ -1055,9 +1057,14 @@ class EngineCoreProc(EngineCore):
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
request_type
=
EngineCoreRequestType
(
bytes
(
type_frame
.
buffer
))
# Deserialize the request data.
# Deserialize the request data.
request
:
Any
if
request_type
==
EngineCoreRequestType
.
ADD
:
if
request_type
==
EngineCoreRequestType
.
ADD
:
request
=
add_request_decoder
.
decode
(
data_frames
)
req
:
EngineCoreRequest
=
add_request_decoder
.
decode
(
data_frames
)
request
=
self
.
preprocess_add_request
(
request
)
try
:
request
=
self
.
preprocess_add_request
(
req
)
except
Exception
:
self
.
_handle_request_preproc_error
(
req
)
continue
else
:
else
:
request
=
generic_decoder
.
decode
(
data_frames
)
request
=
generic_decoder
.
decode
(
data_frames
)
...
@@ -1141,6 +1148,30 @@ class EngineCoreProc(EngineCore):
...
@@ -1141,6 +1148,30 @@ class EngineCoreProc(EngineCore):
# Limit the number of buffers to reuse.
# Limit the number of buffers to reuse.
reuse_buffers
.
append
(
buffer
)
reuse_buffers
.
append
(
buffer
)
def
_handle_request_preproc_error
(
self
,
request
:
EngineCoreRequest
)
->
None
:
"""Log and return a request-scoped error response for exceptions raised
from the add request preprocessing in the input socket processing thread.
"""
logger
.
exception
(
"Unexpected error pre-processing request %s"
,
request
.
request_id
)
self
.
output_queue
.
put_nowait
(
(
request
.
client_index
,
EngineCoreOutputs
(
engine_index
=
self
.
engine_index
,
finished_requests
=
{
request
.
request_id
},
outputs
=
[
EngineCoreOutput
(
request_id
=
request
.
request_id
,
new_token_ids
=
[],
finish_reason
=
FinishReason
.
ERROR
,
)
],
),
)
)
class
DPEngineCoreProc
(
EngineCoreProc
):
class
DPEngineCoreProc
(
EngineCoreProc
):
"""ZMQ-wrapper for running EngineCore in background process
"""ZMQ-wrapper for running EngineCore in background process
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment