Unverified Commit ca15a79d authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix(snapshot): replace PID-1 SIGUSR1/SIGCONT contract with file sentinels (#8403)


Signed-off-by: default avatarSchwinn Saereesitthipitak <schwinns@nvidia.com>
parent 8fba4f56
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
import asyncio import asyncio
import logging import logging
import os import os
import signal
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from typing import Any, Generic, TypeVar from typing import Any, Generic, TypeVar
from dynamo.common.utils.namespace import get_worker_namespace from dynamo.common.utils.namespace import get_worker_namespace
...@@ -25,23 +25,31 @@ KUBERNETES_OPTIONAL_PODINFO_FILES = { ...@@ -25,23 +25,31 @@ KUBERNETES_OPTIONAL_PODINFO_FILES = {
} }
EngineT = TypeVar("EngineT") EngineT = TypeVar("EngineT")
# Must match snapshotprotocol.{SnapshotCompleteFile,RestoreCompleteFile,ReadyForCheckpointFile}.
SNAPSHOT_COMPLETE_FILE = "snapshot-complete"
RESTORE_COMPLETE_FILE = "restore-complete"
READY_FOR_CHECKPOINT_FILE = "ready-for-checkpoint"
# Poll interval for the snapshot-control directory. Checkpoint and restore
# latencies are seconds, so 100ms is negligible overhead.
_SENTINEL_POLL_INTERVAL_SEC = 0.1
class CheckpointConfig: class CheckpointConfig:
"""Parsed checkpoint configuration plus the watcher-driven lifecycle.""" """Parsed checkpoint configuration plus the sentinel-driven lifecycle."""
def __init__(self, ready_file: str): def __init__(self, control_dir: str):
self.ready_file = ready_file self.control_dir = control_dir
self._checkpoint_done = asyncio.Event() self.ready_file = os.path.join(control_dir, READY_FOR_CHECKPOINT_FILE)
self._restore_done = asyncio.Event()
@classmethod @classmethod
def from_env(cls) -> "CheckpointConfig | None": def from_env(cls) -> "CheckpointConfig | None":
ready_file = os.environ.get("DYN_READY_FOR_CHECKPOINT_FILE") control_dir = os.environ.get("DYN_SNAPSHOT_CONTROL_DIR")
if not ready_file: if not control_dir:
return None return None
configure_checkpoint_transport_env() configure_checkpoint_transport_env()
return cls(ready_file=ready_file) return cls(control_dir=control_dir)
async def run_lifecycle( async def run_lifecycle(
self, self,
...@@ -51,65 +59,53 @@ class CheckpointConfig: ...@@ -51,65 +59,53 @@ class CheckpointConfig:
logger.info("Quiescing model") logger.info("Quiescing model")
await quiesce_controller.quiesce(*quiesce_args) await quiesce_controller.quiesce(*quiesce_args)
self._install_signal_handlers()
try: try:
with open(self.ready_file, "w", encoding="utf-8") as ready_file: with open(self.ready_file, "w", encoding="utf-8") as ready_file:
ready_file.write("ready") ready_file.write("ready")
except Exception:
self._remove_signal_handlers()
raise
logger.info( logger.info(
"Ready for checkpoint. Waiting for watcher signal " "Ready for checkpoint. Polling for sentinel in %s "
"(SIGUSR1=checkpoint complete, SIGCONT=restore complete)" "(snapshot-complete or restore-complete)",
) self.control_dir,
)
try: event = await self._wait_for_sentinel()
event = await self._wait_for_watcher_signal()
if event == "restore":
logger.info("Restore signal detected (SIGCONT)")
logger.info("Resuming model after restore")
await quiesce_controller.resume()
quiesce_controller.mark_resumed()
return True
logger.info("Checkpoint completion signal detected (SIGUSR1)")
return False
finally: finally:
self._remove_signal_handlers() self._cleanup_ready_and_sentinels()
if event == "restore":
logger.info("Restore sentinel detected")
logger.info("Resuming model after restore")
await quiesce_controller.resume()
quiesce_controller.mark_resumed()
return True
logger.info("Snapshot completion sentinel detected")
return False
async def _wait_for_sentinel(self) -> str:
snapshot_path = Path(self.control_dir) / SNAPSHOT_COMPLETE_FILE
restore_path = Path(self.control_dir) / RESTORE_COMPLETE_FILE
while True:
if snapshot_path.exists():
return "checkpoint"
if restore_path.exists():
return "restore"
await asyncio.sleep(_SENTINEL_POLL_INTERVAL_SEC)
def _cleanup_ready_and_sentinels(self) -> None:
for name in (
READY_FOR_CHECKPOINT_FILE,
SNAPSHOT_COMPLETE_FILE,
RESTORE_COMPLETE_FILE,
):
path = os.path.join(self.control_dir, name)
try: try:
os.unlink(self.ready_file) os.unlink(path)
except OSError: except FileNotFoundError:
pass pass
except OSError:
def _install_signal_handlers(self) -> None: logger.exception("Failed to clean up %s at %s", name, path)
loop = asyncio.get_running_loop()
loop.add_signal_handler(signal.SIGUSR1, self._checkpoint_done.set)
loop.add_signal_handler(signal.SIGCONT, self._restore_done.set)
def _remove_signal_handlers(self) -> None:
loop = asyncio.get_running_loop()
loop.remove_signal_handler(signal.SIGUSR1)
loop.remove_signal_handler(signal.SIGCONT)
async def _wait_for_watcher_signal(self) -> str:
waiters = {
asyncio.create_task(self._checkpoint_done.wait()): "checkpoint",
asyncio.create_task(self._restore_done.wait()): "restore",
}
try:
done, pending = await asyncio.wait(
waiters.keys(), return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
winner = done.pop()
await winner
return waiters[winner]
finally:
for task in waiters:
if not task.done():
task.cancel()
def configure_checkpoint_transport_env() -> None: def configure_checkpoint_transport_env() -> None:
......
...@@ -164,7 +164,6 @@ The chart includes built-in validation to prevent all operator conflicts: ...@@ -164,7 +164,6 @@ The chart includes built-in validation to prevent all operator conflicts:
| dynamo-operator.webhook.certManager.certificate.rootCA.duration | string | `"87600h"` | Duration for the root CA certificate (e.g., "87600h" for 10 years). The root CA typically has a much longer lifetime than the leaf certificates it signs. | | dynamo-operator.webhook.certManager.certificate.rootCA.duration | string | `"87600h"` | Duration for the root CA certificate (e.g., "87600h" for 10 years). The root CA typically has a much longer lifetime than the leaf certificates it signs. |
| dynamo-operator.webhook.certManager.certificate.rootCA.renewBefore | string | `"720h"` | Time before root CA expiration to trigger renewal (e.g., "720h" for 30 days). Renewing a CA can be disruptive as all signed certificates must be reissued. | | dynamo-operator.webhook.certManager.certificate.rootCA.renewBefore | string | `"720h"` | Time before root CA expiration to trigger renewal (e.g., "720h" for 30 days). Renewing a CA can be disruptive as all signed certificates must be reissued. |
| dynamo-operator.checkpoint.enabled | bool | `false` | Whether to enable checkpoint/restore functionality | | dynamo-operator.checkpoint.enabled | bool | `false` | Whether to enable checkpoint/restore functionality |
| dynamo-operator.checkpoint.readyForCheckpointFilePath | string | `"/tmp/ready-for-checkpoint"` | Path written by worker when model is loaded and ready for checkpointing |
| grove.tolerations | list | `[]` | Node tolerations for Grove pods | | grove.tolerations | list | `[]` | Node tolerations for Grove pods |
| grove.affinity | object | `{}` | Affinity for Grove pods | | grove.affinity | object | `{}` | Affinity for Grove pods |
| kai-scheduler.global.tolerations | list | `[]` | Node tolerations for kai-scheduler pods | | kai-scheduler.global.tolerations | list | `[]` | Node tolerations for kai-scheduler pods |
......
...@@ -132,9 +132,6 @@ data: ...@@ -132,9 +132,6 @@ data:
{{- if .Values.checkpoint.enabled }} {{- if .Values.checkpoint.enabled }}
checkpoint: checkpoint:
enabled: true enabled: true
{{- if ne (.Values.checkpoint.readyForCheckpointFilePath | toString) "/tmp/ready-for-checkpoint" }}
readyForCheckpointFilePath: {{ .Values.checkpoint.readyForCheckpointFilePath | quote }}
{{- end }}
{{- end }} {{- end }}
{{- if and .Values.discoveryBackend (ne (.Values.discoveryBackend | toString) "kubernetes") }} {{- if and .Values.discoveryBackend (ne (.Values.discoveryBackend | toString) "kubernetes") }}
discovery: discovery:
......
...@@ -139,10 +139,6 @@ checkpoint: ...@@ -139,10 +139,6 @@ checkpoint:
# Enable checkpoint/restore functionality # Enable checkpoint/restore functionality
enabled: false enabled: false
# Path written by worker when model is loaded and ready for checkpointing
# Must match the path expected by checkpoint-enabled runtime images
readyForCheckpointFilePath: "/tmp/ready-for-checkpoint"
# Webhook configuration # Webhook configuration
webhook: webhook:
# Certificate configuration # Certificate configuration
......
...@@ -228,9 +228,6 @@ dynamo-operator: ...@@ -228,9 +228,6 @@ dynamo-operator:
# -- Whether to enable checkpoint/restore functionality # -- Whether to enable checkpoint/restore functionality
enabled: false enabled: false
# -- Path written by worker when model is loaded and ready for checkpointing
readyForCheckpointFilePath: "/tmp/ready-for-checkpoint"
# Grove component - distributed inference orchestration # Grove component - distributed inference orchestration
# Installation is controlled by global.grove.install above. # Installation is controlled by global.grove.install above.
grove: grove:
......
...@@ -85,11 +85,6 @@ func SetDefaultsOperatorConfiguration(obj *OperatorConfiguration) { ...@@ -85,11 +85,6 @@ func SetDefaultsOperatorConfiguration(obj *OperatorConfiguration) {
obj.GPU.DiscoveryEnabled = ptr.To(true) obj.GPU.DiscoveryEnabled = ptr.To(true)
} }
// Checkpoint defaults
if obj.Checkpoint.ReadyForCheckpointFilePath == "" {
obj.Checkpoint.ReadyForCheckpointFilePath = "/tmp/ready-for-checkpoint"
}
// Logging defaults // Logging defaults
if obj.Logging.Level == "" { if obj.Logging.Level == "" {
obj.Logging.Level = "info" obj.Logging.Level = "info"
......
...@@ -245,9 +245,6 @@ type MPIConfiguration struct { ...@@ -245,9 +245,6 @@ type MPIConfiguration struct {
type CheckpointConfiguration struct { type CheckpointConfiguration struct {
// Enabled indicates if checkpoint functionality is enabled // Enabled indicates if checkpoint functionality is enabled
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
// ReadyForCheckpointFilePath signals model readiness for checkpoint jobs
// +kubebuilder:default="/tmp/ready-for-checkpoint"
ReadyForCheckpointFilePath string `json:"readyForCheckpointFilePath"`
// Deprecated: Storage is retained for compatibility and ignored by the // Deprecated: Storage is retained for compatibility and ignored by the
// current snapshot flow. Snapshot storage is discovered from the // current snapshot flow. Snapshot storage is discovered from the
// snapshot-agent DaemonSet instead. // snapshot-agent DaemonSet instead.
......
...@@ -68,7 +68,7 @@ func testScheme() *runtime.Scheme { ...@@ -68,7 +68,7 @@ func testScheme() *runtime.Scheme {
} }
func testInfo() *CheckpointInfo { func testInfo() *CheckpointInfo {
return &CheckpointInfo{Enabled: true, Hash: testHash} return &CheckpointInfo{Enabled: true, Ready: true, Hash: testHash}
} }
func testSnapshotAgentDaemonSet() *appsv1.DaemonSet { func testSnapshotAgentDaemonSet() *appsv1.DaemonSet {
...@@ -184,6 +184,27 @@ func TestCreateOrGetAutoCheckpointSetsDefaultArtifactVersion(t *testing.T) { ...@@ -184,6 +184,27 @@ func TestCreateOrGetAutoCheckpointSetsDefaultArtifactVersion(t *testing.T) {
// --- InjectCheckpointIntoPodSpec tests --- // --- InjectCheckpointIntoPodSpec tests ---
func TestInjectCheckpointIntoPodSpec(t *testing.T) { func TestInjectCheckpointIntoPodSpec(t *testing.T) {
t.Run("not ready checkpoint leaves pod spec untouched", func(t *testing.T) {
podSpec := testPodSpec()
originalCmd := append([]string(nil), podSpec.Containers[0].Command...)
originalArgs := append([]string(nil), podSpec.Containers[0].Args...)
info := &CheckpointInfo{Enabled: true, Ready: false, Hash: testHash}
reader := fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build()
require.NoError(t, InjectCheckpointIntoPodSpec(context.Background(), reader, testNamespace, podSpec, info))
assert.Equal(t, originalCmd, podSpec.Containers[0].Command)
assert.Equal(t, originalArgs, podSpec.Containers[0].Args)
for _, volume := range podSpec.Volumes {
assert.NotEqual(t, snapshotprotocol.SnapshotControlVolumeName, volume.Name)
assert.NotEqual(t, snapshotprotocol.CheckpointVolumeName, volume.Name)
assert.NotEqual(t, consts.PodInfoVolumeName, volume.Name)
}
for _, env := range podSpec.Containers[0].Env {
assert.NotEqual(t, snapshotprotocol.SnapshotControlDirEnv, env.Name)
}
})
t.Run("ready checkpoint injects podinfo and overrides command", func(t *testing.T) { t.Run("ready checkpoint injects podinfo and overrides command", func(t *testing.T) {
podSpec := testPodSpec() podSpec := testPodSpec()
info := &CheckpointInfo{Enabled: true, Ready: true, Identity: ptr.To(testIdentity())} info := &CheckpointInfo{Enabled: true, Ready: true, Identity: ptr.To(testIdentity())}
...@@ -279,7 +300,7 @@ func TestInjectCheckpointIntoPodSpec(t *testing.T) { ...@@ -279,7 +300,7 @@ func TestInjectCheckpointIntoPodSpec(t *testing.T) {
reader client.Reader reader client.Reader
errMsg string errMsg string
}{ }{
{"hash empty and identity nil", testPodSpec(), &CheckpointInfo{Enabled: true}, fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build(), "identity is nil"}, {"hash empty and identity nil", testPodSpec(), &CheckpointInfo{Enabled: true, Ready: true}, fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build(), "identity is nil"},
{"no containers", &corev1.PodSpec{}, testInfo(), fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build(), "no container named"}, {"no containers", &corev1.PodSpec{}, testInfo(), fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build(), "no container named"},
{"snapshot daemonset missing", testPodSpec(), testInfo(), fake.NewClientBuilder().WithScheme(testScheme()).Build(), "no snapshot-agent daemonset found"}, {"snapshot daemonset missing", testPodSpec(), testInfo(), fake.NewClientBuilder().WithScheme(testScheme()).Build(), "no snapshot-agent daemonset found"},
} { } {
......
...@@ -58,7 +58,15 @@ func InjectCheckpointIntoPodSpec( ...@@ -58,7 +58,15 @@ func InjectCheckpointIntoPodSpec(
podSpec *corev1.PodSpec, podSpec *corev1.PodSpec,
checkpointInfo *CheckpointInfo, checkpointInfo *CheckpointInfo,
) error { ) error {
if checkpointInfo == nil || !checkpointInfo.Enabled { // Only mutate the worker pod spec once the checkpoint is Ready. Before
// the checkpoint exists, the worker must cold-start normally without
// the snapshot-control volume, DYN_SNAPSHOT_CONTROL_DIR, checkpoint PVC
// mount, or localhost seccomp profile — otherwise the Python worker
// enters checkpoint mode on env-var presence and sits quiesced waiting
// for a sentinel that only the checkpoint Job and restore-target path
// produce. The checkpoint Job itself is built separately through
// buildCheckpointJob + NewCheckpointJob and does get these.
if checkpointInfo == nil || !checkpointInfo.Enabled || !checkpointInfo.Ready {
return nil return nil
} }
......
...@@ -140,9 +140,6 @@ const ( ...@@ -140,9 +140,6 @@ const (
ResourceStateNotReady = "not_ready" ResourceStateNotReady = "not_ready"
ResourceStateUnknown = "unknown" ResourceStateUnknown = "unknown"
// Environment variables injected into pods
EnvReadyForCheckpointFile = "DYN_READY_FOR_CHECKPOINT_FILE" // Ready-for-checkpoint file path — checkpoint job pods
// Pod identity (Downward API) --- // Pod identity (Downward API) ---
// After CRIU restore, env vars contain stale values from the checkpoint pod. // After CRIU restore, env vars contain stale values from the checkpoint pod.
// The Downward API files at /etc/podinfo always reflect the current pod. // The Downward API files at /etc/podinfo always reflect the current pod.
......
...@@ -6,7 +6,6 @@ package controller ...@@ -6,7 +6,6 @@ package controller
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
configv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/config/v1alpha1" configv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/config/v1alpha1"
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1" nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
...@@ -91,25 +90,11 @@ func buildCheckpointJob( ...@@ -91,25 +90,11 @@ func buildCheckpointJob(
mainContainer.Env, mainContainer.Env,
) )
dynamo.AddStandardEnvVars(mainContainer, config) dynamo.AddStandardEnvVars(mainContainer, config)
mainContainer.Env = append(mainContainer.Env, corev1.EnvVar{
Name: consts.EnvReadyForCheckpointFile,
Value: config.Checkpoint.ReadyForCheckpointFilePath,
})
mainContainer.ReadinessProbe = &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
Exec: &corev1.ExecAction{
Command: []string{"cat", config.Checkpoint.ReadyForCheckpointFilePath},
},
},
InitialDelaySeconds: 15,
PeriodSeconds: 2,
}
mainContainer.LivenessProbe = nil
mainContainer.StartupProbe = nil
// The snapshot agent sends SIGUSR1 to PID 1 of the main container after
checkpoint.EnsurePodInfoMount(mainContainer) checkpoint.EnsurePodInfoMount(mainContainer)
dynamo.ApplySharedMemoryVolumeAndMount(&podTemplate.Spec, mainContainer, ckpt.Spec.Job.SharedMemory) dynamo.ApplySharedMemoryVolumeAndMount(&podTemplate.Spec, mainContainer, ckpt.Spec.Job.SharedMemory)
// NewCheckpointJob handles control volume + readiness probe from the
// snapshot contract.
if ckpt.Spec.GPUMemoryService != nil && ckpt.Spec.GPUMemoryService.Enabled { if ckpt.Spec.GPUMemoryService != nil && ckpt.Spec.GPUMemoryService.Enabled {
claimTemplateName := dra.ResourceClaimTemplateName("checkpoint-"+hash, "worker") claimTemplateName := dra.ResourceClaimTemplateName("checkpoint-"+hash, "worker")
...@@ -129,9 +114,6 @@ func buildCheckpointJob( ...@@ -129,9 +114,6 @@ func buildCheckpointJob(
if err := checkpoint.EnsureGMSCheckpointJobSidecars(&podTemplate.Spec, mainContainer, storage); err != nil { if err := checkpoint.EnsureGMSCheckpointJobSidecars(&podTemplate.Spec, mainContainer, storage); err != nil {
return nil, err return nil, err
} }
// Re-acquire pointer: append in EnsureGMSCheckpointJobSidecars may
// have reallocated the Containers slice.
mainContainer = &podTemplate.Spec.Containers[0]
} }
activeDeadlineSeconds := ckpt.Spec.Job.ActiveDeadlineSeconds activeDeadlineSeconds := ckpt.Spec.Job.ActiveDeadlineSeconds
...@@ -153,16 +135,6 @@ func buildCheckpointJob( ...@@ -153,16 +135,6 @@ func buildCheckpointJob(
} }
wrapLaunchJob := tp*pp > 1 wrapLaunchJob := tp*pp > 1
// For single-GPU jobs (no cuda-checkpoint wrapper), unwrap /bin/sh -c so
// the actual process is PID 1 and receives SIGUSR1 from the snapshot agent.
if !wrapLaunchJob && len(mainContainer.Command) >= 2 &&
mainContainer.Command[len(mainContainer.Command)-1] == "-c" &&
len(mainContainer.Args) == 1 {
parts := strings.Fields(mainContainer.Args[0])
mainContainer.Command = parts[:1]
mainContainer.Args = parts[1:]
}
ttlSecondsAfterFinish := snapshotprotocol.DefaultCheckpointJobTTLSeconds ttlSecondsAfterFinish := snapshotprotocol.DefaultCheckpointJobTTLSeconds
return snapshotprotocol.NewCheckpointJob(podTemplate, snapshotprotocol.CheckpointJobOptions{ return snapshotprotocol.NewCheckpointJob(podTemplate, snapshotprotocol.CheckpointJobOptions{
......
...@@ -78,8 +78,7 @@ func checkpointTestScheme() *runtime.Scheme { ...@@ -78,8 +78,7 @@ func checkpointTestScheme() *runtime.Scheme {
func checkpointTestConfig() *configv1alpha1.OperatorConfiguration { func checkpointTestConfig() *configv1alpha1.OperatorConfiguration {
return &configv1alpha1.OperatorConfiguration{ return &configv1alpha1.OperatorConfiguration{
Checkpoint: configv1alpha1.CheckpointConfiguration{ Checkpoint: configv1alpha1.CheckpointConfiguration{
Enabled: true, Enabled: true,
ReadyForCheckpointFilePath: "/tmp/ready-for-checkpoint",
}, },
} }
} }
...@@ -168,7 +167,7 @@ func TestBuildCheckpointJob(t *testing.T) { ...@@ -168,7 +167,7 @@ func TestBuildCheckpointJob(t *testing.T) {
for _, e := range main.Env { for _, e := range main.Env {
envMap[e.Name] = e.Value envMap[e.Name] = e.Value
} }
assert.Equal(t, "/tmp/ready-for-checkpoint", envMap[consts.EnvReadyForCheckpointFile]) assert.Equal(t, snapshotprotocol.SnapshotControlMountPath, envMap[snapshotprotocol.SnapshotControlDirEnv])
assert.Equal(t, "manual-checkpoint", envMap[consts.DynamoNamespaceEnvVar]) assert.Equal(t, "manual-checkpoint", envMap[consts.DynamoNamespaceEnvVar])
assert.Equal(t, consts.ComponentTypeWorker, envMap[consts.DynamoComponentEnvVar]) assert.Equal(t, consts.ComponentTypeWorker, envMap[consts.DynamoComponentEnvVar])
assert.Equal(t, "worker-1234", envMap[consts.DynamoNamespaceWorkerSuffixEnvVar]) assert.Equal(t, "worker-1234", envMap[consts.DynamoNamespaceWorkerSuffixEnvVar])
...@@ -201,7 +200,7 @@ func TestBuildCheckpointJob(t *testing.T) { ...@@ -201,7 +200,7 @@ func TestBuildCheckpointJob(t *testing.T) {
// Probes: readiness set, liveness/startup cleared // Probes: readiness set, liveness/startup cleared
require.NotNil(t, main.ReadinessProbe) require.NotNil(t, main.ReadinessProbe)
assert.Equal(t, []string{"cat", "/tmp/ready-for-checkpoint"}, main.ReadinessProbe.Exec.Command) assert.Equal(t, []string{"cat", "/snapshot-control/ready-for-checkpoint"}, main.ReadinessProbe.Exec.Command)
assert.Nil(t, main.LivenessProbe) assert.Nil(t, main.LivenessProbe)
assert.Nil(t, main.StartupProbe) assert.Nil(t, main.StartupProbe)
...@@ -212,6 +211,7 @@ func TestBuildCheckpointJob(t *testing.T) { ...@@ -212,6 +211,7 @@ func TestBuildCheckpointJob(t *testing.T) {
} }
assert.False(t, volNames[snapshotprotocol.CheckpointVolumeName]) assert.False(t, volNames[snapshotprotocol.CheckpointVolumeName])
assert.True(t, volNames[consts.PodInfoVolumeName]) assert.True(t, volNames[consts.PodInfoVolumeName])
assert.True(t, volNames[snapshotprotocol.SnapshotControlVolumeName])
mountPaths := make(map[string]string) mountPaths := make(map[string]string)
for _, m := range main.VolumeMounts { for _, m := range main.VolumeMounts {
...@@ -221,6 +221,7 @@ func TestBuildCheckpointJob(t *testing.T) { ...@@ -221,6 +221,7 @@ func TestBuildCheckpointJob(t *testing.T) {
assert.False(t, hasCheckpointMount) assert.False(t, hasCheckpointMount)
assert.Equal(t, consts.PodInfoMountPath, mountPaths[consts.PodInfoVolumeName]) assert.Equal(t, consts.PodInfoMountPath, mountPaths[consts.PodInfoVolumeName])
assert.Equal(t, consts.DefaultSharedMemoryMountPath, mountPaths[consts.KubeValueNameSharedMemory]) assert.Equal(t, consts.DefaultSharedMemoryMountPath, mountPaths[consts.KubeValueNameSharedMemory])
assert.Equal(t, snapshotprotocol.SnapshotControlMountPath, mountPaths[snapshotprotocol.SnapshotControlVolumeName])
foundSharedMemoryVolume := false foundSharedMemoryVolume := false
for _, v := range podSpec.Volumes { for _, v := range podSpec.Volumes {
...@@ -301,7 +302,7 @@ func TestBuildCheckpointJobWrapsWithCudaCheckpointForMultiGPU(t *testing.T) { ...@@ -301,7 +302,7 @@ func TestBuildCheckpointJobWrapsWithCudaCheckpointForMultiGPU(t *testing.T) {
assert.Equal(t, []string{"cuda-checkpoint"}, main.Command) assert.Equal(t, []string{"cuda-checkpoint"}, main.Command)
assert.Equal(t, []string{"--launch-job", "python3", "-m", "dynamo.vllm"}, main.Args) assert.Equal(t, []string{"--launch-job", "python3", "-m", "dynamo.vllm"}, main.Args)
require.NotNil(t, main.ReadinessProbe) require.NotNil(t, main.ReadinessProbe)
assert.Equal(t, []string{"cat", "/tmp/ready-for-checkpoint"}, main.ReadinessProbe.Exec.Command) assert.Equal(t, []string{"cat", "/snapshot-control/ready-for-checkpoint"}, main.ReadinessProbe.Exec.Command)
assert.Nil(t, main.LivenessProbe) assert.Nil(t, main.LivenessProbe)
assert.Nil(t, main.StartupProbe) assert.Nil(t, main.StartupProbe)
...@@ -309,7 +310,7 @@ func TestBuildCheckpointJobWrapsWithCudaCheckpointForMultiGPU(t *testing.T) { ...@@ -309,7 +310,7 @@ func TestBuildCheckpointJobWrapsWithCudaCheckpointForMultiGPU(t *testing.T) {
for _, env := range main.Env { for _, env := range main.Env {
mainEnv[env.Name] = env.Value mainEnv[env.Name] = env.Value
} }
assert.Equal(t, "/tmp/ready-for-checkpoint", mainEnv[consts.EnvReadyForCheckpointFile]) assert.Equal(t, snapshotprotocol.SnapshotControlMountPath, mainEnv[snapshotprotocol.SnapshotControlDirEnv])
assert.Equal(t, "secret", mainEnv["HF_TOKEN"]) assert.Equal(t, "secret", mainEnv["HF_TOKEN"])
sidecar := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, "sidecar") sidecar := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, "sidecar")
...@@ -319,7 +320,7 @@ func TestBuildCheckpointJobWrapsWithCudaCheckpointForMultiGPU(t *testing.T) { ...@@ -319,7 +320,7 @@ func TestBuildCheckpointJobWrapsWithCudaCheckpointForMultiGPU(t *testing.T) {
assert.Nil(t, sidecar.LivenessProbe) assert.Nil(t, sidecar.LivenessProbe)
assert.Nil(t, sidecar.StartupProbe) assert.Nil(t, sidecar.StartupProbe)
for _, env := range sidecar.Env { for _, env := range sidecar.Env {
assert.NotEqual(t, consts.EnvReadyForCheckpointFile, env.Name) assert.NotEqual(t, snapshotprotocol.SnapshotControlDirEnv, env.Name)
} }
} }
...@@ -374,7 +375,7 @@ func TestBuildCheckpointJobAddsGMSSidecars(t *testing.T) { ...@@ -374,7 +375,7 @@ func TestBuildCheckpointJobAddsGMSSidecars(t *testing.T) {
} }
assert.True(t, volNames[gms.SharedVolumeName]) assert.True(t, volNames[gms.SharedVolumeName])
assert.True(t, volNames[snapshotprotocol.CheckpointVolumeName]) assert.True(t, volNames[snapshotprotocol.CheckpointVolumeName])
assert.True(t, volNames[snapshotprotocol.CheckpointVolumeName]) assert.True(t, volNames[snapshotprotocol.SnapshotControlVolumeName])
mainMounts := map[string]string{} mainMounts := map[string]string{}
for _, m := range main.VolumeMounts { for _, m := range main.VolumeMounts {
......
...@@ -20,11 +20,13 @@ const ( ...@@ -20,11 +20,13 @@ const (
ServerContainerName = "gms-server" ServerContainerName = "gms-server"
// SharedVolumeName is the emptyDir volume shared between the GMS server // SharedVolumeName is the emptyDir volume shared between the GMS server
// sidecar and the main workload container for UDS sockets. // sidecar and the main workload container for UDS sockets. The name
SharedVolumeName = "gms-shared" // disambiguates it from the snapshot-control volume, which carries
// checkpoint/restore lifecycle sentinels written by the snapshot agent.
SharedVolumeName = "gms-intrapod-control"
// SharedMountPath is the mount path for the shared GMS socket directory. // SharedMountPath is the mount path for the GMS intra-pod IPC directory.
SharedMountPath = "/shared" SharedMountPath = "/gms-intrapod-control"
// EnvSocketDir is the environment variable name for the GMS UDS socket directory. // EnvSocketDir is the environment variable name for the GMS UDS socket directory.
EnvSocketDir = "GMS_SOCKET_DIR" EnvSocketDir = "GMS_SOCKET_DIR"
......
...@@ -317,7 +317,9 @@ func (w *NodeController) reconcileRestorePod(ctx context.Context, pod *corev1.Po ...@@ -317,7 +317,9 @@ func (w *NodeController) reconcileRestorePod(ctx context.Context, pod *corev1.Po
// 1. Hold and renew the checkpoint lease // 1. Hold and renew the checkpoint lease
// 2. Resolve the container ID and host PID // 2. Resolve the container ID and host PID
// 3. Call executor.Checkpoint (inspect → configure → CUDA lock/checkpoint → CRIU dump → rootfs diff) // 3. Call executor.Checkpoint (inspect → configure → CUDA lock/checkpoint → CRIU dump → rootfs diff)
// 4. SIGUSR1 the process on success (notify workload), SIGKILL on failure (terminate immediately) // 4. Write a snapshot-complete sentinel into the pod's snapshot-control
// volume on success (observed by the workload via inotify), or SIGKILL
// on failure (unrecoverable CUDA-locked process)
// 5. Mark job as completed or failed // 5. Mark job as completed or failed
func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job *batchv1.Job, checkpointID, checkpointLocation, podKey string, startedAt time.Time) error { func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job *batchv1.Job, checkpointID, checkpointLocation, podKey string, startedAt time.Time) error {
releasePodOnExit := true releasePodOnExit := true
...@@ -438,16 +440,21 @@ func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job ...@@ -438,16 +440,21 @@ func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job
return nil return nil
} }
// Step 2: SIGUSR1 on success: notify the workload that checkpoint completed // Step 2: Sentinel on success. Workload observes via polling on the
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointSucceeded", fmt.Sprintf("Checkpoint completed: %s", checkpointID)) // snapshot-control volume; containerPID is a PID inside the container's
if err := snapshotruntime.SendSignalToPID(log, containerPID, syscall.SIGUSR1, "checkpoint complete"); err != nil { // mount namespace, which is all the /host/proc/<pid>/root write path
log.Error(err, "Failed to signal checkpoint completion to runtime process") // requires. The Succeeded event is emitted only after the sentinel has
// been written so a sentinel-write failure doesn't produce conflicting
// Succeeded+Failed events for the same operation.
if err := snapshotruntime.WriteControlSentinel(containerPID, snapshotprotocol.SnapshotCompleteFile); err != nil {
log.Error(err, "Failed to write snapshot-complete sentinel")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error()) emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
if statusErr := setCheckpointStatus(snapshotprotocol.CheckpointStatusFailed); statusErr != nil { if statusErr := setCheckpointStatus(snapshotprotocol.CheckpointStatusFailed); statusErr != nil {
return statusErr return statusErr
} }
return nil return nil
} }
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointSucceeded", fmt.Sprintf("Checkpoint completed: %s", checkpointID))
if err := setCheckpointStatus(snapshotprotocol.CheckpointStatusCompleted); err != nil { if err := setCheckpointStatus(snapshotprotocol.CheckpointStatusCompleted); err != nil {
return err return err
...@@ -458,7 +465,8 @@ func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job ...@@ -458,7 +465,8 @@ func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job
// runRestore runs the full restore workflow for a pod: // runRestore runs the full restore workflow for a pod:
// 1. Mark the current container instance as in_progress // 1. Mark the current container instance as in_progress
// 2. Call executor.Restore (inspect placeholder → nsrestore inside namespace) // 2. Call executor.Restore (inspect placeholder → nsrestore inside namespace)
// 3. SIGCONT the restored process to wake it up // 3. Write a restore-complete sentinel into the pod's snapshot-control
// volume to wake the workload (observed via inotify)
// 4. Wait for the pod to become Ready // 4. Wait for the pod to become Ready
// 5. Mark the container instance as completed // 5. Mark the container instance as completed
func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, containerName, containerID, checkpointID, checkpointLocation, restoreAttemptKey string, startedAt time.Time) error { func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, containerName, containerID, checkpointID, checkpointLocation, restoreAttemptKey string, startedAt time.Time) error {
...@@ -506,13 +514,14 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai ...@@ -506,13 +514,14 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
ContainerName: containerName, ContainerName: containerName,
Clientset: w.clientset, Clientset: w.clientset,
} }
restoredPID, err := executor.Restore(restoreCtx, w.containerd, log, req) placeholderHostPID, err := executor.Restore(restoreCtx, w.containerd, log, req)
if err != nil { if err != nil {
log.Error(err, "External restore failed") log.Error(err, "External restore failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error()) emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil { if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil {
return statusErr return statusErr
} }
// Re-resolve: executor.Restore may have failed before resolving the placeholder.
placeholderHostPID, _, pidErr := snapshotruntime.ResolveContainerByPod(ctx, w.containerd, pod.Name, pod.Namespace, containerName) placeholderHostPID, _, pidErr := snapshotruntime.ResolveContainerByPod(ctx, w.containerd, pod.Name, pod.Namespace, containerName)
if pidErr != nil { if pidErr != nil {
return fmt.Errorf("restore failed and placeholder PID could not be resolved: %w", pidErr) return fmt.Errorf("restore failed and placeholder PID could not be resolved: %w", pidErr)
...@@ -523,31 +532,24 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai ...@@ -523,31 +532,24 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
return nil return nil
} }
// Step 2: SIGCONT the restored process via PID namespace // Step 2: Write restore-complete sentinel. placeholderHostPID came back
placeholderHostPID, _, err := snapshotruntime.ResolveContainerByPod(ctx, w.containerd, pod.Name, pod.Namespace, containerName) // from executor.Restore — any PID inside the container's mount namespace
if err != nil { // reaches /snapshot-control via /host/proc/<pid>/root.
log.Error(err, "Failed to resolve placeholder host PID for signaling") if err := snapshotruntime.WriteControlSentinel(placeholderHostPID, snapshotprotocol.RestoreCompleteFile); err != nil {
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error()) log.Error(err, "Failed to write restore-complete sentinel")
if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil {
return statusErr
}
return fmt.Errorf("failed to resolve placeholder host PID for signaling: %w", err)
}
if err := snapshotruntime.SendSignalViaPIDNamespace(restoreCtx, log, placeholderHostPID, restoredPID, syscall.SIGCONT, "restore complete"); err != nil {
log.Error(err, "Failed to signal restored runtime process")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error()) emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil { if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil {
return statusErr return statusErr
} }
if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore signaling failed"); killErr != nil { if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore sentinel failed"); killErr != nil {
log.Error(killErr, "Failed to kill placeholder after restore signaling failure") log.Error(killErr, "Failed to kill placeholder after restore sentinel failure")
} }
return fmt.Errorf("failed to signal restored runtime process: %w", err) return fmt.Errorf("failed to write restore-complete sentinel: %w", err)
} }
// Step 3: Wait for the pod to become Ready // Step 3: Wait for the pod to become Ready
if err := waitForPodReady(restoreCtx, w.clientset, pod.Namespace, pod.Name, containerName); err != nil { if err := waitForPodReady(restoreCtx, w.clientset, pod.Namespace, pod.Name, containerName); err != nil {
log.Error(err, "Restore post-signal readiness check failed") log.Error(err, "Restore post-sentinel readiness check failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error()) emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil { if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil {
return statusErr return statusErr
...@@ -555,7 +557,7 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai ...@@ -555,7 +557,7 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore readiness failed"); killErr != nil { if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore readiness failed"); killErr != nil {
log.Error(killErr, "Failed to kill placeholder after restore readiness failure") log.Error(killErr, "Failed to kill placeholder after restore readiness failure")
} }
return fmt.Errorf("restore post-signal readiness check failed: %w", err) return fmt.Errorf("restore post-sentinel readiness check failed: %w", err)
} }
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "RestoreSucceeded", fmt.Sprintf("Restore completed from checkpoint %s", checkpointID)) emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "RestoreSucceeded", fmt.Sprintf("Restore completed from checkpoint %s", checkpointID))
......
...@@ -39,6 +39,10 @@ type RestoreRequest struct { ...@@ -39,6 +39,10 @@ type RestoreRequest struct {
// Returns the namespace-relative PID of the restored process. // Returns the namespace-relative PID of the restored process.
// The DaemonSet side inspects the placeholder and launches nsrestore, // The DaemonSet side inspects the placeholder and launches nsrestore,
// which handles rootfs application, CRIU restore, and CUDA restore inside the namespace. // which handles rootfs application, CRIU restore, and CUDA restore inside the namespace.
//
// Returns the placeholder container's host PID so callers can reach into the
// container's mount namespace (e.g. to write sentinels under /snapshot-control)
// without re-resolving via containerd.
func Restore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req RestoreRequest) (int, error) { func Restore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req RestoreRequest) (int, error) {
restoreStart := time.Now() restoreStart := time.Now()
log.Info("=== Starting external restore ===", log.Info("=== Starting external restore ===",
...@@ -90,11 +94,12 @@ func Restore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req ...@@ -90,11 +94,12 @@ func Restore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req
log.Info("=== External restore completed ===", log.Info("=== External restore completed ===",
"restored_pid", result.RestoredPID, "restored_pid", result.RestoredPID,
"placeholder_host_pid", snap.PlaceholderPID,
"validation_duration", time.Since(validationStart), "validation_duration", time.Since(validationStart),
"total_duration", time.Since(restoreStart), "total_duration", time.Since(restoreStart),
) )
return result.RestoredPID, nil return snap.PlaceholderPID, nil
} }
func inspectRestore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req RestoreRequest) (*types.RestoreContainerSnapshot, error) { func inspectRestore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req RestoreRequest) (*types.RestoreContainerSnapshot, error) {
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package runtime
import (
"fmt"
"os"
"path/filepath"
"strconv"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
)
// WriteControlSentinel writes a sentinel file into the workload container's
// snapshot-control volume at SnapshotControlMountPath/<name>, accessed through
// the agent's /host/proc/<pid>/root view of the container's mount namespace.
//
// hostPID must be a PID inside the container's mount namespace (the container
// task PID is the canonical choice). The sentinel is observed by the workload
// via inotify on the control directory; it replaces the SIGUSR1/SIGCONT
// agent-to-workload signals that previously required the workload to run as
// PID 1.
//
// The write uses create-then-rename so the workload never observes a partial
// file.
func WriteControlSentinel(hostPID int, name string) error {
if hostPID <= 0 {
return fmt.Errorf("invalid host PID %d for control sentinel %q", hostPID, name)
}
dir := filepath.Join(HostProcPath, strconv.Itoa(hostPID), "root", snapshotprotocol.SnapshotControlMountPath)
return writeSentinelInDir(dir, name)
}
func writeSentinelInDir(dir, name string) error {
tmpPath := filepath.Join(dir, "."+name+".tmp")
finalPath := filepath.Join(dir, name)
if err := os.WriteFile(tmpPath, []byte("done\n"), 0o644); err != nil {
return fmt.Errorf("write temp sentinel %s: %w", tmpPath, err)
}
if err := os.Rename(tmpPath, finalPath); err != nil {
_ = os.Remove(tmpPath)
return fmt.Errorf("rename sentinel %s -> %s: %w", tmpPath, finalPath, err)
}
return nil
}
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package runtime
import (
"os"
"path/filepath"
"testing"
)
func TestWriteSentinelInDir_CreatesFileAtomically(t *testing.T) {
dir := t.TempDir()
if err := writeSentinelInDir(dir, "snapshot-complete"); err != nil {
t.Fatalf("writeSentinelInDir failed: %v", err)
}
data, err := os.ReadFile(filepath.Join(dir, "snapshot-complete"))
if err != nil {
t.Fatalf("sentinel not found: %v", err)
}
if string(data) != "done\n" {
t.Errorf("unexpected sentinel contents: %q", data)
}
entries, err := os.ReadDir(dir)
if err != nil {
t.Fatalf("failed to read dir: %v", err)
}
for _, e := range entries {
if e.Name() != "snapshot-complete" {
t.Errorf("unexpected leftover file %q in control dir", e.Name())
}
}
}
func TestWriteSentinelInDir_Overwrites(t *testing.T) {
dir := t.TempDir()
if err := writeSentinelInDir(dir, "restore-complete"); err != nil {
t.Fatalf("first write failed: %v", err)
}
if err := writeSentinelInDir(dir, "restore-complete"); err != nil {
t.Fatalf("second write failed: %v", err)
}
data, err := os.ReadFile(filepath.Join(dir, "restore-complete"))
if err != nil {
t.Fatalf("sentinel not found: %v", err)
}
if string(data) != "done\n" {
t.Errorf("unexpected sentinel contents: %q", data)
}
}
func TestWriteSentinelInDir_DirMissing(t *testing.T) {
missing := filepath.Join(t.TempDir(), "does-not-exist")
if err := writeSentinelInDir(missing, "snapshot-complete"); err == nil {
t.Fatal("expected error writing into missing directory")
}
}
func TestWriteControlSentinel_RejectsInvalidPID(t *testing.T) {
if err := WriteControlSentinel(0, "snapshot-complete"); err == nil {
t.Fatal("expected error for PID 0")
}
if err := WriteControlSentinel(-1, "snapshot-complete"); err == nil {
t.Fatal("expected error for negative PID")
}
}
package runtime package runtime
import ( import (
"context"
"fmt" "fmt"
"os/exec"
"strconv"
"strings"
"syscall"
"github.com/go-logr/logr"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
...@@ -21,40 +15,3 @@ func GetNetNSInode(pid int) (uint64, error) { ...@@ -21,40 +15,3 @@ func GetNetNSInode(pid int) (uint64, error) {
} }
return stat.Ino, nil return stat.Ino, nil
} }
// SendSignalViaPIDNamespace sends a signal to a namespace-relative PID by entering the
// PID namespace of referenceHostPID via nsenter.
func SendSignalViaPIDNamespace(ctx context.Context, log logr.Logger, referenceHostPID, targetNamespacePID int, sig syscall.Signal, reason string) error {
if referenceHostPID <= 0 {
return fmt.Errorf("invalid reference host PID %d for signal %d", referenceHostPID, int(sig))
}
if targetNamespacePID <= 0 {
return fmt.Errorf("invalid namespace PID %d for signal %d", targetNamespacePID, int(sig))
}
cmd := exec.CommandContext(
ctx,
"nsenter",
"-t", strconv.Itoa(referenceHostPID),
"-p",
"--",
"kill",
fmt.Sprintf("-%d", int(sig)),
strconv.Itoa(targetNamespacePID),
)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf(
"failed to signal namespace PID %d via reference host PID %d with signal %d (%s): %w (output: %s)",
targetNamespacePID, referenceHostPID, int(sig), reason, err, strings.TrimSpace(string(output)),
)
}
log.Info("Signaled runtime process in PID namespace",
"reference_host_pid", referenceHostPID,
"namespace_pid", targetNamespacePID,
"signal", int(sig),
"reason", reason,
)
return nil
}
...@@ -5,7 +5,7 @@ package protocol ...@@ -5,7 +5,7 @@ package protocol
import ( import (
"fmt" "fmt"
"strings" "path/filepath"
batchv1 "k8s.io/api/batch/v1" batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
...@@ -56,17 +56,36 @@ func NewCheckpointJob(podTemplate *corev1.PodTemplateSpec, opts CheckpointJobOpt ...@@ -56,17 +56,36 @@ func NewCheckpointJob(podTemplate *corev1.PodTemplateSpec, opts CheckpointJobOpt
if opts.SeccompProfile != "" { if opts.SeccompProfile != "" {
EnsureLocalhostSeccompProfile(&podTemplate.Spec, opts.SeccompProfile) EnsureLocalhostSeccompProfile(&podTemplate.Spec, opts.SeccompProfile)
} }
if len(podTemplate.Spec.Containers) == 0 {
return nil, fmt.Errorf("checkpoint job requires at least one container")
}
mainContainer := &podTemplate.Spec.Containers[0]
// Snapshot contract: control volume + ready-file readiness probe. The
// agent reads the pod's Ready condition before starting CRIU dump, so
// the workload signals "model loaded, safe to checkpoint" by writing
// $DYN_SNAPSHOT_CONTROL_DIR/ready-for-checkpoint. Any per-container
// liveness/startup probes are cleared — a checkpoint job runs to a
// quiesce-and-sit state, not a long-lived serving state.
EnsureControlVolume(&podTemplate.Spec, mainContainer)
mainContainer.ReadinessProbe = &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
Exec: &corev1.ExecAction{
Command: []string{"cat", filepath.Join(SnapshotControlMountPath, ReadyForCheckpointFile)},
},
},
PeriodSeconds: 1,
}
mainContainer.LivenessProbe = nil
mainContainer.StartupProbe = nil
if opts.WrapLaunchJob { if opts.WrapLaunchJob {
if len(podTemplate.Spec.Containers) == 0 { if len(mainContainer.Command) == 0 {
return nil, fmt.Errorf("checkpoint job requires at least one container")
}
container := &podTemplate.Spec.Containers[0]
if len(container.Command) == 0 {
return nil, fmt.Errorf("checkpoint job requires container.command when cuda-checkpoint launch-job wrapping is enabled") return nil, fmt.Errorf("checkpoint job requires container.command when cuda-checkpoint launch-job wrapping is enabled")
} }
container.Command, container.Args = wrapWithCudaCheckpointLaunchJob( mainContainer.Command, mainContainer.Args = wrapWithCudaCheckpointLaunchJob(
container.Command, mainContainer.Command,
container.Args, mainContainer.Args,
) )
} }
...@@ -157,18 +176,13 @@ func EnsureLocalhostSeccompProfile(podSpec *corev1.PodSpec, profile string) { ...@@ -157,18 +176,13 @@ func EnsureLocalhostSeccompProfile(podSpec *corev1.PodSpec, profile string) {
} }
} }
// wrapWithCudaCheckpointLaunchJob rewrites the container's entrypoint so the
// workload is launched under `cuda-checkpoint --launch-job`, required for
// multi-GPU checkpoints. The original command and args are preserved as-is
// (including shell-form entrypoints): workload-to-agent signaling now uses
// file sentinels in the snapshot-control volume, so an intervening shell at
// PID 1 is no longer an issue.
func wrapWithCudaCheckpointLaunchJob(command []string, args []string) ([]string, []string) { func wrapWithCudaCheckpointLaunchJob(command []string, args []string) ([]string, []string) {
// Unwrap "/bin/sh -c <single-string>" so cuda-checkpoint launches the
// actual process directly. Otherwise sh sits between cuda-checkpoint and
// the real process and swallows SIGUSR1.
if len(command) >= 2 && command[len(command)-1] == "-c" && len(args) == 1 {
shell := command[:len(command)-1] // e.g. ["/bin/sh"] — discarded
_ = shell
parts := strings.Fields(args[0])
command = parts[:1] // e.g. ["python3"]
args = parts[1:] // e.g. ["-m", "dynamo.vllm", "--model", ...]
}
wrappedArgs := make([]string, 0, len(command)+len(args)+1) wrappedArgs := make([]string, 0, len(command)+len(args)+1)
wrappedArgs = append(wrappedArgs, "--launch-job") wrappedArgs = append(wrappedArgs, "--launch-job")
wrappedArgs = append(wrappedArgs, command...) wrappedArgs = append(wrappedArgs, command...)
......
...@@ -63,11 +63,27 @@ func TestNewCheckpointJob(t *testing.T) { ...@@ -63,11 +63,27 @@ func TestNewCheckpointJob(t *testing.T) {
if job.Spec.Template.Annotations[CheckpointArtifactVersionAnnotation] != "2" { if job.Spec.Template.Annotations[CheckpointArtifactVersionAnnotation] != "2" {
t.Fatalf("expected checkpoint artifact version annotation on template: %#v", job.Spec.Template.Annotations) t.Fatalf("expected checkpoint artifact version annotation on template: %#v", job.Spec.Template.Annotations)
} }
if len(job.Spec.Template.Spec.Volumes) != 0 { if len(job.Spec.Template.Spec.Volumes) != 1 || job.Spec.Template.Spec.Volumes[0].Name != SnapshotControlVolumeName {
t.Fatalf("expected no checkpoint volume, got %#v", job.Spec.Template.Spec.Volumes) t.Fatalf("expected only %s volume, got %#v", SnapshotControlVolumeName, job.Spec.Template.Spec.Volumes)
} }
if len(job.Spec.Template.Spec.Containers[0].VolumeMounts) != 0 { main := &job.Spec.Template.Spec.Containers[0]
t.Fatalf("expected no checkpoint volume mount, got %#v", job.Spec.Template.Spec.Containers[0].VolumeMounts) if len(main.VolumeMounts) != 1 || main.VolumeMounts[0].MountPath != SnapshotControlMountPath {
t.Fatalf("expected only %s mount at %s, got %#v", SnapshotControlVolumeName, SnapshotControlMountPath, main.VolumeMounts)
}
if main.ReadinessProbe == nil || main.ReadinessProbe.Exec == nil {
t.Fatalf("expected ready-file readiness probe, got %#v", main.ReadinessProbe)
}
expectedProbe := []string{"cat", SnapshotControlMountPath + "/" + ReadyForCheckpointFile}
if len(main.ReadinessProbe.Exec.Command) != len(expectedProbe) {
t.Fatalf("expected readiness probe %#v, got %#v", expectedProbe, main.ReadinessProbe.Exec.Command)
}
for i := range expectedProbe {
if main.ReadinessProbe.Exec.Command[i] != expectedProbe[i] {
t.Fatalf("expected readiness probe %#v, got %#v", expectedProbe, main.ReadinessProbe.Exec.Command)
}
}
if main.LivenessProbe != nil || main.StartupProbe != nil {
t.Fatalf("expected liveness and startup probes cleared, got liveness=%#v startup=%#v", main.LivenessProbe, main.StartupProbe)
} }
if job.Spec.Template.Spec.RestartPolicy != corev1.RestartPolicyNever { if job.Spec.Template.Spec.RestartPolicy != corev1.RestartPolicyNever {
t.Fatalf("expected restartPolicy Never, got %#v", job.Spec.Template.Spec.RestartPolicy) t.Fatalf("expected restartPolicy Never, got %#v", job.Spec.Template.Spec.RestartPolicy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment