Unverified Commit ca15a79d authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix(snapshot): replace PID-1 SIGUSR1/SIGCONT contract with file sentinels (#8403)


Signed-off-by: default avatarSchwinn Saereesitthipitak <schwinns@nvidia.com>
parent 8fba4f56
......@@ -6,8 +6,8 @@
import asyncio
import logging
import os
import signal
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Generic, TypeVar
from dynamo.common.utils.namespace import get_worker_namespace
......@@ -25,23 +25,31 @@ KUBERNETES_OPTIONAL_PODINFO_FILES = {
}
EngineT = TypeVar("EngineT")
# Must match snapshotprotocol.{SnapshotCompleteFile,RestoreCompleteFile,ReadyForCheckpointFile}.
SNAPSHOT_COMPLETE_FILE = "snapshot-complete"
RESTORE_COMPLETE_FILE = "restore-complete"
READY_FOR_CHECKPOINT_FILE = "ready-for-checkpoint"
# Poll interval for the snapshot-control directory. Checkpoint and restore
# latencies are seconds, so 100ms is negligible overhead.
_SENTINEL_POLL_INTERVAL_SEC = 0.1
class CheckpointConfig:
"""Parsed checkpoint configuration plus the watcher-driven lifecycle."""
"""Parsed checkpoint configuration plus the sentinel-driven lifecycle."""
def __init__(self, ready_file: str):
self.ready_file = ready_file
self._checkpoint_done = asyncio.Event()
self._restore_done = asyncio.Event()
def __init__(self, control_dir: str):
self.control_dir = control_dir
self.ready_file = os.path.join(control_dir, READY_FOR_CHECKPOINT_FILE)
@classmethod
def from_env(cls) -> "CheckpointConfig | None":
ready_file = os.environ.get("DYN_READY_FOR_CHECKPOINT_FILE")
if not ready_file:
control_dir = os.environ.get("DYN_SNAPSHOT_CONTROL_DIR")
if not control_dir:
return None
configure_checkpoint_transport_env()
return cls(ready_file=ready_file)
return cls(control_dir=control_dir)
async def run_lifecycle(
self,
......@@ -51,65 +59,53 @@ class CheckpointConfig:
logger.info("Quiescing model")
await quiesce_controller.quiesce(*quiesce_args)
self._install_signal_handlers()
try:
with open(self.ready_file, "w", encoding="utf-8") as ready_file:
ready_file.write("ready")
except Exception:
self._remove_signal_handlers()
raise
logger.info(
"Ready for checkpoint. Waiting for watcher signal "
"(SIGUSR1=checkpoint complete, SIGCONT=restore complete)"
"Ready for checkpoint. Polling for sentinel in %s "
"(snapshot-complete or restore-complete)",
self.control_dir,
)
try:
event = await self._wait_for_watcher_signal()
event = await self._wait_for_sentinel()
finally:
self._cleanup_ready_and_sentinels()
if event == "restore":
logger.info("Restore signal detected (SIGCONT)")
logger.info("Restore sentinel detected")
logger.info("Resuming model after restore")
await quiesce_controller.resume()
quiesce_controller.mark_resumed()
return True
logger.info("Checkpoint completion signal detected (SIGUSR1)")
logger.info("Snapshot completion sentinel detected")
return False
finally:
self._remove_signal_handlers()
try:
os.unlink(self.ready_file)
except OSError:
pass
def _install_signal_handlers(self) -> None:
loop = asyncio.get_running_loop()
loop.add_signal_handler(signal.SIGUSR1, self._checkpoint_done.set)
loop.add_signal_handler(signal.SIGCONT, self._restore_done.set)
def _remove_signal_handlers(self) -> None:
loop = asyncio.get_running_loop()
loop.remove_signal_handler(signal.SIGUSR1)
loop.remove_signal_handler(signal.SIGCONT)
async def _wait_for_watcher_signal(self) -> str:
waiters = {
asyncio.create_task(self._checkpoint_done.wait()): "checkpoint",
asyncio.create_task(self._restore_done.wait()): "restore",
}
async def _wait_for_sentinel(self) -> str:
snapshot_path = Path(self.control_dir) / SNAPSHOT_COMPLETE_FILE
restore_path = Path(self.control_dir) / RESTORE_COMPLETE_FILE
while True:
if snapshot_path.exists():
return "checkpoint"
if restore_path.exists():
return "restore"
await asyncio.sleep(_SENTINEL_POLL_INTERVAL_SEC)
def _cleanup_ready_and_sentinels(self) -> None:
for name in (
READY_FOR_CHECKPOINT_FILE,
SNAPSHOT_COMPLETE_FILE,
RESTORE_COMPLETE_FILE,
):
path = os.path.join(self.control_dir, name)
try:
done, pending = await asyncio.wait(
waiters.keys(), return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
winner = done.pop()
await winner
return waiters[winner]
finally:
for task in waiters:
if not task.done():
task.cancel()
os.unlink(path)
except FileNotFoundError:
pass
except OSError:
logger.exception("Failed to clean up %s at %s", name, path)
def configure_checkpoint_transport_env() -> None:
......
......@@ -164,7 +164,6 @@ The chart includes built-in validation to prevent all operator conflicts:
| dynamo-operator.webhook.certManager.certificate.rootCA.duration | string | `"87600h"` | Duration for the root CA certificate (e.g., "87600h" for 10 years). The root CA typically has a much longer lifetime than the leaf certificates it signs. |
| dynamo-operator.webhook.certManager.certificate.rootCA.renewBefore | string | `"720h"` | Time before root CA expiration to trigger renewal (e.g., "720h" for 30 days). Renewing a CA can be disruptive as all signed certificates must be reissued. |
| dynamo-operator.checkpoint.enabled | bool | `false` | Whether to enable checkpoint/restore functionality |
| dynamo-operator.checkpoint.readyForCheckpointFilePath | string | `"/tmp/ready-for-checkpoint"` | Path written by worker when model is loaded and ready for checkpointing |
| grove.tolerations | list | `[]` | Node tolerations for Grove pods |
| grove.affinity | object | `{}` | Affinity for Grove pods |
| kai-scheduler.global.tolerations | list | `[]` | Node tolerations for kai-scheduler pods |
......
......@@ -132,9 +132,6 @@ data:
{{- if .Values.checkpoint.enabled }}
checkpoint:
enabled: true
{{- if ne (.Values.checkpoint.readyForCheckpointFilePath | toString) "/tmp/ready-for-checkpoint" }}
readyForCheckpointFilePath: {{ .Values.checkpoint.readyForCheckpointFilePath | quote }}
{{- end }}
{{- end }}
{{- if and .Values.discoveryBackend (ne (.Values.discoveryBackend | toString) "kubernetes") }}
discovery:
......
......@@ -139,10 +139,6 @@ checkpoint:
# Enable checkpoint/restore functionality
enabled: false
# Path written by worker when model is loaded and ready for checkpointing
# Must match the path expected by checkpoint-enabled runtime images
readyForCheckpointFilePath: "/tmp/ready-for-checkpoint"
# Webhook configuration
webhook:
# Certificate configuration
......
......@@ -228,9 +228,6 @@ dynamo-operator:
# -- Whether to enable checkpoint/restore functionality
enabled: false
# -- Path written by worker when model is loaded and ready for checkpointing
readyForCheckpointFilePath: "/tmp/ready-for-checkpoint"
# Grove component - distributed inference orchestration
# Installation is controlled by global.grove.install above.
grove:
......
......@@ -85,11 +85,6 @@ func SetDefaultsOperatorConfiguration(obj *OperatorConfiguration) {
obj.GPU.DiscoveryEnabled = ptr.To(true)
}
// Checkpoint defaults
if obj.Checkpoint.ReadyForCheckpointFilePath == "" {
obj.Checkpoint.ReadyForCheckpointFilePath = "/tmp/ready-for-checkpoint"
}
// Logging defaults
if obj.Logging.Level == "" {
obj.Logging.Level = "info"
......
......@@ -245,9 +245,6 @@ type MPIConfiguration struct {
type CheckpointConfiguration struct {
// Enabled indicates if checkpoint functionality is enabled
Enabled bool `json:"enabled"`
// ReadyForCheckpointFilePath signals model readiness for checkpoint jobs
// +kubebuilder:default="/tmp/ready-for-checkpoint"
ReadyForCheckpointFilePath string `json:"readyForCheckpointFilePath"`
// Deprecated: Storage is retained for compatibility and ignored by the
// current snapshot flow. Snapshot storage is discovered from the
// snapshot-agent DaemonSet instead.
......
......@@ -68,7 +68,7 @@ func testScheme() *runtime.Scheme {
}
func testInfo() *CheckpointInfo {
return &CheckpointInfo{Enabled: true, Hash: testHash}
return &CheckpointInfo{Enabled: true, Ready: true, Hash: testHash}
}
func testSnapshotAgentDaemonSet() *appsv1.DaemonSet {
......@@ -184,6 +184,27 @@ func TestCreateOrGetAutoCheckpointSetsDefaultArtifactVersion(t *testing.T) {
// --- InjectCheckpointIntoPodSpec tests ---
func TestInjectCheckpointIntoPodSpec(t *testing.T) {
t.Run("not ready checkpoint leaves pod spec untouched", func(t *testing.T) {
podSpec := testPodSpec()
originalCmd := append([]string(nil), podSpec.Containers[0].Command...)
originalArgs := append([]string(nil), podSpec.Containers[0].Args...)
info := &CheckpointInfo{Enabled: true, Ready: false, Hash: testHash}
reader := fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build()
require.NoError(t, InjectCheckpointIntoPodSpec(context.Background(), reader, testNamespace, podSpec, info))
assert.Equal(t, originalCmd, podSpec.Containers[0].Command)
assert.Equal(t, originalArgs, podSpec.Containers[0].Args)
for _, volume := range podSpec.Volumes {
assert.NotEqual(t, snapshotprotocol.SnapshotControlVolumeName, volume.Name)
assert.NotEqual(t, snapshotprotocol.CheckpointVolumeName, volume.Name)
assert.NotEqual(t, consts.PodInfoVolumeName, volume.Name)
}
for _, env := range podSpec.Containers[0].Env {
assert.NotEqual(t, snapshotprotocol.SnapshotControlDirEnv, env.Name)
}
})
t.Run("ready checkpoint injects podinfo and overrides command", func(t *testing.T) {
podSpec := testPodSpec()
info := &CheckpointInfo{Enabled: true, Ready: true, Identity: ptr.To(testIdentity())}
......@@ -279,7 +300,7 @@ func TestInjectCheckpointIntoPodSpec(t *testing.T) {
reader client.Reader
errMsg string
}{
{"hash empty and identity nil", testPodSpec(), &CheckpointInfo{Enabled: true}, fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build(), "identity is nil"},
{"hash empty and identity nil", testPodSpec(), &CheckpointInfo{Enabled: true, Ready: true}, fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build(), "identity is nil"},
{"no containers", &corev1.PodSpec{}, testInfo(), fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build(), "no container named"},
{"snapshot daemonset missing", testPodSpec(), testInfo(), fake.NewClientBuilder().WithScheme(testScheme()).Build(), "no snapshot-agent daemonset found"},
} {
......
......@@ -58,7 +58,15 @@ func InjectCheckpointIntoPodSpec(
podSpec *corev1.PodSpec,
checkpointInfo *CheckpointInfo,
) error {
if checkpointInfo == nil || !checkpointInfo.Enabled {
// Only mutate the worker pod spec once the checkpoint is Ready. Before
// the checkpoint exists, the worker must cold-start normally without
// the snapshot-control volume, DYN_SNAPSHOT_CONTROL_DIR, checkpoint PVC
// mount, or localhost seccomp profile — otherwise the Python worker
// enters checkpoint mode on env-var presence and sits quiesced waiting
// for a sentinel that only the checkpoint Job and restore-target path
// produce. The checkpoint Job itself is built separately through
// buildCheckpointJob + NewCheckpointJob and does get these.
if checkpointInfo == nil || !checkpointInfo.Enabled || !checkpointInfo.Ready {
return nil
}
......
......@@ -140,9 +140,6 @@ const (
ResourceStateNotReady = "not_ready"
ResourceStateUnknown = "unknown"
// Environment variables injected into pods
EnvReadyForCheckpointFile = "DYN_READY_FOR_CHECKPOINT_FILE" // Ready-for-checkpoint file path — checkpoint job pods
// Pod identity (Downward API) ---
// After CRIU restore, env vars contain stale values from the checkpoint pod.
// The Downward API files at /etc/podinfo always reflect the current pod.
......
......@@ -6,7 +6,6 @@ package controller
import (
"context"
"fmt"
"strings"
configv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/config/v1alpha1"
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
......@@ -91,25 +90,11 @@ func buildCheckpointJob(
mainContainer.Env,
)
dynamo.AddStandardEnvVars(mainContainer, config)
mainContainer.Env = append(mainContainer.Env, corev1.EnvVar{
Name: consts.EnvReadyForCheckpointFile,
Value: config.Checkpoint.ReadyForCheckpointFilePath,
})
mainContainer.ReadinessProbe = &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
Exec: &corev1.ExecAction{
Command: []string{"cat", config.Checkpoint.ReadyForCheckpointFilePath},
},
},
InitialDelaySeconds: 15,
PeriodSeconds: 2,
}
mainContainer.LivenessProbe = nil
mainContainer.StartupProbe = nil
// The snapshot agent sends SIGUSR1 to PID 1 of the main container after
checkpoint.EnsurePodInfoMount(mainContainer)
dynamo.ApplySharedMemoryVolumeAndMount(&podTemplate.Spec, mainContainer, ckpt.Spec.Job.SharedMemory)
// NewCheckpointJob handles control volume + readiness probe from the
// snapshot contract.
if ckpt.Spec.GPUMemoryService != nil && ckpt.Spec.GPUMemoryService.Enabled {
claimTemplateName := dra.ResourceClaimTemplateName("checkpoint-"+hash, "worker")
......@@ -129,9 +114,6 @@ func buildCheckpointJob(
if err := checkpoint.EnsureGMSCheckpointJobSidecars(&podTemplate.Spec, mainContainer, storage); err != nil {
return nil, err
}
// Re-acquire pointer: append in EnsureGMSCheckpointJobSidecars may
// have reallocated the Containers slice.
mainContainer = &podTemplate.Spec.Containers[0]
}
activeDeadlineSeconds := ckpt.Spec.Job.ActiveDeadlineSeconds
......@@ -153,16 +135,6 @@ func buildCheckpointJob(
}
wrapLaunchJob := tp*pp > 1
// For single-GPU jobs (no cuda-checkpoint wrapper), unwrap /bin/sh -c so
// the actual process is PID 1 and receives SIGUSR1 from the snapshot agent.
if !wrapLaunchJob && len(mainContainer.Command) >= 2 &&
mainContainer.Command[len(mainContainer.Command)-1] == "-c" &&
len(mainContainer.Args) == 1 {
parts := strings.Fields(mainContainer.Args[0])
mainContainer.Command = parts[:1]
mainContainer.Args = parts[1:]
}
ttlSecondsAfterFinish := snapshotprotocol.DefaultCheckpointJobTTLSeconds
return snapshotprotocol.NewCheckpointJob(podTemplate, snapshotprotocol.CheckpointJobOptions{
......
......@@ -79,7 +79,6 @@ func checkpointTestConfig() *configv1alpha1.OperatorConfiguration {
return &configv1alpha1.OperatorConfiguration{
Checkpoint: configv1alpha1.CheckpointConfiguration{
Enabled: true,
ReadyForCheckpointFilePath: "/tmp/ready-for-checkpoint",
},
}
}
......@@ -168,7 +167,7 @@ func TestBuildCheckpointJob(t *testing.T) {
for _, e := range main.Env {
envMap[e.Name] = e.Value
}
assert.Equal(t, "/tmp/ready-for-checkpoint", envMap[consts.EnvReadyForCheckpointFile])
assert.Equal(t, snapshotprotocol.SnapshotControlMountPath, envMap[snapshotprotocol.SnapshotControlDirEnv])
assert.Equal(t, "manual-checkpoint", envMap[consts.DynamoNamespaceEnvVar])
assert.Equal(t, consts.ComponentTypeWorker, envMap[consts.DynamoComponentEnvVar])
assert.Equal(t, "worker-1234", envMap[consts.DynamoNamespaceWorkerSuffixEnvVar])
......@@ -201,7 +200,7 @@ func TestBuildCheckpointJob(t *testing.T) {
// Probes: readiness set, liveness/startup cleared
require.NotNil(t, main.ReadinessProbe)
assert.Equal(t, []string{"cat", "/tmp/ready-for-checkpoint"}, main.ReadinessProbe.Exec.Command)
assert.Equal(t, []string{"cat", "/snapshot-control/ready-for-checkpoint"}, main.ReadinessProbe.Exec.Command)
assert.Nil(t, main.LivenessProbe)
assert.Nil(t, main.StartupProbe)
......@@ -212,6 +211,7 @@ func TestBuildCheckpointJob(t *testing.T) {
}
assert.False(t, volNames[snapshotprotocol.CheckpointVolumeName])
assert.True(t, volNames[consts.PodInfoVolumeName])
assert.True(t, volNames[snapshotprotocol.SnapshotControlVolumeName])
mountPaths := make(map[string]string)
for _, m := range main.VolumeMounts {
......@@ -221,6 +221,7 @@ func TestBuildCheckpointJob(t *testing.T) {
assert.False(t, hasCheckpointMount)
assert.Equal(t, consts.PodInfoMountPath, mountPaths[consts.PodInfoVolumeName])
assert.Equal(t, consts.DefaultSharedMemoryMountPath, mountPaths[consts.KubeValueNameSharedMemory])
assert.Equal(t, snapshotprotocol.SnapshotControlMountPath, mountPaths[snapshotprotocol.SnapshotControlVolumeName])
foundSharedMemoryVolume := false
for _, v := range podSpec.Volumes {
......@@ -301,7 +302,7 @@ func TestBuildCheckpointJobWrapsWithCudaCheckpointForMultiGPU(t *testing.T) {
assert.Equal(t, []string{"cuda-checkpoint"}, main.Command)
assert.Equal(t, []string{"--launch-job", "python3", "-m", "dynamo.vllm"}, main.Args)
require.NotNil(t, main.ReadinessProbe)
assert.Equal(t, []string{"cat", "/tmp/ready-for-checkpoint"}, main.ReadinessProbe.Exec.Command)
assert.Equal(t, []string{"cat", "/snapshot-control/ready-for-checkpoint"}, main.ReadinessProbe.Exec.Command)
assert.Nil(t, main.LivenessProbe)
assert.Nil(t, main.StartupProbe)
......@@ -309,7 +310,7 @@ func TestBuildCheckpointJobWrapsWithCudaCheckpointForMultiGPU(t *testing.T) {
for _, env := range main.Env {
mainEnv[env.Name] = env.Value
}
assert.Equal(t, "/tmp/ready-for-checkpoint", mainEnv[consts.EnvReadyForCheckpointFile])
assert.Equal(t, snapshotprotocol.SnapshotControlMountPath, mainEnv[snapshotprotocol.SnapshotControlDirEnv])
assert.Equal(t, "secret", mainEnv["HF_TOKEN"])
sidecar := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, "sidecar")
......@@ -319,7 +320,7 @@ func TestBuildCheckpointJobWrapsWithCudaCheckpointForMultiGPU(t *testing.T) {
assert.Nil(t, sidecar.LivenessProbe)
assert.Nil(t, sidecar.StartupProbe)
for _, env := range sidecar.Env {
assert.NotEqual(t, consts.EnvReadyForCheckpointFile, env.Name)
assert.NotEqual(t, snapshotprotocol.SnapshotControlDirEnv, env.Name)
}
}
......@@ -374,7 +375,7 @@ func TestBuildCheckpointJobAddsGMSSidecars(t *testing.T) {
}
assert.True(t, volNames[gms.SharedVolumeName])
assert.True(t, volNames[snapshotprotocol.CheckpointVolumeName])
assert.True(t, volNames[snapshotprotocol.CheckpointVolumeName])
assert.True(t, volNames[snapshotprotocol.SnapshotControlVolumeName])
mainMounts := map[string]string{}
for _, m := range main.VolumeMounts {
......
......@@ -20,11 +20,13 @@ const (
ServerContainerName = "gms-server"
// SharedVolumeName is the emptyDir volume shared between the GMS server
// sidecar and the main workload container for UDS sockets.
SharedVolumeName = "gms-shared"
// sidecar and the main workload container for UDS sockets. The name
// disambiguates it from the snapshot-control volume, which carries
// checkpoint/restore lifecycle sentinels written by the snapshot agent.
SharedVolumeName = "gms-intrapod-control"
// SharedMountPath is the mount path for the shared GMS socket directory.
SharedMountPath = "/shared"
// SharedMountPath is the mount path for the GMS intra-pod IPC directory.
SharedMountPath = "/gms-intrapod-control"
// EnvSocketDir is the environment variable name for the GMS UDS socket directory.
EnvSocketDir = "GMS_SOCKET_DIR"
......
......@@ -317,7 +317,9 @@ func (w *NodeController) reconcileRestorePod(ctx context.Context, pod *corev1.Po
// 1. Hold and renew the checkpoint lease
// 2. Resolve the container ID and host PID
// 3. Call executor.Checkpoint (inspect → configure → CUDA lock/checkpoint → CRIU dump → rootfs diff)
// 4. SIGUSR1 the process on success (notify workload), SIGKILL on failure (terminate immediately)
// 4. Write a snapshot-complete sentinel into the pod's snapshot-control
// volume on success (observed by the workload via inotify), or SIGKILL
// on failure (unrecoverable CUDA-locked process)
// 5. Mark job as completed or failed
func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job *batchv1.Job, checkpointID, checkpointLocation, podKey string, startedAt time.Time) error {
releasePodOnExit := true
......@@ -438,16 +440,21 @@ func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job
return nil
}
// Step 2: SIGUSR1 on success: notify the workload that checkpoint completed
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointSucceeded", fmt.Sprintf("Checkpoint completed: %s", checkpointID))
if err := snapshotruntime.SendSignalToPID(log, containerPID, syscall.SIGUSR1, "checkpoint complete"); err != nil {
log.Error(err, "Failed to signal checkpoint completion to runtime process")
// Step 2: Sentinel on success. Workload observes via polling on the
// snapshot-control volume; containerPID is a PID inside the container's
// mount namespace, which is all the /host/proc/<pid>/root write path
// requires. The Succeeded event is emitted only after the sentinel has
// been written so a sentinel-write failure doesn't produce conflicting
// Succeeded+Failed events for the same operation.
if err := snapshotruntime.WriteControlSentinel(containerPID, snapshotprotocol.SnapshotCompleteFile); err != nil {
log.Error(err, "Failed to write snapshot-complete sentinel")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
if statusErr := setCheckpointStatus(snapshotprotocol.CheckpointStatusFailed); statusErr != nil {
return statusErr
}
return nil
}
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointSucceeded", fmt.Sprintf("Checkpoint completed: %s", checkpointID))
if err := setCheckpointStatus(snapshotprotocol.CheckpointStatusCompleted); err != nil {
return err
......@@ -458,7 +465,8 @@ func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job
// runRestore runs the full restore workflow for a pod:
// 1. Mark the current container instance as in_progress
// 2. Call executor.Restore (inspect placeholder → nsrestore inside namespace)
// 3. SIGCONT the restored process to wake it up
// 3. Write a restore-complete sentinel into the pod's snapshot-control
// volume to wake the workload (observed via inotify)
// 4. Wait for the pod to become Ready
// 5. Mark the container instance as completed
func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, containerName, containerID, checkpointID, checkpointLocation, restoreAttemptKey string, startedAt time.Time) error {
......@@ -506,13 +514,14 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
ContainerName: containerName,
Clientset: w.clientset,
}
restoredPID, err := executor.Restore(restoreCtx, w.containerd, log, req)
placeholderHostPID, err := executor.Restore(restoreCtx, w.containerd, log, req)
if err != nil {
log.Error(err, "External restore failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil {
return statusErr
}
// Re-resolve: executor.Restore may have failed before resolving the placeholder.
placeholderHostPID, _, pidErr := snapshotruntime.ResolveContainerByPod(ctx, w.containerd, pod.Name, pod.Namespace, containerName)
if pidErr != nil {
return fmt.Errorf("restore failed and placeholder PID could not be resolved: %w", pidErr)
......@@ -523,31 +532,24 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
return nil
}
// Step 2: SIGCONT the restored process via PID namespace
placeholderHostPID, _, err := snapshotruntime.ResolveContainerByPod(ctx, w.containerd, pod.Name, pod.Namespace, containerName)
if err != nil {
log.Error(err, "Failed to resolve placeholder host PID for signaling")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil {
return statusErr
}
return fmt.Errorf("failed to resolve placeholder host PID for signaling: %w", err)
}
if err := snapshotruntime.SendSignalViaPIDNamespace(restoreCtx, log, placeholderHostPID, restoredPID, syscall.SIGCONT, "restore complete"); err != nil {
log.Error(err, "Failed to signal restored runtime process")
// Step 2: Write restore-complete sentinel. placeholderHostPID came back
// from executor.Restore — any PID inside the container's mount namespace
// reaches /snapshot-control via /host/proc/<pid>/root.
if err := snapshotruntime.WriteControlSentinel(placeholderHostPID, snapshotprotocol.RestoreCompleteFile); err != nil {
log.Error(err, "Failed to write restore-complete sentinel")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil {
return statusErr
}
if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore signaling failed"); killErr != nil {
log.Error(killErr, "Failed to kill placeholder after restore signaling failure")
if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore sentinel failed"); killErr != nil {
log.Error(killErr, "Failed to kill placeholder after restore sentinel failure")
}
return fmt.Errorf("failed to signal restored runtime process: %w", err)
return fmt.Errorf("failed to write restore-complete sentinel: %w", err)
}
// Step 3: Wait for the pod to become Ready
if err := waitForPodReady(restoreCtx, w.clientset, pod.Namespace, pod.Name, containerName); err != nil {
log.Error(err, "Restore post-signal readiness check failed")
log.Error(err, "Restore post-sentinel readiness check failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil {
return statusErr
......@@ -555,7 +557,7 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore readiness failed"); killErr != nil {
log.Error(killErr, "Failed to kill placeholder after restore readiness failure")
}
return fmt.Errorf("restore post-signal readiness check failed: %w", err)
return fmt.Errorf("restore post-sentinel readiness check failed: %w", err)
}
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "RestoreSucceeded", fmt.Sprintf("Restore completed from checkpoint %s", checkpointID))
......
......@@ -39,6 +39,10 @@ type RestoreRequest struct {
// Returns the namespace-relative PID of the restored process.
// The DaemonSet side inspects the placeholder and launches nsrestore,
// which handles rootfs application, CRIU restore, and CUDA restore inside the namespace.
//
// Returns the placeholder container's host PID so callers can reach into the
// container's mount namespace (e.g. to write sentinels under /snapshot-control)
// without re-resolving via containerd.
func Restore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req RestoreRequest) (int, error) {
restoreStart := time.Now()
log.Info("=== Starting external restore ===",
......@@ -90,11 +94,12 @@ func Restore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req
log.Info("=== External restore completed ===",
"restored_pid", result.RestoredPID,
"placeholder_host_pid", snap.PlaceholderPID,
"validation_duration", time.Since(validationStart),
"total_duration", time.Since(restoreStart),
)
return result.RestoredPID, nil
return snap.PlaceholderPID, nil
}
func inspectRestore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req RestoreRequest) (*types.RestoreContainerSnapshot, error) {
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package runtime
import (
"fmt"
"os"
"path/filepath"
"strconv"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
)
// WriteControlSentinel writes a sentinel file into the workload container's
// snapshot-control volume at SnapshotControlMountPath/<name>, accessed through
// the agent's /host/proc/<pid>/root view of the container's mount namespace.
//
// hostPID must be a PID inside the container's mount namespace (the container
// task PID is the canonical choice). The sentinel is observed by the workload
// via inotify on the control directory; it replaces the SIGUSR1/SIGCONT
// agent-to-workload signals that previously required the workload to run as
// PID 1.
//
// The write uses create-then-rename so the workload never observes a partial
// file.
func WriteControlSentinel(hostPID int, name string) error {
if hostPID <= 0 {
return fmt.Errorf("invalid host PID %d for control sentinel %q", hostPID, name)
}
dir := filepath.Join(HostProcPath, strconv.Itoa(hostPID), "root", snapshotprotocol.SnapshotControlMountPath)
return writeSentinelInDir(dir, name)
}
func writeSentinelInDir(dir, name string) error {
tmpPath := filepath.Join(dir, "."+name+".tmp")
finalPath := filepath.Join(dir, name)
if err := os.WriteFile(tmpPath, []byte("done\n"), 0o644); err != nil {
return fmt.Errorf("write temp sentinel %s: %w", tmpPath, err)
}
if err := os.Rename(tmpPath, finalPath); err != nil {
_ = os.Remove(tmpPath)
return fmt.Errorf("rename sentinel %s -> %s: %w", tmpPath, finalPath, err)
}
return nil
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package runtime
import (
"os"
"path/filepath"
"testing"
)
func TestWriteSentinelInDir_CreatesFileAtomically(t *testing.T) {
dir := t.TempDir()
if err := writeSentinelInDir(dir, "snapshot-complete"); err != nil {
t.Fatalf("writeSentinelInDir failed: %v", err)
}
data, err := os.ReadFile(filepath.Join(dir, "snapshot-complete"))
if err != nil {
t.Fatalf("sentinel not found: %v", err)
}
if string(data) != "done\n" {
t.Errorf("unexpected sentinel contents: %q", data)
}
entries, err := os.ReadDir(dir)
if err != nil {
t.Fatalf("failed to read dir: %v", err)
}
for _, e := range entries {
if e.Name() != "snapshot-complete" {
t.Errorf("unexpected leftover file %q in control dir", e.Name())
}
}
}
func TestWriteSentinelInDir_Overwrites(t *testing.T) {
dir := t.TempDir()
if err := writeSentinelInDir(dir, "restore-complete"); err != nil {
t.Fatalf("first write failed: %v", err)
}
if err := writeSentinelInDir(dir, "restore-complete"); err != nil {
t.Fatalf("second write failed: %v", err)
}
data, err := os.ReadFile(filepath.Join(dir, "restore-complete"))
if err != nil {
t.Fatalf("sentinel not found: %v", err)
}
if string(data) != "done\n" {
t.Errorf("unexpected sentinel contents: %q", data)
}
}
func TestWriteSentinelInDir_DirMissing(t *testing.T) {
missing := filepath.Join(t.TempDir(), "does-not-exist")
if err := writeSentinelInDir(missing, "snapshot-complete"); err == nil {
t.Fatal("expected error writing into missing directory")
}
}
func TestWriteControlSentinel_RejectsInvalidPID(t *testing.T) {
if err := WriteControlSentinel(0, "snapshot-complete"); err == nil {
t.Fatal("expected error for PID 0")
}
if err := WriteControlSentinel(-1, "snapshot-complete"); err == nil {
t.Fatal("expected error for negative PID")
}
}
package runtime
import (
"context"
"fmt"
"os/exec"
"strconv"
"strings"
"syscall"
"github.com/go-logr/logr"
"golang.org/x/sys/unix"
)
......@@ -21,40 +15,3 @@ func GetNetNSInode(pid int) (uint64, error) {
}
return stat.Ino, nil
}
// SendSignalViaPIDNamespace sends a signal to a namespace-relative PID by entering the
// PID namespace of referenceHostPID via nsenter.
func SendSignalViaPIDNamespace(ctx context.Context, log logr.Logger, referenceHostPID, targetNamespacePID int, sig syscall.Signal, reason string) error {
if referenceHostPID <= 0 {
return fmt.Errorf("invalid reference host PID %d for signal %d", referenceHostPID, int(sig))
}
if targetNamespacePID <= 0 {
return fmt.Errorf("invalid namespace PID %d for signal %d", targetNamespacePID, int(sig))
}
cmd := exec.CommandContext(
ctx,
"nsenter",
"-t", strconv.Itoa(referenceHostPID),
"-p",
"--",
"kill",
fmt.Sprintf("-%d", int(sig)),
strconv.Itoa(targetNamespacePID),
)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf(
"failed to signal namespace PID %d via reference host PID %d with signal %d (%s): %w (output: %s)",
targetNamespacePID, referenceHostPID, int(sig), reason, err, strings.TrimSpace(string(output)),
)
}
log.Info("Signaled runtime process in PID namespace",
"reference_host_pid", referenceHostPID,
"namespace_pid", targetNamespacePID,
"signal", int(sig),
"reason", reason,
)
return nil
}
......@@ -5,7 +5,7 @@ package protocol
import (
"fmt"
"strings"
"path/filepath"
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
......@@ -56,17 +56,36 @@ func NewCheckpointJob(podTemplate *corev1.PodTemplateSpec, opts CheckpointJobOpt
if opts.SeccompProfile != "" {
EnsureLocalhostSeccompProfile(&podTemplate.Spec, opts.SeccompProfile)
}
if opts.WrapLaunchJob {
if len(podTemplate.Spec.Containers) == 0 {
return nil, fmt.Errorf("checkpoint job requires at least one container")
}
container := &podTemplate.Spec.Containers[0]
if len(container.Command) == 0 {
mainContainer := &podTemplate.Spec.Containers[0]
// Snapshot contract: control volume + ready-file readiness probe. The
// agent reads the pod's Ready condition before starting CRIU dump, so
// the workload signals "model loaded, safe to checkpoint" by writing
// $DYN_SNAPSHOT_CONTROL_DIR/ready-for-checkpoint. Any per-container
// liveness/startup probes are cleared — a checkpoint job runs to a
// quiesce-and-sit state, not a long-lived serving state.
EnsureControlVolume(&podTemplate.Spec, mainContainer)
mainContainer.ReadinessProbe = &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
Exec: &corev1.ExecAction{
Command: []string{"cat", filepath.Join(SnapshotControlMountPath, ReadyForCheckpointFile)},
},
},
PeriodSeconds: 1,
}
mainContainer.LivenessProbe = nil
mainContainer.StartupProbe = nil
if opts.WrapLaunchJob {
if len(mainContainer.Command) == 0 {
return nil, fmt.Errorf("checkpoint job requires container.command when cuda-checkpoint launch-job wrapping is enabled")
}
container.Command, container.Args = wrapWithCudaCheckpointLaunchJob(
container.Command,
container.Args,
mainContainer.Command, mainContainer.Args = wrapWithCudaCheckpointLaunchJob(
mainContainer.Command,
mainContainer.Args,
)
}
......@@ -157,18 +176,13 @@ func EnsureLocalhostSeccompProfile(podSpec *corev1.PodSpec, profile string) {
}
}
// wrapWithCudaCheckpointLaunchJob rewrites the container's entrypoint so the
// workload is launched under `cuda-checkpoint --launch-job`, required for
// multi-GPU checkpoints. The original command and args are preserved as-is
// (including shell-form entrypoints): workload-to-agent signaling now uses
// file sentinels in the snapshot-control volume, so an intervening shell at
// PID 1 is no longer an issue.
func wrapWithCudaCheckpointLaunchJob(command []string, args []string) ([]string, []string) {
// Unwrap "/bin/sh -c <single-string>" so cuda-checkpoint launches the
// actual process directly. Otherwise sh sits between cuda-checkpoint and
// the real process and swallows SIGUSR1.
if len(command) >= 2 && command[len(command)-1] == "-c" && len(args) == 1 {
shell := command[:len(command)-1] // e.g. ["/bin/sh"] — discarded
_ = shell
parts := strings.Fields(args[0])
command = parts[:1] // e.g. ["python3"]
args = parts[1:] // e.g. ["-m", "dynamo.vllm", "--model", ...]
}
wrappedArgs := make([]string, 0, len(command)+len(args)+1)
wrappedArgs = append(wrappedArgs, "--launch-job")
wrappedArgs = append(wrappedArgs, command...)
......
......@@ -63,11 +63,27 @@ func TestNewCheckpointJob(t *testing.T) {
if job.Spec.Template.Annotations[CheckpointArtifactVersionAnnotation] != "2" {
t.Fatalf("expected checkpoint artifact version annotation on template: %#v", job.Spec.Template.Annotations)
}
if len(job.Spec.Template.Spec.Volumes) != 0 {
t.Fatalf("expected no checkpoint volume, got %#v", job.Spec.Template.Spec.Volumes)
if len(job.Spec.Template.Spec.Volumes) != 1 || job.Spec.Template.Spec.Volumes[0].Name != SnapshotControlVolumeName {
t.Fatalf("expected only %s volume, got %#v", SnapshotControlVolumeName, job.Spec.Template.Spec.Volumes)
}
if len(job.Spec.Template.Spec.Containers[0].VolumeMounts) != 0 {
t.Fatalf("expected no checkpoint volume mount, got %#v", job.Spec.Template.Spec.Containers[0].VolumeMounts)
main := &job.Spec.Template.Spec.Containers[0]
if len(main.VolumeMounts) != 1 || main.VolumeMounts[0].MountPath != SnapshotControlMountPath {
t.Fatalf("expected only %s mount at %s, got %#v", SnapshotControlVolumeName, SnapshotControlMountPath, main.VolumeMounts)
}
if main.ReadinessProbe == nil || main.ReadinessProbe.Exec == nil {
t.Fatalf("expected ready-file readiness probe, got %#v", main.ReadinessProbe)
}
expectedProbe := []string{"cat", SnapshotControlMountPath + "/" + ReadyForCheckpointFile}
if len(main.ReadinessProbe.Exec.Command) != len(expectedProbe) {
t.Fatalf("expected readiness probe %#v, got %#v", expectedProbe, main.ReadinessProbe.Exec.Command)
}
for i := range expectedProbe {
if main.ReadinessProbe.Exec.Command[i] != expectedProbe[i] {
t.Fatalf("expected readiness probe %#v, got %#v", expectedProbe, main.ReadinessProbe.Exec.Command)
}
}
if main.LivenessProbe != nil || main.StartupProbe != nil {
t.Fatalf("expected liveness and startup probes cleared, got liveness=%#v startup=%#v", main.LivenessProbe, main.StartupProbe)
}
if job.Spec.Template.Spec.RestartPolicy != corev1.RestartPolicyNever {
t.Fatalf("expected restartPolicy Never, got %#v", job.Spec.Template.Spec.RestartPolicy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment