Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
5a7ead2b
Unverified
Commit
5a7ead2b
authored
Feb 25, 2026
by
Schwinn Saereesitthipitak
Committed by
GitHub
Feb 25, 2026
Browse files
feat(sglang): add checkpoint/restore support for chrek (#6594)
Co-authored-by:
Hannah Zhang
<
hannahz@nvidia.com
>
parent
49eca14b
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
306 additions
and
26 deletions
+306
-26
components/src/dynamo/sglang/checkpoint_restore.py
components/src/dynamo/sglang/checkpoint_restore.py
+251
-0
components/src/dynamo/sglang/init_llm.py
components/src/dynamo/sglang/init_llm.py
+16
-5
components/src/dynamo/sglang/main.py
components/src/dynamo/sglang/main.py
+10
-0
components/src/dynamo/vllm/main.py
components/src/dynamo/vllm/main.py
+18
-12
components/src/dynamo/vllm/tests/test_vllm_worker_factory.py
components/src/dynamo/vllm/tests/test_vllm_worker_factory.py
+6
-4
components/src/dynamo/vllm/worker_factory.py
components/src/dynamo/vllm/worker_factory.py
+5
-5
No files found.
components/src/dynamo/sglang/checkpoint_restore.py
0 → 100644
View file @
5a7ead2b
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Checkpoint/restore (chrek) integration for SGLang workers.
Handles the checkpoint job pod lifecycle:
1. Early exit if a checkpoint already exists (idempotency)
2. Sleep model for CRIU-friendly GPU state
3. Signal readiness for DaemonSet to begin checkpoint
4. Wait for watcher signals from the DaemonSet
5. Wake model after restore
SGLang does not have a native sleep/wake API like vLLM. Instead we use
release_memory_occupation / resume_memory_occupation through the
SGLangCheckpointAdapter, which presents the same sleep()/wake_up()
interface that CheckpointConfig.run_lifecycle expects.
Environment variables:
- DYN_READY_FOR_CHECKPOINT_FILE: Path where this worker writes readiness marker
- DYN_CHECKPOINT_STORAGE_TYPE: Storage backend (pvc, s3, oci) (optional, defaults to pvc)
- DYN_CHECKPOINT_LOCATION: Full checkpoint path (optional when PATH+HASH are provided)
- DYN_CHECKPOINT_PATH + DYN_CHECKPOINT_HASH: PVC base path + hash (used to derive location)
Signals handled in checkpoint mode:
- SIGUSR1: Checkpoint completed, exit process
- SIGCONT: Restore completed, wake model and continue
- SIGKILL (from watcher on failure): Process is terminated immediately (unhandleable)
"""
import
asyncio
import
logging
import
os
import
signal
import
time
from
typing
import
Optional
import
sglang
as
sgl
logger
=
logging
.
getLogger
(
__name__
)
_SLEEP_MODE_LEVEL
=
1
# Memory tags to release/resume for CRIU checkpoint/restore.
# All GPU resources must be released so CRIU can snapshot the process cleanly.
_MEMORY_TAGS
=
[
"kv_cache"
,
"weights"
,
"cuda_graph"
]
class
SGLangCheckpointAdapter
:
"""Adapts an sgl.Engine to the sleep/wake_up interface expected by
CheckpointConfig.run_lifecycle (matching vLLM's AsyncLLM API).
sleep(): pause generation -> release GPU memory
wake_up(): resume GPU memory -> continue generation
"""
def
__init__
(
self
,
engine
:
sgl
.
Engine
):
self
.
_engine
=
engine
async
def
sleep
(
self
,
level
:
int
=
1
)
->
None
:
from
sglang.srt.managers.io_struct
import
(
PauseGenerationReqInput
,
ReleaseMemoryOccupationReqInput
,
)
# Drain in-flight requests before touching GPU memory
await
self
.
_engine
.
tokenizer_manager
.
pause_generation
(
PauseGenerationReqInput
())
await
self
.
_engine
.
tokenizer_manager
.
release_memory_occupation
(
ReleaseMemoryOccupationReqInput
(
tags
=
_MEMORY_TAGS
),
None
)
async
def
wake_up
(
self
)
->
None
:
from
sglang.srt.managers.io_struct
import
(
ContinueGenerationReqInput
,
ResumeMemoryOccupationReqInput
,
)
await
self
.
_engine
.
tokenizer_manager
.
resume_memory_occupation
(
ResumeMemoryOccupationReqInput
(
tags
=
_MEMORY_TAGS
),
None
)
await
self
.
_engine
.
tokenizer_manager
.
continue_generation
(
ContinueGenerationReqInput
()
)
class
CheckpointConfig
:
"""Parsed and validated checkpoint configuration from environment variables."""
def
__init__
(
self
):
self
.
ready_file
=
os
.
environ
[
"DYN_READY_FOR_CHECKPOINT_FILE"
]
self
.
storage_type
=
os
.
environ
.
get
(
"DYN_CHECKPOINT_STORAGE_TYPE"
,
"pvc"
)
self
.
location
=
os
.
environ
.
get
(
"DYN_CHECKPOINT_LOCATION"
,
""
)
if
not
self
.
location
:
checkpoint_path
=
os
.
environ
.
get
(
"DYN_CHECKPOINT_PATH"
,
""
).
rstrip
(
"/"
)
checkpoint_hash
=
os
.
environ
.
get
(
"DYN_CHECKPOINT_HASH"
,
""
)
if
checkpoint_path
and
checkpoint_hash
:
self
.
location
=
f
"
{
checkpoint_path
}
/
{
checkpoint_hash
}
"
self
.
is_checkpoint_job
=
bool
(
self
.
location
)
self
.
_checkpoint_done
=
asyncio
.
Event
()
self
.
_restore_done
=
asyncio
.
Event
()
def
checkpoint_exists
(
self
)
->
bool
:
"""Check if a completed checkpoint already exists (idempotency).
A checkpoint is complete when its directory exists at the base path root
(not under the tmp/ staging area). Directory presence = done.
"""
if
self
.
storage_type
!=
"pvc"
:
return
False
if
os
.
path
.
isdir
(
self
.
location
):
logger
.
info
(
f
"Existing checkpoint found at
{
self
.
location
}
, skipping"
)
return
True
logger
.
info
(
f
"No checkpoint at
{
self
.
location
}
, creating new one"
)
return
False
async
def
run_lifecycle
(
self
,
engine_client
,
sleep_level
:
int
)
->
bool
:
"""Run the full checkpoint lifecycle after the engine is loaded.
1. Put model to sleep (CRIU-friendly GPU state)
2. Write ready file (triggers DaemonSet checkpoint via readiness probe)
3. Wait for watcher signal (checkpoint complete, restore complete, or failure)
4. If restored: wake model and return True (caller proceeds with registration)
5. If checkpoint done: return False (caller should exit)
"""
# Sleep model for checkpoint
logger
.
info
(
f
"Putting model to sleep (level=
{
sleep_level
}
)"
)
await
engine_client
.
sleep
(
level
=
sleep_level
)
# Install signal handlers before writing the ready file so there is no
# window where the DaemonSet can send SIGUSR1/SIGCONT while the default
# signal disposition (terminate) is still in effect.
self
.
_install_signal_handlers
()
# Signal readiness
with
open
(
self
.
ready_file
,
"w"
)
as
f
:
f
.
write
(
"ready"
)
logger
.
info
(
"Ready for checkpoint. Waiting for watcher signal "
"(SIGUSR1=checkpoint complete, SIGCONT=restore complete)"
)
try
:
event
=
await
self
.
_wait_for_watcher_signal
()
if
event
==
"restore"
:
logger
.
info
(
"Restore signal detected (SIGCONT)"
)
logger
.
info
(
"Waking up model after restore"
)
await
engine_client
.
wake_up
()
return
True
# SIGUSR1: checkpoint complete
logger
.
info
(
"Checkpoint completion signal detected (SIGUSR1)"
)
return
False
finally
:
self
.
_remove_signal_handlers
()
# Remove the ready file so that a restarting pod does not leave a
# stale marker that could trick the DaemonSet into acting on it.
try
:
os
.
unlink
(
self
.
ready_file
)
except
OSError
:
pass
def
_install_signal_handlers
(
self
)
->
None
:
loop
=
asyncio
.
get_running_loop
()
loop
.
add_signal_handler
(
signal
.
SIGUSR1
,
self
.
_checkpoint_done
.
set
)
# SIGCONT is used as the restore-complete signal. The chrek DaemonSet
# watcher is the only sender, so there is no conflict with POSIX
# job-control semantics in practice.
loop
.
add_signal_handler
(
signal
.
SIGCONT
,
self
.
_restore_done
.
set
)
# No handler for checkpoint failure: the watcher sends SIGKILL, which
# terminates the process immediately (cannot be caught).
def
_remove_signal_handlers
(
self
)
->
None
:
loop
=
asyncio
.
get_running_loop
()
loop
.
remove_signal_handler
(
signal
.
SIGUSR1
)
loop
.
remove_signal_handler
(
signal
.
SIGCONT
)
async
def
_wait_for_watcher_signal
(
self
)
->
str
:
waiters
=
{
asyncio
.
create_task
(
self
.
_checkpoint_done
.
wait
()):
"checkpoint"
,
asyncio
.
create_task
(
self
.
_restore_done
.
wait
()):
"restore"
,
}
try
:
done
,
pending
=
await
asyncio
.
wait
(
waiters
.
keys
(),
return_when
=
asyncio
.
FIRST_COMPLETED
)
for
task
in
pending
:
task
.
cancel
()
winner
=
done
.
pop
()
await
winner
return
waiters
[
winner
]
finally
:
for
task
in
waiters
:
if
not
task
.
done
():
task
.
cancel
()
async
def
handle_checkpoint_mode
(
server_args
)
->
tuple
[
bool
,
Optional
[
sgl
.
Engine
]]:
"""Single entry point for checkpoint/restore integration.
Must be called BEFORE runtime creation so the engine can be checkpointed
without active NATS/etcd connections.
Returns:
(should_exit, engine) where:
- (True, None): caller should return immediately (checkpoint already
exists, or checkpoint completed successfully).
- (False, None): not in checkpoint mode — cold-start normally.
- (False, engine): restore completed — caller should use this engine.
"""
if
"DYN_READY_FOR_CHECKPOINT_FILE"
not
in
os
.
environ
:
return
False
,
None
# Validate: either a full location or path + hash must be set.
if
not
os
.
environ
.
get
(
"DYN_CHECKPOINT_LOCATION"
):
path
=
os
.
environ
.
get
(
"DYN_CHECKPOINT_PATH"
,
""
)
hash_
=
os
.
environ
.
get
(
"DYN_CHECKPOINT_HASH"
,
""
)
if
not
path
or
not
hash_
:
raise
EnvironmentError
(
"Checkpoint mode requires either DYN_CHECKPOINT_LOCATION or both "
"DYN_CHECKPOINT_PATH and DYN_CHECKPOINT_HASH"
)
cfg
=
CheckpointConfig
()
checkpoint_exists
=
cfg
.
checkpoint_exists
()
if
cfg
.
is_checkpoint_job
and
checkpoint_exists
:
return
True
,
None
if
not
cfg
.
is_checkpoint_job
and
not
checkpoint_exists
:
return
False
,
None
logger
.
info
(
"Checkpoint mode enabled (watcher-driven signals)"
)
# Enable memory_saver + weights CPU backup so weights survive CRIU
# (mirrors vLLM's enable_sleep_mode = True)
server_args
.
enable_memory_saver
=
True
server_args
.
enable_weights_cpu_backup
=
True
start_time
=
time
.
time
()
engine
=
sgl
.
Engine
(
server_args
=
server_args
)
logger
.
info
(
f
"SGLang engine loaded in
{
time
.
time
()
-
start_time
:.
2
f
}
s (checkpoint mode)"
)
adapter
=
SGLangCheckpointAdapter
(
engine
)
if
not
await
cfg
.
run_lifecycle
(
adapter
,
_SLEEP_MODE_LEVEL
):
return
True
,
None
return
False
,
engine
components/src/dynamo/sglang/init_llm.py
View file @
5a7ead2b
...
@@ -5,7 +5,7 @@ import asyncio
...
@@ -5,7 +5,7 @@ import asyncio
import
logging
import
logging
import
os
import
os
import
time
import
time
from
typing
import
Awaitable
,
Callable
from
typing
import
Awaitable
,
Callable
,
Optional
import
sglang
as
sgl
import
sglang
as
sgl
...
@@ -61,12 +61,18 @@ async def init_decode(
...
@@ -61,12 +61,18 @@ async def init_decode(
shutdown_event
:
asyncio
.
Event
,
shutdown_event
:
asyncio
.
Event
,
shutdown_endpoints
:
list
,
shutdown_endpoints
:
list
,
run_deferred_handlers
:
Callable
[[],
Awaitable
[
None
]]
|
None
=
None
,
run_deferred_handlers
:
Callable
[[],
Awaitable
[
None
]]
|
None
=
None
,
checkpoint_restore_engine
:
Optional
[
sgl
.
Engine
]
=
None
,
):
):
server_args
,
dynamo_args
=
config
.
server_args
,
config
.
dynamo_args
server_args
,
dynamo_args
=
config
.
server_args
,
config
.
dynamo_args
if
server_args
.
node_rank
>=
1
:
if
server_args
.
node_rank
>=
1
:
os
.
environ
[
"SGLANG_BLOCK_NONZERO_RANK_CHILDREN"
]
=
"0"
os
.
environ
[
"SGLANG_BLOCK_NONZERO_RANK_CHILDREN"
]
=
"0"
# Use pre-created engine if provided (checkpoint/restore mode)
if
checkpoint_restore_engine
is
not
None
:
engine
=
checkpoint_restore_engine
load_time
=
0.0
else
:
start_time
=
time
.
time
()
start_time
=
time
.
time
()
engine
=
sgl
.
Engine
(
server_args
=
server_args
)
engine
=
sgl
.
Engine
(
server_args
=
server_args
)
load_time
=
time
.
time
()
-
start_time
load_time
=
time
.
time
()
-
start_time
...
@@ -145,12 +151,17 @@ async def init_prefill(
...
@@ -145,12 +151,17 @@ async def init_prefill(
shutdown_event
:
asyncio
.
Event
,
shutdown_event
:
asyncio
.
Event
,
shutdown_endpoints
:
list
,
shutdown_endpoints
:
list
,
run_deferred_handlers
:
Callable
[[],
Awaitable
[
None
]]
|
None
=
None
,
run_deferred_handlers
:
Callable
[[],
Awaitable
[
None
]]
|
None
=
None
,
checkpoint_restore_engine
:
Optional
[
sgl
.
Engine
]
=
None
,
):
):
server_args
,
dynamo_args
=
config
.
server_args
,
config
.
dynamo_args
server_args
,
dynamo_args
=
config
.
server_args
,
config
.
dynamo_args
if
server_args
.
node_rank
>=
1
:
if
server_args
.
node_rank
>=
1
:
os
.
environ
[
"SGLANG_BLOCK_NONZERO_RANK_CHILDREN"
]
=
"0"
os
.
environ
[
"SGLANG_BLOCK_NONZERO_RANK_CHILDREN"
]
=
"0"
# Use pre-created engine if provided (checkpoint/restore mode)
if
checkpoint_restore_engine
is
not
None
:
engine
=
checkpoint_restore_engine
else
:
engine
=
sgl
.
Engine
(
server_args
=
server_args
)
engine
=
sgl
.
Engine
(
server_args
=
server_args
)
generate_endpoint
=
runtime
.
endpoint
(
generate_endpoint
=
runtime
.
endpoint
(
...
...
components/src/dynamo/sglang/main.py
View file @
5a7ead2b
...
@@ -12,6 +12,7 @@ from dynamo.common.constants import DisaggregationMode
...
@@ -12,6 +12,7 @@ from dynamo.common.constants import DisaggregationMode
from
dynamo.common.utils.runtime
import
create_runtime
from
dynamo.common.utils.runtime
import
create_runtime
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.sglang.args
import
parse_args
from
dynamo.sglang.args
import
parse_args
from
dynamo.sglang.checkpoint_restore
import
handle_checkpoint_mode
from
dynamo.sglang.init_diffusion
import
(
from
dynamo.sglang.init_diffusion
import
(
init_image_diffusion
,
init_image_diffusion
,
init_llm_diffusion
,
init_llm_diffusion
,
...
@@ -39,6 +40,13 @@ async def worker():
...
@@ -39,6 +40,13 @@ async def worker():
config
.
server_args
.
load_format
=
setup_gms
(
config
.
server_args
)
config
.
server_args
.
load_format
=
setup_gms
(
config
.
server_args
)
# Checkpoint mode: engine must be created BEFORE runtime (no NATS/etcd during CRIU)
should_exit
,
checkpoint_restore_engine
=
await
handle_checkpoint_mode
(
config
.
server_args
)
if
should_exit
:
return
dynamo_args
=
config
.
dynamo_args
dynamo_args
=
config
.
dynamo_args
shutdown_event
=
asyncio
.
Event
()
shutdown_event
=
asyncio
.
Event
()
shutdown_endpoints
:
list
=
[]
shutdown_endpoints
:
list
=
[]
...
@@ -121,6 +129,7 @@ async def worker():
...
@@ -121,6 +129,7 @@ async def worker():
shutdown_event
,
shutdown_event
,
shutdown_endpoints
,
shutdown_endpoints
,
run_deferred_handlers
,
run_deferred_handlers
,
checkpoint_restore_engine
=
checkpoint_restore_engine
,
)
)
else
:
else
:
await
init_prefill
(
await
init_prefill
(
...
@@ -129,6 +138,7 @@ async def worker():
...
@@ -129,6 +138,7 @@ async def worker():
shutdown_event
,
shutdown_event
,
shutdown_endpoints
,
shutdown_endpoints
,
run_deferred_handlers
,
run_deferred_handlers
,
checkpoint_restore_engine
=
checkpoint_restore_engine
,
)
)
...
...
components/src/dynamo/vllm/main.py
View file @
5a7ead2b
...
@@ -147,15 +147,15 @@ async def worker():
...
@@ -147,15 +147,15 @@ async def worker():
# CHECKPOINT MODE: Load engine BEFORE runtime creation
# CHECKPOINT MODE: Load engine BEFORE runtime creation
# This allows checkpointing GPU state before runtime connections are established
# This allows checkpointing GPU state before runtime connections are established
pre_created
_engine
=
None
checkpoint_restore
_engine
=
None
if
checkpoint_cfg
is
not
None
:
if
checkpoint_cfg
is
not
None
:
logger
.
info
(
"Checkpoint mode enabled (watcher-driven signals)"
)
logger
.
info
(
"Checkpoint mode enabled (watcher-driven signals)"
)
# Checkpoint mode requires sleep mode — enable before engine init
# Checkpoint mode requires sleep mode — enable before engine init
config
.
engine_args
.
enable_sleep_mode
=
True
config
.
engine_args
.
enable_sleep_mode
=
True
pre_created
_engine
=
setup_vllm_engine
(
config
)
checkpoint_restore
_engine
=
setup_vllm_engine
(
config
)
engine_client
=
pre_created
_engine
[
0
]
engine_client
=
checkpoint_restore
_engine
[
0
]
if
not
await
checkpoint_cfg
.
run_lifecycle
(
if
not
await
checkpoint_cfg
.
run_lifecycle
(
engine_client
,
CHECKPOINT_SLEEP_MODE_LEVEL
engine_client
,
CHECKPOINT_SLEEP_MODE_LEVEL
...
@@ -185,7 +185,7 @@ async def worker():
...
@@ -185,7 +185,7 @@ async def worker():
config
,
config
,
shutdown_event
,
shutdown_event
,
shutdown_endpoints
,
shutdown_endpoints
,
pre_created_engine
=
pre_created
_engine
,
checkpoint_restore_engine
=
checkpoint_restore
_engine
,
)
)
logger
.
debug
(
"multimodal worker completed"
)
logger
.
debug
(
"multimodal worker completed"
)
elif
config
.
omni
:
elif
config
.
omni
:
...
@@ -193,12 +193,18 @@ async def worker():
...
@@ -193,12 +193,18 @@ async def worker():
logger
.
debug
(
"init_omni completed"
)
logger
.
debug
(
"init_omni completed"
)
elif
config
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
elif
config
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
await
init_prefill
(
await
init_prefill
(
runtime
,
config
,
shutdown_event
,
pre_created_engine
=
pre_created_engine
runtime
,
config
,
shutdown_event
,
checkpoint_restore_engine
=
checkpoint_restore_engine
,
)
)
logger
.
debug
(
"init_prefill completed"
)
logger
.
debug
(
"init_prefill completed"
)
else
:
else
:
await
init
(
await
init
(
runtime
,
config
,
shutdown_event
,
pre_created_engine
=
pre_created_engine
runtime
,
config
,
shutdown_event
,
checkpoint_restore_engine
=
checkpoint_restore_engine
,
)
)
logger
.
debug
(
"init completed"
)
logger
.
debug
(
"init completed"
)
...
@@ -592,7 +598,7 @@ async def init_prefill(
...
@@ -592,7 +598,7 @@ async def init_prefill(
runtime
:
DistributedRuntime
,
runtime
:
DistributedRuntime
,
config
:
Config
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
,
shutdown_event
:
asyncio
.
Event
,
pre_created
_engine
=
None
,
checkpoint_restore
_engine
=
None
,
):
):
"""
"""
Instantiate and serve
Instantiate and serve
...
@@ -605,14 +611,14 @@ async def init_prefill(
...
@@ -605,14 +611,14 @@ async def init_prefill(
)
)
# Use pre-created engine if provided (checkpoint mode), otherwise create new
# Use pre-created engine if provided (checkpoint mode), otherwise create new
if
pre_created
_engine
is
not
None
:
if
checkpoint_restore
_engine
is
not
None
:
(
(
engine_client
,
engine_client
,
vllm_config
,
vllm_config
,
default_sampling_params
,
default_sampling_params
,
prometheus_temp_dir
,
prometheus_temp_dir
,
_component_gauges
,
_component_gauges
,
)
=
pre_created
_engine
)
=
checkpoint_restore
_engine
else
:
else
:
(
(
engine_client
,
engine_client
,
...
@@ -734,7 +740,7 @@ async def init(
...
@@ -734,7 +740,7 @@ async def init(
runtime
:
DistributedRuntime
,
runtime
:
DistributedRuntime
,
config
:
Config
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
,
shutdown_event
:
asyncio
.
Event
,
pre_created
_engine
=
None
,
checkpoint_restore
_engine
=
None
,
):
):
"""
"""
Instantiate and serve
Instantiate and serve
...
@@ -773,14 +779,14 @@ async def init(
...
@@ -773,14 +779,14 @@ async def init(
)
)
# Use pre-created engine if provided (checkpoint mode), otherwise create new
# Use pre-created engine if provided (checkpoint mode), otherwise create new
if
pre_created
_engine
is
not
None
:
if
checkpoint_restore
_engine
is
not
None
:
(
(
engine_client
,
engine_client
,
vllm_config
,
vllm_config
,
default_sampling_params
,
default_sampling_params
,
prometheus_temp_dir
,
prometheus_temp_dir
,
component_gauges
,
component_gauges
,
)
=
pre_created
_engine
)
=
checkpoint_restore
_engine
# Factory is created after unpack so component_gauges is available
# Factory is created after unpack so component_gauges is available
factory
=
StatLoggerFactory
(
factory
=
StatLoggerFactory
(
endpoint
=
generate_endpoint
,
endpoint
=
generate_endpoint
,
...
...
components/src/dynamo/vllm/tests/test_vllm_worker_factory.py
View file @
5a7ead2b
...
@@ -103,12 +103,14 @@ class TestCreate:
...
@@ -103,12 +103,14 @@ class TestCreate:
factory
.
_create_multimodal_worker
.
assert_called_once
()
# type: ignore[union-attr]
factory
.
_create_multimodal_worker
.
assert_called_once
()
# type: ignore[union-attr]
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
async
def
test_passes_pre_created_engine
(
self
,
factory
:
WorkerFactory
)
->
None
:
async
def
test_passes_checkpoint_restore_engine
(
self
,
factory
:
WorkerFactory
)
->
None
:
config
=
_make_config
(
multimodal_worker
=
True
)
config
=
_make_config
(
multimodal_worker
=
True
)
runtime
=
Mock
()
runtime
=
Mock
()
shutdown_event
=
asyncio
.
Event
()
shutdown_event
=
asyncio
.
Event
()
shutdown_endpoints
:
list
=
[]
shutdown_endpoints
:
list
=
[]
pre_created
_engine
:
EngineSetupResult
=
(
checkpoint_restore
_engine
:
EngineSetupResult
=
(
Mock
(),
Mock
(),
Mock
(),
Mock
(),
Mock
(),
Mock
(),
...
@@ -121,7 +123,7 @@ class TestCreate:
...
@@ -121,7 +123,7 @@ class TestCreate:
config
,
config
,
shutdown_event
,
shutdown_event
,
shutdown_endpoints
,
shutdown_endpoints
,
pre_created_engine
=
pre_created
_engine
,
checkpoint_restore_engine
=
checkpoint_restore
_engine
,
)
)
factory
.
_create_multimodal_worker
.
assert_called_once_with
(
# type: ignore[union-attr]
factory
.
_create_multimodal_worker
.
assert_called_once_with
(
# type: ignore[union-attr]
...
@@ -129,7 +131,7 @@ class TestCreate:
...
@@ -129,7 +131,7 @@ class TestCreate:
config
,
config
,
shutdown_event
,
shutdown_event
,
shutdown_endpoints
,
shutdown_endpoints
,
pre_created_engine
=
pre_created
_engine
,
checkpoint_restore_engine
=
checkpoint_restore
_engine
,
)
)
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
asyncio
...
...
components/src/dynamo/vllm/worker_factory.py
View file @
5a7ead2b
...
@@ -58,7 +58,7 @@ class WorkerFactory:
...
@@ -58,7 +58,7 @@ class WorkerFactory:
config
:
Config
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
,
shutdown_event
:
asyncio
.
Event
,
shutdown_endpoints
:
list
,
shutdown_endpoints
:
list
,
pre_created
_engine
:
Optional
[
EngineSetupResult
]
=
None
,
checkpoint_restore
_engine
:
Optional
[
EngineSetupResult
]
=
None
,
)
->
None
:
)
->
None
:
"""Create the appropriate multimodal worker based on config flags."""
"""Create the appropriate multimodal worker based on config flags."""
...
@@ -72,7 +72,7 @@ class WorkerFactory:
...
@@ -72,7 +72,7 @@ class WorkerFactory:
config
,
config
,
shutdown_event
,
shutdown_event
,
shutdown_endpoints
,
shutdown_endpoints
,
pre_created_engine
=
pre_created
_engine
,
checkpoint_restore_engine
=
checkpoint_restore
_engine
,
)
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
@@ -85,7 +85,7 @@ class WorkerFactory:
...
@@ -85,7 +85,7 @@ class WorkerFactory:
config
:
Config
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
,
shutdown_event
:
asyncio
.
Event
,
shutdown_endpoints
:
list
,
# mutated in place
shutdown_endpoints
:
list
,
# mutated in place
pre_created
_engine
:
Optional
[
EngineSetupResult
]
=
None
,
checkpoint_restore
_engine
:
Optional
[
EngineSetupResult
]
=
None
,
)
->
None
:
)
->
None
:
"""
"""
Initialize multimodal worker component.
Initialize multimodal worker component.
...
@@ -121,14 +121,14 @@ class WorkerFactory:
...
@@ -121,14 +121,14 @@ class WorkerFactory:
[
load_lora_endpoint
,
unload_lora_endpoint
,
list_loras_endpoint
]
[
load_lora_endpoint
,
unload_lora_endpoint
,
list_loras_endpoint
]
)
)
# Use pre-created engine if provided (checkpoint mode), otherwise create new
# Use pre-created engine if provided (checkpoint mode), otherwise create new
if
pre_created
_engine
is
not
None
:
if
checkpoint_restore
_engine
is
not
None
:
(
(
engine_client
,
engine_client
,
vllm_config
,
vllm_config
,
_default_sampling_params
,
_default_sampling_params
,
prometheus_temp_dir
,
prometheus_temp_dir
,
_component_gauges
,
_component_gauges
,
)
=
pre_created
_engine
)
=
checkpoint_restore
_engine
else
:
else
:
(
(
engine_client
,
engine_client
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment