Unverified Commit 5a7ead2b authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

feat(sglang): add checkpoint/restore support for chrek (#6594)


Co-authored-by: default avatarHannah Zhang <hannahz@nvidia.com>
parent 49eca14b
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Checkpoint/restore (chrek) integration for SGLang workers.
Handles the checkpoint job pod lifecycle:
1. Early exit if a checkpoint already exists (idempotency)
2. Sleep model for CRIU-friendly GPU state
3. Signal readiness for DaemonSet to begin checkpoint
4. Wait for watcher signals from the DaemonSet
5. Wake model after restore
SGLang does not have a native sleep/wake API like vLLM. Instead we use
release_memory_occupation / resume_memory_occupation through the
SGLangCheckpointAdapter, which presents the same sleep()/wake_up()
interface that CheckpointConfig.run_lifecycle expects.
Environment variables:
- DYN_READY_FOR_CHECKPOINT_FILE: Path where this worker writes readiness marker
- DYN_CHECKPOINT_STORAGE_TYPE: Storage backend (pvc, s3, oci) (optional, defaults to pvc)
- DYN_CHECKPOINT_LOCATION: Full checkpoint path (optional when PATH+HASH are provided)
- DYN_CHECKPOINT_PATH + DYN_CHECKPOINT_HASH: PVC base path + hash (used to derive location)
Signals handled in checkpoint mode:
- SIGUSR1: Checkpoint completed, exit process
- SIGCONT: Restore completed, wake model and continue
- SIGKILL (from watcher on failure): Process is terminated immediately (unhandleable)
"""
import asyncio
import logging
import os
import signal
import time
from typing import Optional
import sglang as sgl
logger = logging.getLogger(__name__)
_SLEEP_MODE_LEVEL = 1
# Memory tags to release/resume for CRIU checkpoint/restore.
# All GPU resources must be released so CRIU can snapshot the process cleanly.
_MEMORY_TAGS = ["kv_cache", "weights", "cuda_graph"]
class SGLangCheckpointAdapter:
"""Adapts an sgl.Engine to the sleep/wake_up interface expected by
CheckpointConfig.run_lifecycle (matching vLLM's AsyncLLM API).
sleep(): pause generation -> release GPU memory
wake_up(): resume GPU memory -> continue generation
"""
def __init__(self, engine: sgl.Engine):
self._engine = engine
async def sleep(self, level: int = 1) -> None:
from sglang.srt.managers.io_struct import (
PauseGenerationReqInput,
ReleaseMemoryOccupationReqInput,
)
# Drain in-flight requests before touching GPU memory
await self._engine.tokenizer_manager.pause_generation(PauseGenerationReqInput())
await self._engine.tokenizer_manager.release_memory_occupation(
ReleaseMemoryOccupationReqInput(tags=_MEMORY_TAGS), None
)
async def wake_up(self) -> None:
from sglang.srt.managers.io_struct import (
ContinueGenerationReqInput,
ResumeMemoryOccupationReqInput,
)
await self._engine.tokenizer_manager.resume_memory_occupation(
ResumeMemoryOccupationReqInput(tags=_MEMORY_TAGS), None
)
await self._engine.tokenizer_manager.continue_generation(
ContinueGenerationReqInput()
)
class CheckpointConfig:
"""Parsed and validated checkpoint configuration from environment variables."""
def __init__(self):
self.ready_file = os.environ["DYN_READY_FOR_CHECKPOINT_FILE"]
self.storage_type = os.environ.get("DYN_CHECKPOINT_STORAGE_TYPE", "pvc")
self.location = os.environ.get("DYN_CHECKPOINT_LOCATION", "")
if not self.location:
checkpoint_path = os.environ.get("DYN_CHECKPOINT_PATH", "").rstrip("/")
checkpoint_hash = os.environ.get("DYN_CHECKPOINT_HASH", "")
if checkpoint_path and checkpoint_hash:
self.location = f"{checkpoint_path}/{checkpoint_hash}"
self.is_checkpoint_job = bool(self.location)
self._checkpoint_done = asyncio.Event()
self._restore_done = asyncio.Event()
def checkpoint_exists(self) -> bool:
"""Check if a completed checkpoint already exists (idempotency).
A checkpoint is complete when its directory exists at the base path root
(not under the tmp/ staging area). Directory presence = done.
"""
if self.storage_type != "pvc":
return False
if os.path.isdir(self.location):
logger.info(f"Existing checkpoint found at {self.location}, skipping")
return True
logger.info(f"No checkpoint at {self.location}, creating new one")
return False
async def run_lifecycle(self, engine_client, sleep_level: int) -> bool:
"""Run the full checkpoint lifecycle after the engine is loaded.
1. Put model to sleep (CRIU-friendly GPU state)
2. Write ready file (triggers DaemonSet checkpoint via readiness probe)
3. Wait for watcher signal (checkpoint complete, restore complete, or failure)
4. If restored: wake model and return True (caller proceeds with registration)
5. If checkpoint done: return False (caller should exit)
"""
# Sleep model for checkpoint
logger.info(f"Putting model to sleep (level={sleep_level})")
await engine_client.sleep(level=sleep_level)
# Install signal handlers before writing the ready file so there is no
# window where the DaemonSet can send SIGUSR1/SIGCONT while the default
# signal disposition (terminate) is still in effect.
self._install_signal_handlers()
# Signal readiness
with open(self.ready_file, "w") as f:
f.write("ready")
logger.info(
"Ready for checkpoint. Waiting for watcher signal "
"(SIGUSR1=checkpoint complete, SIGCONT=restore complete)"
)
try:
event = await self._wait_for_watcher_signal()
if event == "restore":
logger.info("Restore signal detected (SIGCONT)")
logger.info("Waking up model after restore")
await engine_client.wake_up()
return True
# SIGUSR1: checkpoint complete
logger.info("Checkpoint completion signal detected (SIGUSR1)")
return False
finally:
self._remove_signal_handlers()
# Remove the ready file so that a restarting pod does not leave a
# stale marker that could trick the DaemonSet into acting on it.
try:
os.unlink(self.ready_file)
except OSError:
pass
def _install_signal_handlers(self) -> None:
loop = asyncio.get_running_loop()
loop.add_signal_handler(signal.SIGUSR1, self._checkpoint_done.set)
# SIGCONT is used as the restore-complete signal. The chrek DaemonSet
# watcher is the only sender, so there is no conflict with POSIX
# job-control semantics in practice.
loop.add_signal_handler(signal.SIGCONT, self._restore_done.set)
# No handler for checkpoint failure: the watcher sends SIGKILL, which
# terminates the process immediately (cannot be caught).
def _remove_signal_handlers(self) -> None:
loop = asyncio.get_running_loop()
loop.remove_signal_handler(signal.SIGUSR1)
loop.remove_signal_handler(signal.SIGCONT)
async def _wait_for_watcher_signal(self) -> str:
waiters = {
asyncio.create_task(self._checkpoint_done.wait()): "checkpoint",
asyncio.create_task(self._restore_done.wait()): "restore",
}
try:
done, pending = await asyncio.wait(
waiters.keys(), return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
winner = done.pop()
await winner
return waiters[winner]
finally:
for task in waiters:
if not task.done():
task.cancel()
async def handle_checkpoint_mode(server_args) -> tuple[bool, Optional[sgl.Engine]]:
"""Single entry point for checkpoint/restore integration.
Must be called BEFORE runtime creation so the engine can be checkpointed
without active NATS/etcd connections.
Returns:
(should_exit, engine) where:
- (True, None): caller should return immediately (checkpoint already
exists, or checkpoint completed successfully).
- (False, None): not in checkpoint mode — cold-start normally.
- (False, engine): restore completed — caller should use this engine.
"""
if "DYN_READY_FOR_CHECKPOINT_FILE" not in os.environ:
return False, None
# Validate: either a full location or path + hash must be set.
if not os.environ.get("DYN_CHECKPOINT_LOCATION"):
path = os.environ.get("DYN_CHECKPOINT_PATH", "")
hash_ = os.environ.get("DYN_CHECKPOINT_HASH", "")
if not path or not hash_:
raise EnvironmentError(
"Checkpoint mode requires either DYN_CHECKPOINT_LOCATION or both "
"DYN_CHECKPOINT_PATH and DYN_CHECKPOINT_HASH"
)
cfg = CheckpointConfig()
checkpoint_exists = cfg.checkpoint_exists()
if cfg.is_checkpoint_job and checkpoint_exists:
return True, None
if not cfg.is_checkpoint_job and not checkpoint_exists:
return False, None
logger.info("Checkpoint mode enabled (watcher-driven signals)")
# Enable memory_saver + weights CPU backup so weights survive CRIU
# (mirrors vLLM's enable_sleep_mode = True)
server_args.enable_memory_saver = True
server_args.enable_weights_cpu_backup = True
start_time = time.time()
engine = sgl.Engine(server_args=server_args)
logger.info(
f"SGLang engine loaded in {time.time() - start_time:.2f}s (checkpoint mode)"
)
adapter = SGLangCheckpointAdapter(engine)
if not await cfg.run_lifecycle(adapter, _SLEEP_MODE_LEVEL):
return True, None
return False, engine
......@@ -5,7 +5,7 @@ import asyncio
import logging
import os
import time
from typing import Awaitable, Callable
from typing import Awaitable, Callable, Optional
import sglang as sgl
......@@ -61,15 +61,21 @@ async def init_decode(
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
checkpoint_restore_engine: Optional[sgl.Engine] = None,
):
server_args, dynamo_args = config.server_args, config.dynamo_args
if server_args.node_rank >= 1:
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
start_time = time.time()
engine = sgl.Engine(server_args=server_args)
load_time = time.time() - start_time
# Use pre-created engine if provided (checkpoint/restore mode)
if checkpoint_restore_engine is not None:
engine = checkpoint_restore_engine
load_time = 0.0
else:
start_time = time.time()
engine = sgl.Engine(server_args=server_args)
load_time = time.time() - start_time
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
......@@ -145,13 +151,18 @@ async def init_prefill(
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
run_deferred_handlers: Callable[[], Awaitable[None]] | None = None,
checkpoint_restore_engine: Optional[sgl.Engine] = None,
):
server_args, dynamo_args = config.server_args, config.dynamo_args
if server_args.node_rank >= 1:
os.environ["SGLANG_BLOCK_NONZERO_RANK_CHILDREN"] = "0"
engine = sgl.Engine(server_args=server_args)
# Use pre-created engine if provided (checkpoint/restore mode)
if checkpoint_restore_engine is not None:
engine = checkpoint_restore_engine
else:
engine = sgl.Engine(server_args=server_args)
generate_endpoint = runtime.endpoint(
f"{dynamo_args.namespace}.{dynamo_args.component}.{dynamo_args.endpoint}"
......
......@@ -12,6 +12,7 @@ from dynamo.common.constants import DisaggregationMode
from dynamo.common.utils.runtime import create_runtime
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.sglang.args import parse_args
from dynamo.sglang.checkpoint_restore import handle_checkpoint_mode
from dynamo.sglang.init_diffusion import (
init_image_diffusion,
init_llm_diffusion,
......@@ -39,6 +40,13 @@ async def worker():
config.server_args.load_format = setup_gms(config.server_args)
# Checkpoint mode: engine must be created BEFORE runtime (no NATS/etcd during CRIU)
should_exit, checkpoint_restore_engine = await handle_checkpoint_mode(
config.server_args
)
if should_exit:
return
dynamo_args = config.dynamo_args
shutdown_event = asyncio.Event()
shutdown_endpoints: list = []
......@@ -121,6 +129,7 @@ async def worker():
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
checkpoint_restore_engine=checkpoint_restore_engine,
)
else:
await init_prefill(
......@@ -129,6 +138,7 @@ async def worker():
shutdown_event,
shutdown_endpoints,
run_deferred_handlers,
checkpoint_restore_engine=checkpoint_restore_engine,
)
......
......@@ -147,15 +147,15 @@ async def worker():
# CHECKPOINT MODE: Load engine BEFORE runtime creation
# This allows checkpointing GPU state before runtime connections are established
pre_created_engine = None
checkpoint_restore_engine = None
if checkpoint_cfg is not None:
logger.info("Checkpoint mode enabled (watcher-driven signals)")
# Checkpoint mode requires sleep mode — enable before engine init
config.engine_args.enable_sleep_mode = True
pre_created_engine = setup_vllm_engine(config)
engine_client = pre_created_engine[0]
checkpoint_restore_engine = setup_vllm_engine(config)
engine_client = checkpoint_restore_engine[0]
if not await checkpoint_cfg.run_lifecycle(
engine_client, CHECKPOINT_SLEEP_MODE_LEVEL
......@@ -185,7 +185,7 @@ async def worker():
config,
shutdown_event,
shutdown_endpoints,
pre_created_engine=pre_created_engine,
checkpoint_restore_engine=checkpoint_restore_engine,
)
logger.debug("multimodal worker completed")
elif config.omni:
......@@ -193,12 +193,18 @@ async def worker():
logger.debug("init_omni completed")
elif config.disaggregation_mode == DisaggregationMode.PREFILL:
await init_prefill(
runtime, config, shutdown_event, pre_created_engine=pre_created_engine
runtime,
config,
shutdown_event,
checkpoint_restore_engine=checkpoint_restore_engine,
)
logger.debug("init_prefill completed")
else:
await init(
runtime, config, shutdown_event, pre_created_engine=pre_created_engine
runtime,
config,
shutdown_event,
checkpoint_restore_engine=checkpoint_restore_engine,
)
logger.debug("init completed")
......@@ -592,7 +598,7 @@ async def init_prefill(
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
pre_created_engine=None,
checkpoint_restore_engine=None,
):
"""
Instantiate and serve
......@@ -605,14 +611,14 @@ async def init_prefill(
)
# Use pre-created engine if provided (checkpoint mode), otherwise create new
if pre_created_engine is not None:
if checkpoint_restore_engine is not None:
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
_component_gauges,
) = pre_created_engine
) = checkpoint_restore_engine
else:
(
engine_client,
......@@ -734,7 +740,7 @@ async def init(
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
pre_created_engine=None,
checkpoint_restore_engine=None,
):
"""
Instantiate and serve
......@@ -773,14 +779,14 @@ async def init(
)
# Use pre-created engine if provided (checkpoint mode), otherwise create new
if pre_created_engine is not None:
if checkpoint_restore_engine is not None:
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
component_gauges,
) = pre_created_engine
) = checkpoint_restore_engine
# Factory is created after unpack so component_gauges is available
factory = StatLoggerFactory(
endpoint=generate_endpoint,
......
......@@ -103,12 +103,14 @@ class TestCreate:
factory._create_multimodal_worker.assert_called_once() # type: ignore[union-attr]
@pytest.mark.asyncio
async def test_passes_pre_created_engine(self, factory: WorkerFactory) -> None:
async def test_passes_checkpoint_restore_engine(
self, factory: WorkerFactory
) -> None:
config = _make_config(multimodal_worker=True)
runtime = Mock()
shutdown_event = asyncio.Event()
shutdown_endpoints: list = []
pre_created_engine: EngineSetupResult = (
checkpoint_restore_engine: EngineSetupResult = (
Mock(),
Mock(),
Mock(),
......@@ -121,7 +123,7 @@ class TestCreate:
config,
shutdown_event,
shutdown_endpoints,
pre_created_engine=pre_created_engine,
checkpoint_restore_engine=checkpoint_restore_engine,
)
factory._create_multimodal_worker.assert_called_once_with( # type: ignore[union-attr]
......@@ -129,7 +131,7 @@ class TestCreate:
config,
shutdown_event,
shutdown_endpoints,
pre_created_engine=pre_created_engine,
checkpoint_restore_engine=checkpoint_restore_engine,
)
@pytest.mark.asyncio
......
......@@ -58,7 +58,7 @@ class WorkerFactory:
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list,
pre_created_engine: Optional[EngineSetupResult] = None,
checkpoint_restore_engine: Optional[EngineSetupResult] = None,
) -> None:
"""Create the appropriate multimodal worker based on config flags."""
......@@ -72,7 +72,7 @@ class WorkerFactory:
config,
shutdown_event,
shutdown_endpoints,
pre_created_engine=pre_created_engine,
checkpoint_restore_engine=checkpoint_restore_engine,
)
else:
raise ValueError(
......@@ -85,7 +85,7 @@ class WorkerFactory:
config: Config,
shutdown_event: asyncio.Event,
shutdown_endpoints: list, # mutated in place
pre_created_engine: Optional[EngineSetupResult] = None,
checkpoint_restore_engine: Optional[EngineSetupResult] = None,
) -> None:
"""
Initialize multimodal worker component.
......@@ -121,14 +121,14 @@ class WorkerFactory:
[load_lora_endpoint, unload_lora_endpoint, list_loras_endpoint]
)
# Use pre-created engine if provided (checkpoint mode), otherwise create new
if pre_created_engine is not None:
if checkpoint_restore_engine is not None:
(
engine_client,
vllm_config,
_default_sampling_params,
prometheus_temp_dir,
_component_gauges,
) = pre_created_engine
) = checkpoint_restore_engine
else:
(
engine_client,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment