Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
704c1dad
Unverified
Commit
704c1dad
authored
Jan 30, 2026
by
jh-nv
Committed by
GitHub
Jan 30, 2026
Browse files
fix: fix vllm graceful shutdown (#5818)
parent
284f772b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
230 additions
and
115 deletions
+230
-115
components/src/dynamo/vllm/engine_monitor.py
components/src/dynamo/vllm/engine_monitor.py
+41
-3
components/src/dynamo/vllm/handlers.py
components/src/dynamo/vllm/handlers.py
+142
-94
components/src/dynamo/vllm/main.py
components/src/dynamo/vllm/main.py
+43
-18
components/src/dynamo/vllm/multimodal_handlers/worker_handler.py
...nts/src/dynamo/vllm/multimodal_handlers/worker_handler.py
+4
-0
No files found.
components/src/dynamo/vllm/engine_monitor.py
View file @
704c1dad
...
@@ -25,7 +25,12 @@ class VllmEngineMonitor:
...
@@ -25,7 +25,12 @@ class VllmEngineMonitor:
Monitors the health of the vLLM engine and initiates a shutdown if the engine is dead.
Monitors the health of the vLLM engine and initiates a shutdown if the engine is dead.
"""
"""
def
__init__
(
self
,
runtime
:
DistributedRuntime
,
engine_client
:
AsyncLLM
):
def
__init__
(
self
,
runtime
:
DistributedRuntime
,
engine_client
:
AsyncLLM
,
shutdown_event
:
asyncio
.
Event
=
None
,
):
if
not
isinstance
(
runtime
,
DistributedRuntime
):
if
not
isinstance
(
runtime
,
DistributedRuntime
):
raise
ValueError
(
raise
ValueError
(
f
"
{
self
.
__class__
.
__name__
}
requires an instance of DistributedRuntime."
f
"
{
self
.
__class__
.
__name__
}
requires an instance of DistributedRuntime."
...
@@ -37,6 +42,7 @@ class VllmEngineMonitor:
...
@@ -37,6 +42,7 @@ class VllmEngineMonitor:
self
.
runtime
=
runtime
self
.
runtime
=
runtime
self
.
engine_client
=
engine_client
self
.
engine_client
=
engine_client
self
.
shutdown_event
=
shutdown_event
self
.
_monitor_task
=
asyncio
.
create_task
(
self
.
_check_engine_health
())
self
.
_monitor_task
=
asyncio
.
create_task
(
self
.
_check_engine_health
())
logger
.
info
(
logger
.
info
(
...
@@ -66,10 +72,41 @@ class VllmEngineMonitor:
...
@@ -66,10 +72,41 @@ class VllmEngineMonitor:
signal
.
alarm
(
0
)
signal
.
alarm
(
0
)
async
def
_check_engine_health
(
self
):
async
def
_check_engine_health
(
self
):
"""
Continuously check engine health until:
1. Engine dies (EngineDeadError) - initiate shutdown
2. Shutdown event is triggered - stop monitoring gracefully
3. Task is cancelled - cleanup
"""
while
True
:
while
True
:
try
:
try
:
# Check if shutdown event was triggered - stop monitoring
if
self
.
shutdown_event
and
self
.
shutdown_event
.
is_set
():
logger
.
info
(
f
"
{
self
.
__class__
.
__name__
}
: Shutdown event detected, stopping engine health monitoring."
)
break
await
self
.
engine_client
.
check_health
()
await
self
.
engine_client
.
check_health
()
await
asyncio
.
sleep
(
HEALTH_CHECK_INTERVAL
)
# Sleep with shutdown event awareness for faster response
if
self
.
shutdown_event
:
try
:
await
asyncio
.
wait_for
(
self
.
shutdown_event
.
wait
(),
timeout
=
HEALTH_CHECK_INTERVAL
)
# Shutdown event was set during sleep
logger
.
info
(
f
"
{
self
.
__class__
.
__name__
}
: Shutdown event detected, stopping engine health monitoring."
)
break
except
asyncio
.
TimeoutError
:
# Normal timeout, continue monitoring
pass
else
:
# No shutdown event, just sleep normally
await
asyncio
.
sleep
(
HEALTH_CHECK_INTERVAL
)
except
EngineDeadError
as
e
:
except
EngineDeadError
as
e
:
logger
.
error
(
f
"Traceback:
{
traceback
.
format_exc
()
}
"
)
logger
.
error
(
f
"Traceback:
{
traceback
.
format_exc
()
}
"
)
logger
.
error
(
f
"vLLM AsyncLLM health check failed:
{
e
}
"
)
logger
.
error
(
f
"vLLM AsyncLLM health check failed:
{
e
}
"
)
...
@@ -78,4 +115,5 @@ class VllmEngineMonitor:
...
@@ -78,4 +115,5 @@ class VllmEngineMonitor:
self
.
runtime
.
shutdown
()
self
.
runtime
.
shutdown
()
os
.
_exit
(
1
)
os
.
_exit
(
1
)
except
asyncio
.
CancelledError
:
except
asyncio
.
CancelledError
:
pass
logger
.
debug
(
f
"
{
self
.
__class__
.
__name__
}
: Health check task cancelled."
)
break
components/src/dynamo/vllm/handlers.py
View file @
704c1dad
...
@@ -243,6 +243,7 @@ class BaseWorkerHandler(ABC):
...
@@ -243,6 +243,7 @@ class BaseWorkerHandler(ABC):
generate_endpoint
=
None
,
generate_endpoint
=
None
,
config
=
None
,
config
=
None
,
use_vllm_tokenizer
:
bool
=
False
,
use_vllm_tokenizer
:
bool
=
False
,
shutdown_event
:
asyncio
.
Event
|
None
=
None
,
):
):
self
.
runtime
=
runtime
self
.
runtime
=
runtime
self
.
component
=
component
self
.
component
=
component
...
@@ -251,7 +252,7 @@ class BaseWorkerHandler(ABC):
...
@@ -251,7 +252,7 @@ class BaseWorkerHandler(ABC):
self
.
kv_publishers
:
list
[
ZmqKvEventPublisher
]
|
None
=
None
self
.
kv_publishers
:
list
[
ZmqKvEventPublisher
]
|
None
=
None
self
.
generate_endpoint
=
generate_endpoint
self
.
generate_endpoint
=
generate_endpoint
self
.
config
=
config
self
.
config
=
config
self
.
engine_monitor
=
VllmEngineMonitor
(
runtime
,
engine
)
self
.
engine_monitor
=
VllmEngineMonitor
(
runtime
,
engine
,
shutdown_event
)
self
.
image_loader
=
ImageLoader
()
self
.
image_loader
=
ImageLoader
()
self
.
temp_dirs
:
list
[
tempfile
.
TemporaryDirectory
]
=
[]
self
.
temp_dirs
:
list
[
tempfile
.
TemporaryDirectory
]
=
[]
self
.
model_max_len
=
model_max_len
self
.
model_max_len
=
model_max_len
...
@@ -272,6 +273,9 @@ class BaseWorkerHandler(ABC):
...
@@ -272,6 +273,9 @@ class BaseWorkerHandler(ABC):
tokenizer
=
engine
.
tokenizer
tokenizer
=
engine
.
tokenizer
self
.
input_param_manager
=
InputParamManager
(
tokenizer
)
self
.
input_param_manager
=
InputParamManager
(
tokenizer
)
# Store shutdown event for graceful shutdown monitoring
self
.
shutdown_event
=
shutdown_event
async
def
sleep
(
self
,
body
:
dict
)
->
dict
:
async
def
sleep
(
self
,
body
:
dict
)
->
dict
:
"""Sleep the engine to release GPU memory and unregister from discovery.
"""Sleep the engine to release GPU memory and unregister from discovery.
...
@@ -339,14 +343,44 @@ class BaseWorkerHandler(ABC):
...
@@ -339,14 +343,44 @@ class BaseWorkerHandler(ABC):
raise
NotImplementedError
raise
NotImplementedError
async
def
_monitor_abort
(
self
,
context
,
request_id
,
is_prefill
):
async
def
_monitor_abort
(
self
,
context
,
request_id
,
is_prefill
):
"""Background task that monitors for context cancellation and aborts the request."""
"""
Background task that monitors for context cancellation and shutdown.
Aborts the request if either occurs. Raises GeneratorExit if shutdown was triggered.
"""
try
:
try
:
await
context
.
async_killed_or_stopped
()
# Build list of futures/tasks to wait for
# If we reach here, the context was stopped or killed
wait_for
=
[
context
.
async_killed_or_stopped
()]
shutdown_task
=
None
if
self
.
shutdown_event
:
# Create task for shutdown monitoring and add to wait list
shutdown_task
=
asyncio
.
create_task
(
self
.
shutdown_event
.
wait
())
wait_for
.
append
(
shutdown_task
)
# Wait for whichever happens first
done
,
pending
=
await
asyncio
.
wait
(
wait_for
,
return_when
=
asyncio
.
FIRST_COMPLETED
,
)
# Cancel the pending task/future
for
task
in
pending
:
task
.
cancel
()
try
:
await
task
except
asyncio
.
CancelledError
:
pass
# Abort the request
await
self
.
engine_client
.
abort
(
request_id
)
await
self
.
engine_client
.
abort
(
request_id
)
logger
.
debug
(
logger
.
debug
(
f
"Aborted
{
'Prefill '
if
is_prefill
else
''
}
Request ID:
{
request_id
}
"
f
"Aborted
{
'Prefill '
if
is_prefill
else
''
}
Request ID:
{
request_id
}
"
)
)
# Check which event triggered and raise GeneratorExit if shutdown
if
shutdown_task
and
shutdown_task
in
done
:
raise
GeneratorExit
(
"Engine was shut down during generation."
)
except
asyncio
.
CancelledError
:
except
asyncio
.
CancelledError
:
# Task was cancelled, normal cleanup if not aborted
# Task was cancelled, normal cleanup if not aborted
pass
pass
...
@@ -355,18 +389,24 @@ class BaseWorkerHandler(ABC):
...
@@ -355,18 +389,24 @@ class BaseWorkerHandler(ABC):
@
asynccontextmanager
@
asynccontextmanager
async
def
_abort_monitor
(
self
,
context
,
request_id
,
is_prefill
=
False
):
async
def
_abort_monitor
(
self
,
context
,
request_id
,
is_prefill
=
False
):
"""Context manager that creates and automatically cleans up an abort monitoring task."""
"""
Context manager that creates and automatically cleans up an abort monitoring task.
If shutdown event was triggered, raises GeneratorExit on exit.
"""
task
=
asyncio
.
create_task
(
self
.
_monitor_abort
(
context
,
request_id
,
is_prefill
))
task
=
asyncio
.
create_task
(
self
.
_monitor_abort
(
context
,
request_id
,
is_prefill
))
try
:
try
:
yield
task
yield
task
finally
:
finally
:
# C
ancel
the abort monitoring task
when exiting the context
# C
lean up
the abort monitoring task
if
not
task
.
done
():
if
not
task
.
done
():
task
.
cancel
()
task
.
cancel
()
try
:
try
:
await
task
await
task
except
asyncio
.
CancelledError
:
except
asyncio
.
CancelledError
:
pass
pass
else
:
# If the task completed, check if it raised GeneratorExit
task
.
result
()
async
def
clear_kv_blocks
(
self
,
request
=
None
):
async
def
clear_kv_blocks
(
self
,
request
=
None
):
try
:
try
:
...
@@ -389,6 +429,20 @@ class BaseWorkerHandler(ABC):
...
@@ -389,6 +429,20 @@ class BaseWorkerHandler(ABC):
self
.
_lora_load_locks
[
lora_name
]
=
lock
self
.
_lora_load_locks
[
lora_name
]
=
lock
return
lock
return
lock
def
_normalize_finish_reason
(
self
,
finish_reason
:
str
)
->
str
:
"""
Normalize vLLM finish reasons to Dynamo-compatible values.
vLLM may return finish reasons that aren't recognized by Dynamo's Rust layer.
This method maps them to compatible values.
[TODO]: Remove this method and add the right code in the Rust layer.
"""
# Map vLLM's "abort" to Dynamo's "cancelled"
if
finish_reason
.
startswith
(
"abort"
):
logging
.
debug
(
f
"Normalizing finish reason:
{
finish_reason
}
to cancelled"
)
return
"cancelled"
return
finish_reason
async
def
load_lora
(
self
,
request
=
None
):
async
def
load_lora
(
self
,
request
=
None
):
"""
"""
Load a LoRA adapter dynamically into the vLLM's AsyncLLM engine.
Load a LoRA adapter dynamically into the vLLM's AsyncLLM engine.
...
@@ -1112,63 +1166,57 @@ class BaseWorkerHandler(ABC):
...
@@ -1112,63 +1166,57 @@ class BaseWorkerHandler(ABC):
)
)
num_output_tokens_so_far
=
0
num_output_tokens_so_far
=
0
try
:
async
for
res
in
gen
:
async
for
res
in
gen
:
# res is vllm's RequestOutput
# res is vllm's RequestOutput
if
not
res
.
outputs
:
if
not
res
.
outputs
:
self
.
_log_with_lora_context
(
self
.
_log_with_lora_context
(
"Request {request_id}{lora_info} returned no outputs"
,
"Request {request_id}{lora_info} returned no outputs"
,
request_id
,
request_id
,
lora_request
,
lora_request
,
)
)
# Use string format "error: message" for consistency with vLLM's string-based finish_reason
# Use string format "error: message" for consistency with vLLM's string-based finish_reason
# Rust will parse this into FinishReason::Error(message)
# Rust will parse this into FinishReason::Error(message)
yield
{
yield
{
"finish_reason"
:
"error: No outputs from vLLM engine"
,
"finish_reason"
:
"error: No outputs from vLLM engine"
,
"token_ids"
:
[],
"token_ids"
:
[],
}
}
break
break
output
=
res
.
outputs
[
0
]
output
=
res
.
outputs
[
0
]
next_total_toks
=
len
(
output
.
token_ids
)
next_total_toks
=
len
(
output
.
token_ids
)
out
=
{
"token_ids"
:
output
.
token_ids
[
num_output_tokens_so_far
:]}
out
=
{
"token_ids"
:
output
.
token_ids
[
num_output_tokens_so_far
:]}
# Extract logprobs for new tokens if available
# Extract logprobs for new tokens if available
log_probs
,
top_logprobs
=
self
.
_extract_logprobs
(
log_probs
,
top_logprobs
=
self
.
_extract_logprobs
(
output
,
num_output_tokens_so_far
output
,
num_output_tokens_so_far
)
if
log_probs
is
not
None
:
out
[
"log_probs"
]
=
log_probs
if
top_logprobs
is
not
None
:
out
[
"top_logprobs"
]
=
top_logprobs
if
output
.
finish_reason
:
out
[
"finish_reason"
]
=
self
.
_normalize_finish_reason
(
output
.
finish_reason
)
)
if
log_probs
is
not
None
:
out
[
"completion_usage"
]
=
BaseWorkerHandler
.
_build_completion_usage
(
out
[
"log_probs"
]
=
log_probs
request_output
=
res
,
if
top_logprobs
is
not
None
:
embedding_sequence_length
=
embedding_sequence_length
,
out
[
"top_logprobs"
]
=
top_logprobs
)
# Log completion with LoRA info (debug level to avoid log spam)
if
output
.
finish_reason
:
self
.
_log_with_lora_context
(
out
[
"finish_reason"
]
=
output
.
finish_reason
"Completed token generation for request {request_id}{lora_info}: "
out
[
"{output_tokens} output tokens, finish_reason={finish_reason}"
,
"completion_usage"
request_id
,
]
=
BaseWorkerHandler
.
_build_completion_usage
(
lora_request
,
request_output
=
res
,
output_tokens
=
next_total_toks
,
embedding_sequence_length
=
embedding_sequence_length
,
finish_reason
=
output
.
finish_reason
,
)
)
# Log completion with LoRA info (debug level to avoid log spam)
if
output
.
stop_reason
:
self
.
_log_with_lora_context
(
out
[
"stop_reason"
]
=
output
.
stop_reason
"Completed token generation for request {request_id}{lora_info}: "
yield
out
"{output_tokens} output tokens, finish_reason={finish_reason}"
,
num_output_tokens_so_far
=
next_total_toks
request_id
,
lora_request
,
output_tokens
=
next_total_toks
,
finish_reason
=
output
.
finish_reason
,
)
if
output
.
stop_reason
:
out
[
"stop_reason"
]
=
output
.
stop_reason
yield
out
num_output_tokens_so_far
=
next_total_toks
except
asyncio
.
CancelledError
:
# raise EngineShGeneratorExit when engine exits so that frontend can migrate the request
raise
GeneratorExit
(
"Decode engine was shut down during token generation"
)
from
None
except
EngineDeadError
as
e
:
except
EngineDeadError
as
e
:
logger
.
error
(
f
"vLLM EngineDeadError:
{
e
}
"
)
logger
.
error
(
f
"vLLM EngineDeadError:
{
e
}
"
)
...
@@ -1189,6 +1237,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
...
@@ -1189,6 +1237,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
generate_endpoint
=
None
,
generate_endpoint
=
None
,
config
=
None
,
config
=
None
,
use_vllm_tokenizer
:
bool
=
False
,
use_vllm_tokenizer
:
bool
=
False
,
shutdown_event
:
asyncio
.
Event
|
None
=
None
,
):
):
super
().
__init__
(
super
().
__init__
(
runtime
,
runtime
,
...
@@ -1200,6 +1249,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
...
@@ -1200,6 +1249,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
generate_endpoint
,
generate_endpoint
,
config
,
config
,
use_vllm_tokenizer
,
use_vllm_tokenizer
,
shutdown_event
,
)
)
async
def
generate
(
self
,
request
,
context
):
async
def
generate
(
self
,
request
,
context
):
...
@@ -1361,7 +1411,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
...
@@ -1361,7 +1411,9 @@ class DecodeWorkerHandler(BaseWorkerHandler):
"role"
:
"assistant"
,
"role"
:
"assistant"
,
"content"
:
delta_text
,
"content"
:
delta_text
,
},
},
"finish_reason"
:
output
.
finish_reason
,
"finish_reason"
:
self
.
_normalize_finish_reason
(
output
.
finish_reason
),
}
}
chunk
=
{
chunk
=
{
...
@@ -1398,6 +1450,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
...
@@ -1398,6 +1450,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
generate_endpoint
=
None
,
generate_endpoint
=
None
,
config
=
None
,
config
=
None
,
use_vllm_tokenizer
:
bool
=
False
,
use_vllm_tokenizer
:
bool
=
False
,
shutdown_event
:
asyncio
.
Event
|
None
=
None
,
):
):
super
().
__init__
(
super
().
__init__
(
runtime
,
runtime
,
...
@@ -1409,6 +1462,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
...
@@ -1409,6 +1462,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
generate_endpoint
,
generate_endpoint
,
config
,
config
,
use_vllm_tokenizer
,
use_vllm_tokenizer
,
shutdown_event
,
)
)
async
def
generate
(
self
,
request
,
context
):
async
def
generate
(
self
,
request
,
context
):
...
@@ -1501,39 +1555,33 @@ class PrefillWorkerHandler(BaseWorkerHandler):
...
@@ -1501,39 +1555,33 @@ class PrefillWorkerHandler(BaseWorkerHandler):
self
.
runtime
.
shutdown
()
self
.
runtime
.
shutdown
()
os
.
_exit
(
1
)
os
.
_exit
(
1
)
try
:
async
for
res
in
gen
:
async
for
res
in
gen
:
logger
.
debug
(
f
"kv transfer params:
{
res
.
kv_transfer_params
}
"
)
logger
.
debug
(
f
"kv transfer params:
{
res
.
kv_transfer_params
}
"
)
token_ids
=
res
.
outputs
[
0
].
token_ids
if
res
.
outputs
else
[]
token_ids
=
res
.
outputs
[
0
].
token_ids
if
res
.
outputs
else
[]
output
:
Dict
[
str
,
Any
]
=
{
output
:
Dict
[
str
,
Any
]
=
{
"token_ids"
:
list
(
token_ids
),
"token_ids"
:
list
(
token_ids
),
"disaggregated_params"
:
(
"disaggregated_params"
:
(
{
"kv_transfer_params"
:
res
.
kv_transfer_params
}
{
"kv_transfer_params"
:
res
.
kv_transfer_params
}
if
res
.
kv_transfer_params
if
res
.
kv_transfer_params
else
None
else
None
),
),
"completion_usage"
:
BaseWorkerHandler
.
_build_completion_usage
(
"completion_usage"
:
BaseWorkerHandler
.
_build_completion_usage
(
request_output
=
res
,
request_output
=
res
,
embedding_sequence_length
=
embedding_sequence_length
,
embedding_sequence_length
=
embedding_sequence_length
,
),
),
}
}
# Log prefill completion with LoRA info
# Log prefill completion with LoRA info
self
.
_log_with_lora_context
(
self
.
_log_with_lora_context
(
"Prefill completed for request {request_id}{lora_info}: "
"Prefill completed for request {request_id}{lora_info}: "
"generated {token_count} token(s), has_kv_params={has_kv_params}"
,
"generated {token_count} token(s), has_kv_params={has_kv_params}"
,
request_id
,
request_id
,
lora_request
,
lora_request
,
level
=
"info"
if
lora_request
else
"debug"
,
level
=
"info"
if
lora_request
else
"debug"
,
token_count
=
len
(
token_ids
),
token_count
=
len
(
token_ids
),
has_kv_params
=
res
.
kv_transfer_params
is
not
None
,
has_kv_params
=
res
.
kv_transfer_params
is
not
None
,
)
)
yield
output
yield
output
except
asyncio
.
CancelledError
:
# raise the error because we cannot migrate prefill requests
raise
GeneratorExit
(
"Prefill engine was shut down during token generation"
)
from
None
components/src/dynamo/vllm/main.py
View file @
704c1dad
...
@@ -61,7 +61,7 @@ async def _handle_non_leader_node(dp_rank: int) -> None:
...
@@ -61,7 +61,7 @@ async def _handle_non_leader_node(dp_rank: int) -> None:
await
asyncio
.
Event
().
wait
()
await
asyncio
.
Event
().
wait
()
async
def
graceful_shutdown
(
runtime
):
async
def
graceful_shutdown
(
runtime
,
shutdown_event
):
"""
"""
Shutdown dynamo distributed runtime.
Shutdown dynamo distributed runtime.
The endpoints will be immediately invalidated so no new requests will be accepted.
The endpoints will be immediately invalidated so no new requests will be accepted.
...
@@ -69,6 +69,7 @@ async def graceful_shutdown(runtime):
...
@@ -69,6 +69,7 @@ async def graceful_shutdown(runtime):
For endpoints served with graceful_shutdown=False, the serving function will return immediately.
For endpoints served with graceful_shutdown=False, the serving function will return immediately.
"""
"""
logging
.
info
(
"Received shutdown signal, shutting down DistributedRuntime"
)
logging
.
info
(
"Received shutdown signal, shutting down DistributedRuntime"
)
shutdown_event
.
set
()
runtime
.
shutdown
()
runtime
.
shutdown
()
logging
.
info
(
"DistributedRuntime shutdown complete"
)
logging
.
info
(
"DistributedRuntime shutdown complete"
)
...
@@ -79,6 +80,9 @@ async def worker():
...
@@ -79,6 +80,9 @@ async def worker():
loop
=
asyncio
.
get_running_loop
()
loop
=
asyncio
.
get_running_loop
()
overwrite_args
(
config
)
overwrite_args
(
config
)
# Create shutdown event
shutdown_event
=
asyncio
.
Event
()
# Set DYN_EVENT_PLANE environment variable based on config
# Set DYN_EVENT_PLANE environment variable based on config
os
.
environ
[
"DYN_EVENT_PLANE"
]
=
config
.
event_plane
os
.
environ
[
"DYN_EVENT_PLANE"
]
=
config
.
event_plane
...
@@ -95,7 +99,7 @@ async def worker():
...
@@ -95,7 +99,7 @@ async def worker():
# Set up signal handler for graceful shutdown
# Set up signal handler for graceful shutdown
def
signal_handler
():
def
signal_handler
():
asyncio
.
create_task
(
graceful_shutdown
(
runtime
))
asyncio
.
create_task
(
graceful_shutdown
(
runtime
,
shutdown_event
))
for
sig
in
(
signal
.
SIGTERM
,
signal
.
SIGINT
):
for
sig
in
(
signal
.
SIGTERM
,
signal
.
SIGINT
):
loop
.
add_signal_handler
(
sig
,
signal_handler
)
loop
.
add_signal_handler
(
sig
,
signal_handler
)
...
@@ -123,29 +127,29 @@ async def worker():
...
@@ -123,29 +127,29 @@ async def worker():
# Route to appropriate initialization based on config flags
# Route to appropriate initialization based on config flags
if
config
.
vllm_native_encoder_worker
:
if
config
.
vllm_native_encoder_worker
:
await
init_vllm_native_encoder
(
runtime
,
config
)
await
init_vllm_native_encoder
(
runtime
,
config
,
shutdown_event
)
logger
.
debug
(
"init_vllm_native_encoder completed"
)
logger
.
debug
(
"init_vllm_native_encoder completed"
)
elif
config
.
ec_processor
:
elif
config
.
ec_processor
:
await
init_ec_processor
(
runtime
,
config
)
await
init_ec_processor
(
runtime
,
config
,
shutdown_event
)
logger
.
debug
(
"init_ec_processor completed"
)
logger
.
debug
(
"init_ec_processor completed"
)
elif
config
.
multimodal_processor
:
elif
config
.
multimodal_processor
:
await
init_multimodal_processor
(
runtime
,
config
)
await
init_multimodal_processor
(
runtime
,
config
,
shutdown_event
)
logger
.
debug
(
"init_multimodal_processor completed"
)
logger
.
debug
(
"init_multimodal_processor completed"
)
elif
config
.
multimodal_encode_worker
:
elif
config
.
multimodal_encode_worker
:
await
init_multimodal_encode_worker
(
runtime
,
config
)
await
init_multimodal_encode_worker
(
runtime
,
config
,
shutdown_event
)
logger
.
debug
(
"init_multimodal_encode_worker completed"
)
logger
.
debug
(
"init_multimodal_encode_worker completed"
)
elif
(
elif
(
config
.
multimodal_worker
config
.
multimodal_worker
or
config
.
multimodal_decode_worker
or
config
.
multimodal_decode_worker
or
config
.
multimodal_encode_prefill_worker
or
config
.
multimodal_encode_prefill_worker
):
):
await
init_multimodal_worker
(
runtime
,
config
)
await
init_multimodal_worker
(
runtime
,
config
,
shutdown_event
)
logger
.
debug
(
"init_multimodal_worker completed"
)
logger
.
debug
(
"init_multimodal_worker completed"
)
elif
config
.
is_prefill_worker
:
elif
config
.
is_prefill_worker
:
await
init_prefill
(
runtime
,
config
)
await
init_prefill
(
runtime
,
config
,
shutdown_event
)
logger
.
debug
(
"init_prefill completed"
)
logger
.
debug
(
"init_prefill completed"
)
else
:
else
:
await
init
(
runtime
,
config
)
await
init
(
runtime
,
config
,
shutdown_event
)
logger
.
debug
(
"init completed"
)
logger
.
debug
(
"init completed"
)
logger
.
debug
(
"Worker function completed, exiting..."
)
logger
.
debug
(
"Worker function completed, exiting..."
)
...
@@ -415,7 +419,9 @@ async def register_vllm_model(
...
@@ -415,7 +419,9 @@ async def register_vllm_model(
)
)
async
def
init_prefill
(
runtime
:
DistributedRuntime
,
config
:
Config
):
async
def
init_prefill
(
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
):
"""
"""
Instantiate and serve
Instantiate and serve
"""
"""
...
@@ -441,6 +447,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
...
@@ -441,6 +447,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
generate_endpoint
=
generate_endpoint
,
generate_endpoint
=
generate_endpoint
,
config
=
config
,
config
=
config
,
use_vllm_tokenizer
=
config
.
use_vllm_tokenizer
,
use_vllm_tokenizer
=
config
.
use_vllm_tokenizer
,
shutdown_event
=
shutdown_event
,
)
)
handler
.
add_temp_dir
(
prometheus_temp_dir
)
handler
.
add_temp_dir
(
prometheus_temp_dir
)
...
@@ -527,7 +534,9 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
...
@@ -527,7 +534,9 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
handler
.
cleanup
()
handler
.
cleanup
()
async
def
init
(
runtime
:
DistributedRuntime
,
config
:
Config
):
async
def
init
(
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
):
"""
"""
Instantiate and serve
Instantiate and serve
"""
"""
...
@@ -566,6 +575,7 @@ async def init(runtime: DistributedRuntime, config: Config):
...
@@ -566,6 +575,7 @@ async def init(runtime: DistributedRuntime, config: Config):
generate_endpoint
=
generate_endpoint
,
generate_endpoint
=
generate_endpoint
,
config
=
config
,
config
=
config
,
use_vllm_tokenizer
=
config
.
use_vllm_tokenizer
,
use_vllm_tokenizer
=
config
.
use_vllm_tokenizer
,
shutdown_event
=
shutdown_event
,
)
)
handler
.
add_temp_dir
(
prometheus_temp_dir
)
handler
.
add_temp_dir
(
prometheus_temp_dir
)
...
@@ -699,7 +709,9 @@ def get_engine_cache_info(engine: AsyncLLM):
...
@@ -699,7 +709,9 @@ def get_engine_cache_info(engine: AsyncLLM):
raise
raise
async
def
init_multimodal_processor
(
runtime
:
DistributedRuntime
,
config
:
Config
):
async
def
init_multimodal_processor
(
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
):
"""Initialize multimodal processor component"""
"""Initialize multimodal processor component"""
component
=
runtime
.
namespace
(
config
.
namespace
).
component
(
config
.
component
)
component
=
runtime
.
namespace
(
config
.
namespace
).
component
(
config
.
component
)
...
@@ -754,7 +766,9 @@ async def init_multimodal_processor(runtime: DistributedRuntime, config: Config)
...
@@ -754,7 +766,9 @@ async def init_multimodal_processor(runtime: DistributedRuntime, config: Config)
handler
.
cleanup
()
handler
.
cleanup
()
async
def
init_multimodal_encode_worker
(
runtime
:
DistributedRuntime
,
config
:
Config
):
async
def
init_multimodal_encode_worker
(
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
):
"""Initialize multimodal encode worker component"""
"""Initialize multimodal encode worker component"""
component
=
runtime
.
namespace
(
config
.
namespace
).
component
(
config
.
component
)
component
=
runtime
.
namespace
(
config
.
namespace
).
component
(
config
.
component
)
...
@@ -792,7 +806,9 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con
...
@@ -792,7 +806,9 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con
handler
.
cleanup
()
handler
.
cleanup
()
async
def
init_vllm_native_encoder
(
runtime
:
DistributedRuntime
,
config
:
Config
):
async
def
init_vllm_native_encoder
(
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
):
"""
"""
Initialize vLLM-native encoder worker component (ECConnector mode).
Initialize vLLM-native encoder worker component (ECConnector mode).
In this mode, vLLM handles encoder execution, caching, and storage automatically.
In this mode, vLLM handles encoder execution, caching, and storage automatically.
...
@@ -853,7 +869,9 @@ async def init_vllm_native_encoder(runtime: DistributedRuntime, config: Config):
...
@@ -853,7 +869,9 @@ async def init_vllm_native_encoder(runtime: DistributedRuntime, config: Config):
handler
.
cleanup
()
handler
.
cleanup
()
async
def
init_ec_processor
(
runtime
:
DistributedRuntime
,
config
:
Config
):
async
def
init_ec_processor
(
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
):
"""
"""
Initialize ECConnector processor component.
Initialize ECConnector processor component.
...
@@ -923,7 +941,9 @@ async def init_ec_processor(runtime: DistributedRuntime, config: Config):
...
@@ -923,7 +941,9 @@ async def init_ec_processor(runtime: DistributedRuntime, config: Config):
handler
.
cleanup
()
handler
.
cleanup
()
async
def
init_multimodal_worker
(
runtime
:
DistributedRuntime
,
config
:
Config
):
async
def
init_multimodal_worker
(
runtime
:
DistributedRuntime
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
):
"""
"""
Initialize multimodal worker component.
Initialize multimodal worker component.
...
@@ -983,11 +1003,16 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
...
@@ -983,11 +1003,16 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config):
# Choose handler based on worker type
# Choose handler based on worker type
if
config
.
multimodal_decode_worker
:
if
config
.
multimodal_decode_worker
:
handler
=
MultimodalDecodeWorkerHandler
(
handler
=
MultimodalDecodeWorkerHandler
(
runtime
,
component
,
engine_client
,
config
runtime
,
component
,
engine_client
,
config
,
shutdown_event
)
)
else
:
else
:
handler
=
MultimodalPDWorkerHandler
(
handler
=
MultimodalPDWorkerHandler
(
runtime
,
component
,
engine_client
,
config
,
decode_worker_client
runtime
,
component
,
engine_client
,
config
,
decode_worker_client
,
shutdown_event
,
)
)
handler
.
add_temp_dir
(
prometheus_temp_dir
)
handler
.
add_temp_dir
(
prometheus_temp_dir
)
...
...
components/src/dynamo/vllm/multimodal_handlers/worker_handler.py
View file @
704c1dad
...
@@ -37,6 +37,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
...
@@ -37,6 +37,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
component
,
component
,
engine_client
,
engine_client
,
config
,
config
,
shutdown_event
=
None
,
):
):
# Get default_sampling_params from config
# Get default_sampling_params from config
default_sampling_params
=
(
default_sampling_params
=
(
...
@@ -50,6 +51,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
...
@@ -50,6 +51,7 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
engine_client
,
engine_client
,
default_sampling_params
,
default_sampling_params
,
enable_multimodal
=
config
.
enable_multimodal
,
enable_multimodal
=
config
.
enable_multimodal
,
shutdown_event
=
shutdown_event
,
)
)
self
.
config
=
config
self
.
config
=
config
...
@@ -117,6 +119,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -117,6 +119,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
engine_client
:
AsyncLLM
,
engine_client
:
AsyncLLM
,
config
,
config
,
decode_worker_client
:
Client
=
None
,
decode_worker_client
:
Client
=
None
,
shutdown_event
=
None
,
):
):
# Get default_sampling_params from config
# Get default_sampling_params from config
default_sampling_params
=
(
default_sampling_params
=
(
...
@@ -130,6 +133,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
...
@@ -130,6 +133,7 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
engine_client
,
engine_client
,
default_sampling_params
,
default_sampling_params
,
enable_multimodal
=
config
.
enable_multimodal
,
enable_multimodal
=
config
.
enable_multimodal
,
shutdown_event
=
shutdown_event
,
)
)
self
.
config
=
config
self
.
config
=
config
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment