Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3f808cc0
Unverified
Commit
3f808cc0
authored
Feb 26, 2025
by
Joe Runde
Committed by
GitHub
Feb 26, 2025
Browse files
[Bugfix] Do not crash V0 engine on input errors (#13101)
Signed-off-by:
Joe Runde
<
Joseph.Runde@ibm.com
>
parent
ec8a5e53
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
172 additions
and
6 deletions
+172
-6
tests/mq_llm_engine/test_error_handling.py
tests/mq_llm_engine/test_error_handling.py
+78
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+59
-3
vllm/engine/multiprocessing/engine.py
vllm/engine/multiprocessing/engine.py
+9
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+8
-3
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+18
-0
No files found.
tests/mq_llm_engine/test_error_handling.py
View file @
3f808cc0
...
...
@@ -18,6 +18,7 @@ from vllm.engine.multiprocessing.engine import MQLLMEngine
from
vllm.entrypoints.openai.api_server
import
build_async_engine_client
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SequenceGroupMetadata
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
FlexibleArgumentParser
...
...
@@ -292,3 +293,80 @@ async def test_engine_process_death(tmp_socket):
await
client
.
check_health
()
client
.
close
()
def
run_with_evil_input_processing
(
engine_args
:
AsyncEngineArgs
,
ipc_path
:
str
):
"""Simulate an exception while preparing inputs for the model.
In the wild, this could be something like a multimodal input processor
failing on invalid image data."""
# Make engine.
engine
=
MQLLMEngine
.
from_engine_args
(
engine_args
=
engine_args
,
usage_context
=
UsageContext
.
UNKNOWN_CONTEXT
,
ipc_path
=
ipc_path
)
runner
=
engine
.
engine
.
model_executor
.
driver_worker
.
worker
.
model_runner
# Raise error in the model runner when adding a sequence group.
# See class ModelInputForGPUBuilder
def
raiser
(
_
,
seq_group_metadata
:
SequenceGroupMetadata
):
if
seq_group_metadata
.
request_id
.
startswith
(
"evil"
):
raise
RAISED_ERROR
(
RAISED_VALUE
)
runner
.
builder
.
per_seq_group_compute_fns
.
append
(
raiser
)
# Run engine.
engine
.
start
()
@
pytest
.
mark
.
asyncio
async
def
test_failed_inputs
(
tmp_socket
):
with
RemoteMQLLMEngine
(
engine_args
=
ENGINE_ARGS
,
ipc_path
=
tmp_socket
,
run_fn
=
run_with_evil_input_processing
)
as
engine
:
client
=
await
engine
.
make_client
()
assert
client
.
is_running
# Engine should be healthy
await
client
.
check_health
()
async
def
run_failing_request
():
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(
max_tokens
=
10
),
request_id
=
"evil"
+
str
(
uuid
.
uuid4
())):
pass
async
def
run_passing_request
():
async
for
_
in
client
.
generate
(
prompt
=
"Hello my name is"
,
sampling_params
=
SamplingParams
(
max_tokens
=
10
),
request_id
=
str
(
uuid
.
uuid4
())):
pass
passing_tasks
=
[
asyncio
.
create_task
(
run_passing_request
())
for
_
in
range
(
10
)
]
failing_tasks
=
[
asyncio
.
create_task
(
run_failing_request
())
for
_
in
range
(
10
)
]
await
asyncio
.
gather
(
*
failing_tasks
,
return_exceptions
=
True
)
await
asyncio
.
gather
(
*
passing_tasks
)
# All the bad inputs should have raised
for
task
in
failing_tasks
:
with
pytest
.
raises
(
RAISED_ERROR
):
task
.
result
()
# But all good inputs should have still succeeded
for
task
in
passing_tasks
:
task
.
result
()
# And the engine should remain healthy
assert
not
client
.
errored
await
client
.
check_health
()
client
.
close
()
vllm/engine/llm_engine.py
View file @
3f808cc0
...
...
@@ -60,6 +60,7 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
from
vllm.utils
import
(
Counter
,
Device
,
deprecate_kwargs
,
resolve_obj_by_qualname
,
weak_bind
)
from
vllm.version
import
__version__
as
VLLM_VERSION
from
vllm.worker.model_runner_base
import
InputProcessingError
logger
=
init_logger
(
__name__
)
_LOCAL_LOGGING_INTERVAL_SEC
=
5
...
...
@@ -410,6 +411,10 @@ class LLMEngine:
self
.
seq_id_to_seq_group
:
Dict
[
str
,
SequenceGroupBase
]
=
{}
# Flag to set when an input fails to process and the engine should run
# the next step without re-scheduling.
self
.
_skip_scheduling_next_step
=
False
def
_initialize_kv_caches
(
self
)
->
None
:
"""Initialize the KV cache in the worker(s).
...
...
@@ -1334,7 +1339,11 @@ class LLMEngine:
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
):
# The scheduler is also skipped if a single request caused the last
# engine step to fail, and the previous schedule needs to be rerun.
if
not
self
.
_has_remaining_steps
(
seq_group_metadata_list
)
and
not
self
.
_skip_scheduling_next_step
:
# Schedule iteration
(
seq_group_metadata_list
,
scheduler_outputs
,
allow_async_output_proc
...
...
@@ -1388,8 +1397,23 @@ class LLMEngine:
execute_model_req
.
async_callback
=
self
.
async_callbacks
[
virtual_engine
]
outputs
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
try
:
outputs
=
self
.
model_executor
.
execute_model
(
execute_model_req
=
execute_model_req
)
self
.
_skip_scheduling_next_step
=
False
except
InputProcessingError
as
e
:
# The input for this request cannot be processed, so we must
# abort it. If there are remaining requests in the batch that
# have been scheduled, they will be retried on the next step.
invalid_request_id
=
e
.
request_id
self
.
_abort_and_cache_schedule
(
request_id
=
invalid_request_id
,
virtual_engine
=
virtual_engine
,
seq_group_metadata_list
=
seq_group_metadata_list
,
scheduler_outputs
=
scheduler_outputs
,
allow_async_output_proc
=
allow_async_output_proc
)
# Raise so the caller is notified that this request failed
raise
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
...
...
@@ -1464,6 +1488,38 @@ class LLMEngine:
return
ctx
.
request_outputs
def
_abort_and_cache_schedule
(
self
,
request_id
:
str
,
virtual_engine
:
int
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
scheduler_outputs
:
SchedulerOutputs
,
allow_async_output_proc
:
bool
)
->
None
:
"""Aborts a single request, and caches the scheduler outputs minus that
request. This allows the next step to continue processing the remaining
requests without having to re-run the scheduler."""
# Abort the request and remove its sequence group from the current
# schedule
self
.
abort_request
(
request_id
)
for
i
,
metadata
in
enumerate
(
seq_group_metadata_list
):
if
metadata
.
request_id
==
request_id
:
del
seq_group_metadata_list
[
i
]
break
for
i
,
group
in
enumerate
(
scheduler_outputs
.
scheduled_seq_groups
):
if
group
.
seq_group
.
request_id
==
request_id
:
del
scheduler_outputs
.
scheduled_seq_groups
[
i
]
break
# If there are still other sequence groups left in the schedule, cache
# them and flag the engine to reuse the schedule.
if
len
(
seq_group_metadata_list
)
>
0
:
self
.
_skip_scheduling_next_step
=
True
# Reuse multi-step caching logic
self
.
_cache_scheduler_outputs_for_multi_step
(
virtual_engine
=
virtual_engine
,
scheduler_outputs
=
scheduler_outputs
,
seq_group_metadata_list
=
seq_group_metadata_list
,
allow_async_output_proc
=
allow_async_output_proc
)
def
_has_remaining_steps
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
)
->
bool
:
...
...
vllm/engine/multiprocessing/engine.py
View file @
3f808cc0
...
...
@@ -27,6 +27,7 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.worker.model_runner_base
import
InputProcessingError
logger
=
init_logger
(
__name__
)
...
...
@@ -210,6 +211,14 @@ class MQLLMEngine:
return
self
.
engine
.
step
()
except
SystemExit
:
raise
except
InputProcessingError
as
e
:
# Special case where we handle an error preparing the inputs for
# a single request in the batch
rpc_err
=
RPCError
(
request_id
=
e
.
request_id
,
is_engine_errored
=
False
,
exception
=
e
.
__cause__
)
self
.
_send_outputs
(
rpc_err
)
return
[]
except
BaseException
as
e
:
self
.
_set_errored
(
e
)
rpc_err
=
RPCError
(
request_id
=
None
,
...
...
vllm/worker/model_runner.py
View file @
3f808cc0
...
...
@@ -53,8 +53,8 @@ from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache,
is_pin_memory_available
,
supports_dynamo
,
weak_ref_tensor
)
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
_add_attn_metadata_broadcastable_dict
,
InputProcessingError
,
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerInputBuilderBase
,
_add_attn_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_sampling_metadata_from_tensor_dict
)
...
...
@@ -1216,7 +1216,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"""
self
.
builder
.
prepare
(
finished_requests_ids
)
for
seq_group_metadata
in
seq_group_metadata_list
:
self
.
builder
.
add_seq_group
(
seq_group_metadata
)
try
:
self
.
builder
.
add_seq_group
(
seq_group_metadata
)
except
Exception
as
e
:
# Raise an exception that tracks the ID of the bad request
raise
InputProcessingError
(
seq_group_metadata
.
request_id
,
str
(
e
))
from
e
self
.
builder
.
reset_cached_inter_data
()
...
...
vllm/worker/model_runner_base.py
View file @
3f808cc0
...
...
@@ -261,3 +261,21 @@ class ModelRunnerWrapperBase:
def
__getattr__
(
self
,
attr
):
return
getattr
(
self
.
model_runner
,
attr
)
class
InputProcessingError
(
Exception
):
"""This exception is raised when an error occurs preparing the inputs for
a single sequence group.
This allows the engine to gracefully handle errors with a single sequence
group without having to fail the entire batch.
"""
def
__init__
(
self
,
request_id
,
message
):
"""request_id is the id of the offending sequence group"""
self
.
request_id
=
request_id
self
.
message
=
message
super
().
__init__
(
self
.
message
)
def
__str__
(
self
):
return
"Failed to prepare inputs for sequence group with request id: "
\
f
"
{
self
.
request_id
}
, Error:
{
self
.
message
}
"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment