Unverified Commit 38bb9d37 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

refactor: clean up checkpoint orchestration (#7309)


Signed-off-by: default avatarSchwinn Saereesitthipitak <schwinns@nvidia.com>
parent 9ea3acad
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
# ============================================================================= # =============================================================================
ARG DOCKER_PROXY ARG DOCKER_PROXY
ARG GO_VERSION=1.25 ARG GO_VERSION=1.25
ARG CRIU_VERSION=v4.2 ARG CRIU_REPO=https://github.com/dfeigin-nv/criu.git
ARG CRIU_VERSION=add-aio-and-parallel-memfd
ARG AGENT_BASE_IMAGE=nvcr.io/nvidia/cuda-dl-base:25.11-cuda13.0-devel-ubuntu24.04 ARG AGENT_BASE_IMAGE=nvcr.io/nvidia/cuda-dl-base:25.11-cuda13.0-devel-ubuntu24.04
# For placeholder target only - this default allows agent builds to succeed, # For placeholder target only - this default allows agent builds to succeed,
...@@ -74,6 +75,7 @@ RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -ldflags="-w -s ...@@ -74,6 +75,7 @@ RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -ldflags="-w -s
# ============================================================================= # =============================================================================
FROM ubuntu:24.04 AS criu-builder FROM ubuntu:24.04 AS criu-builder
ARG CRIU_REPO
ARG CRIU_VERSION ARG CRIU_VERSION
RUN apt-get update && apt-get install -y --no-install-recommends \ RUN apt-get update && apt-get install -y --no-install-recommends \
...@@ -97,7 +99,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ ...@@ -97,7 +99,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
uuid-dev \ uuid-dev \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
RUN git clone --branch ${CRIU_VERSION} https://github.com/checkpoint-restore/criu.git /tmp/criu \ RUN git clone --depth 1 --branch ${CRIU_VERSION} ${CRIU_REPO} /tmp/criu \
&& cd /tmp/criu \ && cd /tmp/criu \
&& make -j$(nproc) \ && make -j$(nproc) \
&& make DESTDIR=/criu-install install-criu install-lib install-cuda_plugin && make DESTDIR=/criu-install install-criu install-lib install-cuda_plugin
......
// Package main provides the snapshot DaemonSet agent. // Package main provides the snapshot-agent DaemonSet entrypoint.
// The agent watches for pods with checkpoint/restore labels on its node // The agent runs the node-local snapshot controller and delegates CRIU/CUDA
// and triggers operations via the orchestrators. // execution to the snapshot executor workflows.
package main package main
import ( import (
...@@ -13,8 +13,8 @@ import ( ...@@ -13,8 +13,8 @@ import (
"github.com/go-logr/logr" "github.com/go-logr/logr"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/common" "github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/common"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/controller"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/logging" "github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/logging"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/watcher"
) )
func main() { func main() {
...@@ -43,37 +43,36 @@ func main() { ...@@ -43,37 +43,36 @@ func main() {
agentLog.Info("Starting snapshot agent", agentLog.Info("Starting snapshot agent",
"node", cfg.NodeName, "node", cfg.NodeName,
"checkpoint_dir", cfg.BasePath, "restricted_namespace", cfg.RestrictedNamespace,
"watch_namespace", cfg.RestrictedNamespace,
) )
podWatcher, err := watcher.NewWatcher(cfg, ctrd, rootLog.WithName("watcher")) nodeController, err := controller.NewNodeController(cfg, ctrd, rootLog.WithName("controller"))
if err != nil { if err != nil {
fatal(agentLog, err, "Failed to create pod watcher") fatal(agentLog, err, "Failed to create snapshot node controller")
} }
// Run watcher in the background // Run the node-local controller in the background.
watcherDone := make(chan error, 1) controllerDone := make(chan error, 1)
go func() { go func() {
agentLog.Info("Pod watcher started") agentLog.Info("Snapshot node controller started")
watcherDone <- podWatcher.Start(ctx) controllerDone <- nodeController.Run(ctx)
}() }()
// Wait for signal or watcher exit // Wait for signal or controller exit.
select { select {
case <-sigChan: case <-sigChan:
agentLog.Info("Shutting down") agentLog.Info("Shutting down")
cancel() cancel()
select { select {
case err := <-watcherDone: case err := <-controllerDone:
if err != nil { if err != nil {
agentLog.Error(err, "Pod watcher exited with error during shutdown") agentLog.Error(err, "Snapshot node controller exited with error during shutdown")
} }
default: default:
} }
case err := <-watcherDone: case err := <-controllerDone:
if err != nil { if err != nil {
fatal(agentLog, err, "Pod watcher exited with error") fatal(agentLog, err, "Snapshot node controller exited with error")
} }
} }
......
...@@ -8,8 +8,8 @@ import ( ...@@ -8,8 +8,8 @@ import (
"github.com/go-logr/logr" "github.com/go-logr/logr"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/executor"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/logging" "github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/logging"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/orchestrate"
) )
func main() { func main() {
...@@ -25,13 +25,13 @@ func main() { ...@@ -25,13 +25,13 @@ func main() {
fatal(log, nil, "--checkpoint-path is required") fatal(log, nil, "--checkpoint-path is required")
} }
opts := orchestrate.RestoreOptions{ opts := executor.RestoreOptions{
CheckpointPath: *checkpointPath, CheckpointPath: *checkpointPath,
CUDADeviceMap: *cudaDeviceMap, CUDADeviceMap: *cudaDeviceMap,
CgroupRoot: *cgroupRoot, CgroupRoot: *cgroupRoot,
} }
restoredPID, err := orchestrate.RestoreInNamespace(context.Background(), opts, log) restoredPID, err := executor.RestoreInNamespace(context.Background(), opts, log)
if err != nil { if err != nil {
fatal(log, err, "restore failed") fatal(log, err, "restore failed")
} }
......
...@@ -77,11 +77,7 @@ func CaptureRootfsDiff(upperDir, checkpointDir string, exclusions types.OverlayS ...@@ -77,11 +77,7 @@ func CaptureRootfsDiff(upperDir, checkpointDir string, exclusions types.OverlayS
// buildExclusions merges exclusion lists and normalizes paths for tar --exclude patterns. // buildExclusions merges exclusion lists and normalizes paths for tar --exclude patterns.
func buildExclusions(s types.OverlaySettings) []string { func buildExclusions(s types.OverlaySettings) []string {
total := len(s.SystemDirs) + len(s.CacheDirs) + len(s.AdditionalExclusions) exclusions := append([]string(nil), s.Exclusions...)
exclusions := make([]string, 0, total)
exclusions = append(exclusions, s.SystemDirs...)
exclusions = append(exclusions, s.CacheDirs...)
exclusions = append(exclusions, s.AdditionalExclusions...)
for i, p := range exclusions { for i, p := range exclusions {
if strings.HasPrefix(p, "*") { if strings.HasPrefix(p, "*") {
continue continue
......
...@@ -18,11 +18,9 @@ func TestBuildExclusions(t *testing.T) { ...@@ -18,11 +18,9 @@ func TestBuildExclusions(t *testing.T) {
want map[string]bool // expected entries (true = must be present) want map[string]bool // expected entries (true = must be present)
}{ }{
{ {
name: "merges all lists and normalizes paths", name: "normalizes rooted paths",
settings: types.OverlaySettings{ settings: types.OverlaySettings{
SystemDirs: []string{"/proc", "/sys"}, Exclusions: []string{"/proc", "/sys", "/root/.cache", "/tmp"},
CacheDirs: []string{"/root/.cache"},
AdditionalExclusions: []string{"/tmp"},
}, },
want: map[string]bool{ want: map[string]bool{
"./proc": true, "./proc": true,
...@@ -34,7 +32,7 @@ func TestBuildExclusions(t *testing.T) { ...@@ -34,7 +32,7 @@ func TestBuildExclusions(t *testing.T) {
{ {
name: "strips leading dot and slash before prepending ./", name: "strips leading dot and slash before prepending ./",
settings: types.OverlaySettings{ settings: types.OverlaySettings{
SystemDirs: []string{"./proc", "/sys", "tmp"}, Exclusions: []string{"./proc", "/sys", "tmp"},
}, },
want: map[string]bool{ want: map[string]bool{
"./proc": true, "./proc": true,
...@@ -45,11 +43,13 @@ func TestBuildExclusions(t *testing.T) { ...@@ -45,11 +43,13 @@ func TestBuildExclusions(t *testing.T) {
{ {
name: "glob patterns starting with * are untouched", name: "glob patterns starting with * are untouched",
settings: types.OverlaySettings{ settings: types.OverlaySettings{
AdditionalExclusions: []string{"*.pyc", "*/__pycache__"}, Exclusions: []string{"*/.cache/huggingface", "*/.cache/vllm/torch_compile_cache", "*.pyc", "*/__pycache__"},
}, },
want: map[string]bool{ want: map[string]bool{
"*.pyc": true, "*/.cache/huggingface": true,
"*/__pycache__": true, "*/.cache/vllm/torch_compile_cache": true,
"*.pyc": true,
"*/__pycache__": true,
}, },
}, },
{ {
......
// Package watcher provides Kubernetes pod watching for automatic checkpoint/restore. // Package controller implements the node-local control loop inside snapshot-agent.
// The watcher is the sole entry point for snapshot operations — it detects pods with // It does not own CRDs or replace the operator. Instead it watches pod, job, and
// checkpoint/restore labels and calls the orchestrators directly. // lease state on the current node and delegates CRIU/CUDA execution to the
package watcher // snapshot executor workflows.
package controller
import ( import (
"context" "context"
...@@ -15,6 +16,8 @@ import ( ...@@ -15,6 +16,8 @@ import (
"github.com/containerd/containerd" "github.com/containerd/containerd"
"github.com/go-logr/logr" "github.com/go-logr/logr"
"github.com/google/uuid"
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels" "k8s.io/apimachinery/pkg/labels"
...@@ -24,24 +27,29 @@ import ( ...@@ -24,24 +27,29 @@ import (
"k8s.io/client-go/tools/cache" "k8s.io/client-go/tools/cache"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/common" "github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/common"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/orchestrate" "github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/executor"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/types" "github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/types"
) )
const ( const (
kubeLabelIsCheckpointSource = "nvidia.com/snapshot-is-checkpoint-source" kubeLabelIsCheckpointSource = "nvidia.com/snapshot-is-checkpoint-source"
kubeLabelCheckpointHash = "nvidia.com/snapshot-checkpoint-hash" kubeLabelCheckpointHash = "nvidia.com/snapshot-checkpoint-hash"
kubeLabelIsRestoreTarget = "nvidia.com/snapshot-is-restore-target" kubeLabelIsRestoreTarget = "nvidia.com/snapshot-is-restore-target"
kubeAnnotationCheckpointStatus = "nvidia.com/snapshot-checkpoint-status" kubeAnnotationCheckpointLocation = "nvidia.com/snapshot-checkpoint-location"
kubeAnnotationRestoreStatus = "nvidia.com/snapshot-restore-status" kubeAnnotationCheckpointStorageType = "nvidia.com/snapshot-checkpoint-storage-type"
kubeAnnotationCheckpointStatus = "nvidia.com/snapshot-checkpoint-status"
kubeAnnotationRestoreStatus = "nvidia.com/snapshot-restore-status"
kubeAnnotationRestoreContainerID = "nvidia.com/snapshot-restore-container-id"
) )
// Watcher watches for pods with checkpoint/restore labels and triggers operations. // NodeController watches local-node pods with checkpoint metadata and reconciles
type Watcher struct { // snapshot execution for checkpoint and restore requests.
type NodeController struct {
config *types.AgentConfig config *types.AgentConfig
clientset kubernetes.Interface clientset kubernetes.Interface
containerd *containerd.Client containerd *containerd.Client
log logr.Logger log logr.Logger
holderID string
inFlight map[string]struct{} inFlight map[string]struct{}
inFlightMu sync.Mutex inFlightMu sync.Mutex
...@@ -49,12 +57,12 @@ type Watcher struct { ...@@ -49,12 +57,12 @@ type Watcher struct {
stopCh chan struct{} stopCh chan struct{}
} }
// NewWatcher creates a new pod watcher. // NewNodeController creates the node-local controller that runs inside snapshot-agent.
func NewWatcher( func NewNodeController(
cfg *types.AgentConfig, cfg *types.AgentConfig,
containerd *containerd.Client, containerd *containerd.Client,
log logr.Logger, log logr.Logger,
) (*Watcher, error) { ) (*NodeController, error) {
restConfig, err := rest.InClusterConfig() restConfig, err := rest.InClusterConfig()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get in-cluster config: %w", err) return nil, fmt.Errorf("failed to get in-cluster config: %w", err)
...@@ -65,19 +73,20 @@ func NewWatcher( ...@@ -65,19 +73,20 @@ func NewWatcher(
return nil, fmt.Errorf("failed to create kubernetes client: %w", err) return nil, fmt.Errorf("failed to create kubernetes client: %w", err)
} }
return &Watcher{ return &NodeController{
config: cfg, config: cfg,
clientset: clientset, clientset: clientset,
containerd: containerd, containerd: containerd,
log: log, log: log,
holderID: "snapshot-agent/" + uuid.NewString(),
inFlight: make(map[string]struct{}), inFlight: make(map[string]struct{}),
stopCh: make(chan struct{}), stopCh: make(chan struct{}),
}, nil }, nil
} }
// Start begins watching for pods and processing checkpoint/restore events. // Run starts the local pod informers and processes checkpoint/restore events.
func (w *Watcher) Start(ctx context.Context) error { func (w *NodeController) Run(ctx context.Context) error {
w.log.Info("Starting pod watcher", w.log.Info("Starting snapshot node controller",
"node", w.config.NodeName, "node", w.config.NodeName,
"checkpoint", kubeLabelIsCheckpointSource, "checkpoint", kubeLabelIsCheckpointSource,
"restore", kubeLabelIsRestoreTarget, "restore", kubeLabelIsRestoreTarget,
...@@ -115,14 +124,14 @@ func (w *Watcher) Start(ctx context.Context) error { ...@@ -115,14 +124,14 @@ func (w *Watcher) Start(ctx context.Context) error {
if !ok { if !ok {
return return
} }
w.handleCheckpointPodEvent(ctx, pod) w.reconcileCheckpointPod(ctx, pod)
}, },
UpdateFunc: func(_, newObj interface{}) { UpdateFunc: func(_, newObj interface{}) {
pod, ok := podFromInformerObj(newObj) pod, ok := podFromInformerObj(newObj)
if !ok { if !ok {
return return
} }
w.handleCheckpointPodEvent(ctx, pod) w.reconcileCheckpointPod(ctx, pod)
}, },
}); err != nil { }); err != nil {
return fmt.Errorf("failed to add checkpoint informer handler: %w", err) return fmt.Errorf("failed to add checkpoint informer handler: %w", err)
...@@ -152,14 +161,14 @@ func (w *Watcher) Start(ctx context.Context) error { ...@@ -152,14 +161,14 @@ func (w *Watcher) Start(ctx context.Context) error {
if !ok { if !ok {
return return
} }
w.handleRestorePodEvent(ctx, pod) w.reconcileRestorePod(ctx, pod)
}, },
UpdateFunc: func(_, newObj interface{}) { UpdateFunc: func(_, newObj interface{}) {
pod, ok := podFromInformerObj(newObj) pod, ok := podFromInformerObj(newObj)
if !ok { if !ok {
return return
} }
w.handleRestorePodEvent(ctx, pod) w.reconcileRestorePod(ctx, pod)
}, },
}); err != nil { }); err != nil {
return fmt.Errorf("failed to add restore informer handler: %w", err) return fmt.Errorf("failed to add restore informer handler: %w", err)
...@@ -171,13 +180,13 @@ func (w *Watcher) Start(ctx context.Context) error { ...@@ -171,13 +180,13 @@ func (w *Watcher) Start(ctx context.Context) error {
return fmt.Errorf("failed to sync informer caches") return fmt.Errorf("failed to sync informer caches")
} }
w.log.Info("Pod watcher started and caches synced") w.log.Info("Snapshot node controller started and caches synced")
<-ctx.Done() <-ctx.Done()
close(w.stopCh) close(w.stopCh)
return nil return nil
} }
func (w *Watcher) handleCheckpointPodEvent(ctx context.Context, pod *corev1.Pod) { func (w *NodeController) reconcileCheckpointPod(ctx context.Context, pod *corev1.Pod) {
if pod.Spec.NodeName != w.config.NodeName { if pod.Spec.NodeName != w.config.NodeName {
return return
} }
...@@ -193,8 +202,14 @@ func (w *Watcher) handleCheckpointPodEvent(ctx context.Context, pod *corev1.Pod) ...@@ -193,8 +202,14 @@ func (w *Watcher) handleCheckpointPodEvent(ctx context.Context, pod *corev1.Pod)
return return
} }
annotationStatus := pod.Annotations[kubeAnnotationCheckpointStatus] job, err := getCheckpointJob(ctx, w.clientset, pod)
if annotationStatus == "completed" || annotationStatus == "in_progress" { if err != nil {
w.log.Error(err, "Failed to resolve checkpoint job", "pod", podKey)
return
}
jobStatus := job.Annotations[kubeAnnotationCheckpointStatus]
if jobStatus == "completed" || jobStatus == "failed" {
return return
} }
...@@ -202,19 +217,37 @@ func (w *Watcher) handleCheckpointPodEvent(ctx context.Context, pod *corev1.Pod) ...@@ -202,19 +217,37 @@ func (w *Watcher) handleCheckpointPodEvent(ctx context.Context, pod *corev1.Pod)
return return
} }
checkpointLocation, checkpointStorageType, err := checkpointStorageFromPod(pod)
if err != nil {
w.release(podKey)
w.log.Error(err, "Checkpoint pod is missing storage metadata", "pod", podKey, "checkpoint_hash", checkpointHash)
return
}
acquiredLease, err := acquireCheckpointLease(ctx, w.clientset, w.log, job, w.holderID)
if err != nil {
w.release(podKey)
w.log.Error(err, "Failed to acquire checkpoint lease", "pod", podKey, "checkpoint_hash", checkpointHash)
return
}
if !acquiredLease {
w.release(podKey)
return
}
w.log.Info("Pod ready, triggering checkpoint", "pod", podKey, "checkpoint_hash", checkpointHash) w.log.Info("Pod ready, triggering checkpoint", "pod", podKey, "checkpoint_hash", checkpointHash)
emitPodEvent(ctx, w.clientset, w.log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointRequested", fmt.Sprintf("Checkpoint requested: %s", checkpointHash)) emitPodEvent(ctx, w.clientset, w.log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointRequested", fmt.Sprintf("Checkpoint requested: %s", checkpointHash))
go func() { go func() {
if err := w.doCheckpoint(ctx, pod, checkpointHash, podKey); err != nil { if err := w.runCheckpoint(ctx, pod, job, checkpointHash, checkpointLocation, checkpointStorageType, podKey); err != nil {
opLog := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash) opLog := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash)
opLog.Error(err, "Checkpoint worker failed") opLog.Error(err, "Checkpoint controller worker failed")
emitPodEvent(ctx, w.clientset, opLog, pod, "snapshot", corev1.EventTypeWarning, "CheckpointWorkerFailed", err.Error()) emitPodEvent(ctx, w.clientset, opLog, pod, "snapshot", corev1.EventTypeWarning, "CheckpointWorkerFailed", err.Error())
} }
}() }()
} }
func (w *Watcher) handleRestorePodEvent(ctx context.Context, pod *corev1.Pod) { func (w *NodeController) reconcileRestorePod(ctx context.Context, pod *corev1.Pod) {
if pod.Spec.NodeName != w.config.NodeName { if pod.Spec.NodeName != w.config.NodeName {
return return
} }
...@@ -225,17 +258,10 @@ func (w *Watcher) handleRestorePodEvent(ctx context.Context, pod *corev1.Pod) { ...@@ -225,17 +258,10 @@ func (w *Watcher) handleRestorePodEvent(ctx context.Context, pod *corev1.Pod) {
return return
} }
annotationStatus := pod.Annotations[kubeAnnotationRestoreStatus]
if isPodReady(pod) { if isPodReady(pod) {
return return
} }
// Restore failures require explicit intervention (new label/update) before retry.
if annotationStatus == "completed" || annotationStatus == "in_progress" || annotationStatus == "failed" {
return
}
checkpointHash, ok := pod.Labels[kubeLabelCheckpointHash] checkpointHash, ok := pod.Labels[kubeLabelCheckpointHash]
if !ok || checkpointHash == "" { if !ok || checkpointHash == "" {
w.log.Info("Restore pod has no checkpoint-hash label", "pod", podKey) w.log.Info("Restore pod has no checkpoint-hash label", "pod", podKey)
...@@ -247,13 +273,43 @@ func (w *Watcher) handleRestorePodEvent(ctx context.Context, pod *corev1.Pod) { ...@@ -247,13 +273,43 @@ func (w *Watcher) handleRestorePodEvent(ctx context.Context, pod *corev1.Pod) {
return return
} }
checkpointDir := filepath.Join(w.config.BasePath, checkpointHash) checkpointLocation, checkpointStorageType, err := checkpointStorageFromPod(pod)
if _, err := os.Stat(checkpointDir); os.IsNotExist(err) { if err != nil {
w.log.V(1).Info("Checkpoint not ready on disk, skipping restore", "pod", podKey, "checkpoint_hash", checkpointHash) w.log.Error(err, "Restore pod is missing storage metadata", "pod", podKey, "checkpoint_hash", checkpointHash)
return
}
if _, err := os.Stat(checkpointLocation); os.IsNotExist(err) {
w.log.V(1).Info("Checkpoint not ready on disk, skipping restore", "pod", podKey, "checkpoint_hash", checkpointHash, "checkpoint_location", checkpointLocation)
return return
} }
if !w.tryAcquire(podKey) { containerName := resolveMainContainerName(pod)
if containerName == "" {
w.log.Info("Restore pod has no containers", "pod", podKey)
return
}
containerID := ""
for _, cs := range pod.Status.ContainerStatuses {
if cs.Name != containerName || cs.ContainerID == "" {
continue
}
containerID = strings.TrimPrefix(cs.ContainerID, "containerd://")
break
}
if containerID == "" {
w.log.V(1).Info("Restore pod has no running main container yet", "pod", podKey, "container", containerName)
return
}
annotationStatus := pod.Annotations[kubeAnnotationRestoreStatus]
annotationContainerID := pod.Annotations[kubeAnnotationRestoreContainerID]
if annotationContainerID == containerID && (annotationStatus == "completed" || annotationStatus == "in_progress") {
return
}
restoreAttemptKey := fmt.Sprintf("%s/%s", podKey, containerID)
if !w.tryAcquire(restoreAttemptKey) {
return return
} }
...@@ -261,53 +317,56 @@ func (w *Watcher) handleRestorePodEvent(ctx context.Context, pod *corev1.Pod) { ...@@ -261,53 +317,56 @@ func (w *Watcher) handleRestorePodEvent(ctx context.Context, pod *corev1.Pod) {
emitPodEvent(ctx, w.clientset, w.log, pod, "snapshot", corev1.EventTypeNormal, "RestoreRequested", fmt.Sprintf("Restore requested from checkpoint %s", checkpointHash)) emitPodEvent(ctx, w.clientset, w.log, pod, "snapshot", corev1.EventTypeNormal, "RestoreRequested", fmt.Sprintf("Restore requested from checkpoint %s", checkpointHash))
go func() { go func() {
if err := w.doRestore(ctx, pod, checkpointHash, podKey); err != nil { if err := w.runRestore(ctx, pod, containerName, containerID, checkpointHash, checkpointLocation, checkpointStorageType, restoreAttemptKey); err != nil {
opLog := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash) opLog := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash)
opLog.Error(err, "Restore worker failed") opLog.Error(err, "Restore controller worker failed")
emitPodEvent(ctx, w.clientset, opLog, pod, "snapshot", corev1.EventTypeWarning, "RestoreWorkerFailed", err.Error()) emitPodEvent(ctx, w.clientset, opLog, pod, "snapshot", corev1.EventTypeWarning, "RestoreWorkerFailed", err.Error())
} }
}() }()
} }
// doCheckpoint runs the full checkpoint workflow for a pod: // runCheckpoint runs the full checkpoint workflow for a pod:
// 1. Mark pod as in_progress // 1. Hold and renew the checkpoint lease
// 2. Resolve the container ID and host PID // 2. Resolve the container ID and host PID
// 3. Call orchestrate.Checkpoint (inspect → configure → CUDA lock/checkpoint → CRIU dump → rootfs diff) // 3. Call executor.Checkpoint (inspect → configure → CUDA lock/checkpoint → CRIU dump → rootfs diff)
// 4. SIGUSR1 the process on success (notify workload), SIGKILL on failure (terminate immediately) // 4. SIGUSR1 the process on success (notify workload), SIGKILL on failure (terminate immediately)
// 5. Mark pod as completed or failed // 5. Mark job as completed or failed
func (w *Watcher) doCheckpoint(ctx context.Context, pod *corev1.Pod, checkpointHash, podKey string) error { func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job *batchv1.Job, checkpointHash, checkpointLocation, checkpointStorageType, podKey string) error {
releaseOnExit := true releasePodOnExit := true
defer func() { defer func() {
if releaseOnExit { if releasePodOnExit {
w.release(podKey) w.release(podKey)
} }
}() }()
log := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash) log := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash)
setCheckpointStatus := func(value string) error { leaseCtx, stopLease := context.WithCancelCause(ctx)
annotations := map[string]string{ defer stopLease(nil)
kubeAnnotationCheckpointStatus: value,
}
if value == "failed" || value == "completed" { releaseLeaseOnExit := true
if err := annotatePodRetry(ctx, w.clientset, log, pod, annotations); err != nil { defer func() {
releaseOnExit = false if !releaseLeaseOnExit {
return fmt.Errorf("failed to persist terminal checkpoint status %q: %w", value, err) return
}
return nil
} }
releaseCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := releaseCheckpointLease(releaseCtx, w.clientset, log, job, w.holderID); err != nil {
log.Error(err, "Failed to release checkpoint lease")
}
}()
if err := annotatePod(ctx, w.clientset, log, pod, annotations); err != nil { go w.renewCheckpointLease(leaseCtx, log, job, stopLease)
return fmt.Errorf("failed to update checkpoint status %q: %w", value, err)
setCheckpointStatus := func(value string) error {
if err := annotateJob(ctx, w.clientset, log, job, map[string]string{
kubeAnnotationCheckpointStatus: value,
}); err != nil {
releasePodOnExit = false
releaseLeaseOnExit = false
return fmt.Errorf("failed to persist terminal checkpoint status %q: %w", value, err)
} }
return nil return nil
} }
if err := annotatePod(ctx, w.clientset, log, pod, map[string]string{
kubeAnnotationCheckpointStatus: "in_progress",
}); err != nil {
return fmt.Errorf("failed to annotate pod with checkpoint in_progress: %w", err)
}
// Resolve the target container // Resolve the target container
containerName := resolveMainContainerName(pod) containerName := resolveMainContainerName(pod)
if containerName == "" { if containerName == "" {
...@@ -346,16 +405,20 @@ func (w *Watcher) doCheckpoint(ctx context.Context, pod *corev1.Pod, checkpointH ...@@ -346,16 +405,20 @@ func (w *Watcher) doCheckpoint(ctx context.Context, pod *corev1.Pod, checkpointH
} }
// Step 1: Run the checkpoint orchestrator // Step 1: Run the checkpoint orchestrator
req := orchestrate.CheckpointRequest{ req := executor.CheckpointRequest{
ContainerID: containerID, ContainerID: containerID,
ContainerName: containerName, ContainerName: containerName,
CheckpointHash: checkpointHash, CheckpointHash: checkpointHash,
CheckpointDir: w.config.BasePath, CheckpointLocation: checkpointLocation,
NodeName: w.config.NodeName, CheckpointStorageType: checkpointStorageType,
PodName: pod.Name, NodeName: w.config.NodeName,
PodNamespace: pod.Namespace, PodName: pod.Name,
} PodNamespace: pod.Namespace,
if err := orchestrate.Checkpoint(ctx, w.containerd, log, req, w.config); err != nil { }
if err := executor.Checkpoint(leaseCtx, w.containerd, log, req, w.config); err != nil {
if cause := context.Cause(leaseCtx); cause != nil && cause != context.Canceled {
err = fmt.Errorf("checkpoint lease lost: %w", cause)
}
log.Error(err, "Checkpoint failed") log.Error(err, "Checkpoint failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error()) emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
// SIGKILL on failure: process is unrecoverable (CUDA locked), terminate immediately // SIGKILL on failure: process is unrecoverable (CUDA locked), terminate immediately
...@@ -368,6 +431,24 @@ func (w *Watcher) doCheckpoint(ctx context.Context, pod *corev1.Pod, checkpointH ...@@ -368,6 +431,24 @@ func (w *Watcher) doCheckpoint(ctx context.Context, pod *corev1.Pod, checkpointH
return nil return nil
} }
info, err := os.Stat(checkpointLocation)
if err != nil || !info.IsDir() {
if err == nil {
err = fmt.Errorf("published checkpoint path %s is not a directory", checkpointLocation)
} else {
err = fmt.Errorf("published checkpoint path %s is missing: %w", checkpointLocation, err)
}
log.Error(err, "Checkpoint failed verification")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
if signalErr := common.SendSignalToPID(log, containerPID, syscall.SIGKILL, "checkpoint verification failed"); signalErr != nil {
log.Error(signalErr, "Failed to signal checkpoint verification failure to runtime process")
}
if statusErr := setCheckpointStatus("failed"); statusErr != nil {
return statusErr
}
return nil
}
// Step 2: SIGUSR1 on success: notify the workload that checkpoint completed // Step 2: SIGUSR1 on success: notify the workload that checkpoint completed
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointSucceeded", fmt.Sprintf("Checkpoint completed: %s", checkpointHash)) emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointSucceeded", fmt.Sprintf("Checkpoint completed: %s", checkpointHash))
if err := common.SendSignalToPID(log, containerPID, syscall.SIGUSR1, "checkpoint complete"); err != nil { if err := common.SendSignalToPID(log, containerPID, syscall.SIGUSR1, "checkpoint complete"); err != nil {
...@@ -385,71 +466,65 @@ func (w *Watcher) doCheckpoint(ctx context.Context, pod *corev1.Pod, checkpointH ...@@ -385,71 +466,65 @@ func (w *Watcher) doCheckpoint(ctx context.Context, pod *corev1.Pod, checkpointH
return nil return nil
} }
// doRestore runs the full restore workflow for a pod: // runRestore runs the full restore workflow for a pod:
// 1. Mark pod as in_progress // 1. Mark the current container instance as in_progress
// 2. Call orchestrate.Restore (inspect placeholder → nsrestore inside namespace) // 2. Call executor.Restore (inspect placeholder → nsrestore inside namespace)
// 3. SIGCONT the restored process to wake it up // 3. SIGCONT the restored process to wake it up
// 4. Wait for the pod to become Ready // 4. Wait for the pod to become Ready
// 5. Mark pod as completed or failed // 5. Mark the container instance as completed
func (w *Watcher) doRestore(ctx context.Context, pod *corev1.Pod, checkpointHash, podKey string) error { func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, containerName, containerID, checkpointHash, checkpointLocation, checkpointStorageType, restoreAttemptKey string) error {
releaseOnExit := true releaseOnExit := true
defer func() { defer func() {
if releaseOnExit { if releaseOnExit {
w.release(podKey) w.release(restoreAttemptKey)
} }
}() }()
log := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash) podKey := fmt.Sprintf("%s/%s", pod.Namespace, pod.Name)
log := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash, "container_id", containerID)
setRestoreStatus := func(value string) error { setRestoreStatus := func(value string) error {
annotations := map[string]string{ annotations := map[string]string{
kubeAnnotationRestoreStatus: value, kubeAnnotationRestoreStatus: value,
kubeAnnotationRestoreContainerID: containerID,
} }
if err := annotatePod(ctx, w.clientset, log, pod, annotations); err != nil {
if value == "failed" || value == "completed" { if value == "completed" {
if err := annotatePodRetry(ctx, w.clientset, log, pod, annotations); err != nil {
releaseOnExit = false releaseOnExit = false
return fmt.Errorf("failed to persist terminal restore status %q: %w", value, err) return fmt.Errorf("failed to persist terminal restore status %q: %w", value, err)
} }
return nil
}
if err := annotatePod(ctx, w.clientset, log, pod, annotations); err != nil {
return fmt.Errorf("failed to update restore status %q: %w", value, err) return fmt.Errorf("failed to update restore status %q: %w", value, err)
} }
return nil return nil
} }
if err := annotatePod(ctx, w.clientset, log, pod, map[string]string{ if err := annotatePod(ctx, w.clientset, log, pod, map[string]string{
kubeAnnotationRestoreStatus: "in_progress", kubeAnnotationRestoreStatus: "in_progress",
kubeAnnotationRestoreContainerID: containerID,
}); err != nil { }); err != nil {
return fmt.Errorf("failed to annotate pod with restore in_progress: %w", err) return fmt.Errorf("failed to annotate pod with restore in_progress: %w", err)
} }
containerName := resolveMainContainerName(pod)
if containerName == "" {
err := fmt.Errorf("no containers found in pod spec")
log.Error(err, "Restore failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if statusErr := setRestoreStatus("failed"); statusErr != nil {
return statusErr
}
return nil
}
// Step 1: Run the restore orchestrator (inspect + nsrestore) // Step 1: Run the restore orchestrator (inspect + nsrestore)
req := orchestrate.RestoreRequest{ req := executor.RestoreRequest{
CheckpointHash: checkpointHash, CheckpointHash: checkpointHash,
CheckpointBase: w.config.BasePath, CheckpointLocation: checkpointLocation,
NSRestorePath: w.config.Restore.NSRestorePath, CheckpointStorageType: checkpointStorageType,
PodName: pod.Name, NSRestorePath: w.config.Restore.NSRestorePath,
PodNamespace: pod.Namespace, PodName: pod.Name,
ContainerName: containerName, PodNamespace: pod.Namespace,
} ContainerName: containerName,
restoredPID, err := orchestrate.Restore(ctx, w.containerd, log, req) }
restoredPID, err := executor.Restore(ctx, w.containerd, log, req)
if err != nil { if err != nil {
log.Error(err, "External restore failed") log.Error(err, "External restore failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error()) emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if statusErr := setRestoreStatus("failed"); statusErr != nil { placeholderHostPID, _, pidErr := common.ResolveContainerByPod(ctx, w.containerd, pod.Name, pod.Namespace, containerName)
return statusErr if pidErr != nil {
releaseOnExit = false
return fmt.Errorf("restore failed and placeholder PID could not be resolved: %w", pidErr)
}
if killErr := common.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore failed"); killErr != nil {
releaseOnExit = false
return fmt.Errorf("restore failed and placeholder could not be killed: %w", killErr)
} }
return nil return nil
} }
...@@ -459,18 +534,17 @@ func (w *Watcher) doRestore(ctx context.Context, pod *corev1.Pod, checkpointHash ...@@ -459,18 +534,17 @@ func (w *Watcher) doRestore(ctx context.Context, pod *corev1.Pod, checkpointHash
if err != nil { if err != nil {
log.Error(err, "Failed to resolve placeholder host PID for signaling") log.Error(err, "Failed to resolve placeholder host PID for signaling")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error()) emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if statusErr := setRestoreStatus("failed"); statusErr != nil { releaseOnExit = false
return statusErr return fmt.Errorf("failed to resolve placeholder host PID for signaling: %w", err)
}
return nil
} }
if err := common.SendSignalViaPIDNamespace(ctx, log, placeholderHostPID, restoredPID, syscall.SIGCONT, "restore complete"); err != nil { if err := common.SendSignalViaPIDNamespace(ctx, log, placeholderHostPID, restoredPID, syscall.SIGCONT, "restore complete"); err != nil {
log.Error(err, "Failed to signal restored runtime process") log.Error(err, "Failed to signal restored runtime process")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error()) emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if statusErr := setRestoreStatus("failed"); statusErr != nil { if killErr := common.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore signaling failed"); killErr != nil {
return statusErr log.Error(killErr, "Failed to kill placeholder after restore signaling failure")
} }
return nil releaseOnExit = false
return fmt.Errorf("failed to signal restored runtime process: %w", err)
} }
// Step 3: Wait for the pod to become Ready // Step 3: Wait for the pod to become Ready
...@@ -483,10 +557,11 @@ func (w *Watcher) doRestore(ctx context.Context, pod *corev1.Pod, checkpointHash ...@@ -483,10 +557,11 @@ func (w *Watcher) doRestore(ctx context.Context, pod *corev1.Pod, checkpointHash
if err := waitForPodReady(readyCtx, w.clientset, pod.Namespace, pod.Name, containerName); err != nil { if err := waitForPodReady(readyCtx, w.clientset, pod.Namespace, pod.Name, containerName); err != nil {
log.Error(err, "Restore post-signal readiness check failed") log.Error(err, "Restore post-signal readiness check failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error()) emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if statusErr := setRestoreStatus("failed"); statusErr != nil { if killErr := common.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore readiness failed"); killErr != nil {
return statusErr log.Error(killErr, "Failed to kill placeholder after restore readiness failure")
} }
return nil releaseOnExit = false
return fmt.Errorf("restore post-signal readiness check failed: %w", err)
} }
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "RestoreSucceeded", fmt.Sprintf("Restore completed from checkpoint %s", checkpointHash)) emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "RestoreSucceeded", fmt.Sprintf("Restore completed from checkpoint %s", checkpointHash))
...@@ -496,7 +571,7 @@ func (w *Watcher) doRestore(ctx context.Context, pod *corev1.Pod, checkpointHash ...@@ -496,7 +571,7 @@ func (w *Watcher) doRestore(ctx context.Context, pod *corev1.Pod, checkpointHash
return nil return nil
} }
func (w *Watcher) tryAcquire(podKey string) bool { func (w *NodeController) tryAcquire(podKey string) bool {
w.inFlightMu.Lock() w.inFlightMu.Lock()
defer w.inFlightMu.Unlock() defer w.inFlightMu.Unlock()
if _, held := w.inFlight[podKey]; held { if _, held := w.inFlight[podKey]; held {
...@@ -506,8 +581,25 @@ func (w *Watcher) tryAcquire(podKey string) bool { ...@@ -506,8 +581,25 @@ func (w *Watcher) tryAcquire(podKey string) bool {
return true return true
} }
func (w *Watcher) release(podKey string) { func (w *NodeController) release(podKey string) {
w.inFlightMu.Lock() w.inFlightMu.Lock()
defer w.inFlightMu.Unlock() defer w.inFlightMu.Unlock()
delete(w.inFlight, podKey) delete(w.inFlight, podKey)
} }
func checkpointStorageFromPod(pod *corev1.Pod) (string, string, error) {
checkpointLocation := strings.TrimSpace(pod.Annotations[kubeAnnotationCheckpointLocation])
if checkpointLocation == "" {
return "", "", fmt.Errorf("missing %s annotation", kubeAnnotationCheckpointLocation)
}
checkpointStorageType := strings.TrimSpace(pod.Annotations[kubeAnnotationCheckpointStorageType])
if checkpointStorageType == "" {
return "", "", fmt.Errorf("missing %s annotation", kubeAnnotationCheckpointStorageType)
}
if checkpointStorageType != "pvc" {
return "", "", fmt.Errorf("checkpoint storage type %q is not supported", checkpointStorageType)
}
return checkpointLocation, checkpointStorageType, nil
}
package watcher package controller
import ( import (
"context" "context"
...@@ -9,6 +9,8 @@ import ( ...@@ -9,6 +9,8 @@ import (
"time" "time"
"github.com/go-logr/logr/testr" "github.com/go-logr/logr/testr"
batchv1 "k8s.io/api/batch/v1"
coordinationv1 "k8s.io/api/coordination/v1"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime"
...@@ -19,24 +21,42 @@ import ( ...@@ -19,24 +21,42 @@ import (
) )
const testNodeName = "test-node" const testNodeName = "test-node"
const testContainerID = "test-container"
// makeTestWatcher creates a Watcher with a fake k8s client and nil orchestrators. // makeTestController creates a NodeController with a fake k8s client and nil executors.
// The fake clientset is empty so any goroutine launched by doCheckpoint/doRestore // The fake clientset is empty so any goroutine launched by runCheckpoint/runRestore
// will fail on the first annotatePod call and exit cleanly. // will fail on the first annotatePod call and exit cleanly.
func makeTestWatcher(t *testing.T) *Watcher { func makeTestController(t *testing.T, objs ...runtime.Object) *NodeController {
t.Helper() t.Helper()
return &Watcher{ return &NodeController{
config: &types.AgentConfig{ config: &types.AgentConfig{
NodeName: testNodeName, NodeName: testNodeName,
BasePath: t.TempDir(),
}, },
clientset: fake.NewClientset(), clientset: fake.NewClientset(objs...),
log: testr.New(t), log: testr.New(t),
holderID: "test-holder",
inFlight: make(map[string]struct{}), inFlight: make(map[string]struct{}),
stopCh: make(chan struct{}), stopCh: make(chan struct{}),
} }
} }
func makeLease(namespace, name, holder string, renewTime time.Time) *coordinationv1.Lease {
leaseDurationSeconds := int32(checkpointLeaseDuration.Seconds())
renewMicroTime := metav1.NewMicroTime(renewTime)
return &coordinationv1.Lease{
ObjectMeta: metav1.ObjectMeta{
Name: name,
Namespace: namespace,
},
Spec: coordinationv1.LeaseSpec{
HolderIdentity: &holder,
LeaseDurationSeconds: &leaseDurationSeconds,
AcquireTime: &renewMicroTime,
RenewTime: &renewMicroTime,
},
}
}
func makePod(name, namespace, nodeName string, phase corev1.PodPhase, ready bool, labels, annotations map[string]string) *corev1.Pod { func makePod(name, namespace, nodeName string, phase corev1.PodPhase, ready bool, labels, annotations map[string]string) *corev1.Pod {
var conditions []corev1.PodCondition var conditions []corev1.PodCondition
if ready { if ready {
...@@ -65,7 +85,7 @@ func makePod(name, namespace, nodeName string, phase corev1.PodPhase, ready bool ...@@ -65,7 +85,7 @@ func makePod(name, namespace, nodeName string, phase corev1.PodPhase, ready bool
} }
} }
func TestHandleCheckpointPodEvent(t *testing.T) { func TestReconcileCheckpointPod(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
nodeName string nodeName string
...@@ -73,6 +93,7 @@ func TestHandleCheckpointPodEvent(t *testing.T) { ...@@ -73,6 +93,7 @@ func TestHandleCheckpointPodEvent(t *testing.T) {
ready bool ready bool
hash string hash string
annotation string annotation string
lease *coordinationv1.Lease
preSeed bool // pre-populate inFlight to test deduplication preSeed bool // pre-populate inFlight to test deduplication
want bool // true = pod passes filtering and triggers checkpoint want bool // true = pod passes filtering and triggers checkpoint
}{ }{
...@@ -126,14 +147,32 @@ func TestHandleCheckpointPodEvent(t *testing.T) { ...@@ -126,14 +147,32 @@ func TestHandleCheckpointPodEvent(t *testing.T) {
want: false, want: false,
}, },
{ {
name: "already in progress", name: "already failed",
nodeName: testNodeName, nodeName: testNodeName,
phase: corev1.PodRunning, phase: corev1.PodRunning,
ready: true, ready: true,
hash: "abc123", hash: "abc123",
annotation: "in_progress", annotation: "failed",
want: false, want: false,
}, },
{
name: "active lease held elsewhere",
nodeName: testNodeName,
phase: corev1.PodRunning,
ready: true,
hash: "abc123",
lease: makeLease("default", "checkpoint-job", "other-holder", time.Now()),
want: false,
},
{
name: "expired lease can be reclaimed",
nodeName: testNodeName,
phase: corev1.PodRunning,
ready: true,
hash: "abc123",
lease: makeLease("default", "checkpoint-job", "other-holder", time.Now().Add(-checkpointLeaseDuration-time.Second)),
want: true,
},
{ {
name: "duplicate in-flight", name: "duplicate in-flight",
nodeName: testNodeName, nodeName: testNodeName,
...@@ -148,28 +187,46 @@ func TestHandleCheckpointPodEvent(t *testing.T) { ...@@ -148,28 +187,46 @@ func TestHandleCheckpointPodEvent(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
labels := map[string]string{ labels := map[string]string{
kubeLabelIsCheckpointSource: "true", kubeLabelIsCheckpointSource: "true",
"batch.kubernetes.io/job-name": "checkpoint-job",
} }
if tc.hash != "" { if tc.hash != "" {
labels[kubeLabelCheckpointHash] = tc.hash labels[kubeLabelCheckpointHash] = tc.hash
} }
var annotations map[string]string job := &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{
Name: "checkpoint-job",
Namespace: "default",
},
}
if tc.annotation != "" { if tc.annotation != "" {
annotations = map[string]string{ job.Annotations = map[string]string{
kubeAnnotationCheckpointStatus: tc.annotation, kubeAnnotationCheckpointStatus: tc.annotation,
} }
} }
var annotations map[string]string
if tc.hash != "" {
annotations = map[string]string{
kubeAnnotationCheckpointLocation: "/checkpoints/" + tc.hash,
kubeAnnotationCheckpointStorageType: "pvc",
}
}
pod := makePod("test-pod", "default", tc.nodeName, tc.phase, tc.ready, labels, annotations) pod := makePod("test-pod", "default", tc.nodeName, tc.phase, tc.ready, labels, annotations)
w := makeTestWatcher(t) objs := []runtime.Object{job}
if tc.lease != nil {
objs = append(objs, tc.lease)
}
w := makeTestController(t, objs...)
ctx := context.Background() ctx := context.Background()
if tc.preSeed { if tc.preSeed {
w.inFlight["default/test-pod"] = struct{}{} w.inFlight["default/test-pod"] = struct{}{}
} }
w.handleCheckpointPodEvent(ctx, pod) w.reconcileCheckpointPod(ctx, pod)
// tryAcquire adds to inFlight synchronously before launching the goroutine. // tryAcquire adds to inFlight synchronously before launching the goroutine.
// For filtered pods, inFlight stays at its original size. // For filtered pods, inFlight stays at its original size.
...@@ -191,17 +248,18 @@ func TestHandleCheckpointPodEvent(t *testing.T) { ...@@ -191,17 +248,18 @@ func TestHandleCheckpointPodEvent(t *testing.T) {
} }
} }
func TestHandleRestorePodEvent(t *testing.T) { func TestReconcileRestorePod(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
nodeName string nodeName string
phase corev1.PodPhase phase corev1.PodPhase
ready bool ready bool
hash string hash string
annotation string annotationStatus string
createDir bool // whether to create the checkpoint dir on disk annotationContainerID string
preSeed bool createDir bool // whether to create the checkpoint dir on disk
want bool preSeed bool
want bool
}{ }{
{ {
name: "happy path", name: "happy path",
...@@ -257,34 +315,48 @@ func TestHandleRestorePodEvent(t *testing.T) { ...@@ -257,34 +315,48 @@ func TestHandleRestorePodEvent(t *testing.T) {
want: false, want: false,
}, },
{ {
name: "already completed", name: "already completed for same container",
nodeName: testNodeName, nodeName: testNodeName,
phase: corev1.PodRunning, phase: corev1.PodRunning,
ready: false, ready: false,
hash: "abc123", hash: "abc123",
annotation: "completed", annotationStatus: "completed",
createDir: true, annotationContainerID: testContainerID,
want: false, createDir: true,
want: false,
}, },
{ {
name: "already in progress", name: "already in progress for same container",
nodeName: testNodeName, nodeName: testNodeName,
phase: corev1.PodRunning, phase: corev1.PodRunning,
ready: false, ready: false,
hash: "abc123", hash: "abc123",
annotation: "in_progress", annotationStatus: "in_progress",
createDir: true, annotationContainerID: testContainerID,
want: false, createDir: true,
want: false,
}, },
{ {
name: "already failed", name: "completed for previous container retries",
nodeName: testNodeName, nodeName: testNodeName,
phase: corev1.PodRunning, phase: corev1.PodRunning,
ready: false, ready: false,
hash: "abc123", hash: "abc123",
annotation: "failed", annotationStatus: "completed",
createDir: true, annotationContainerID: "old-container",
want: false, createDir: true,
want: true,
},
{
name: "in progress for previous container retries",
nodeName: testNodeName,
phase: corev1.PodRunning,
ready: false,
hash: "abc123",
annotationStatus: "in_progress",
annotationContainerID: "old-container",
createDir: true,
want: true,
}, },
{ {
name: "checkpoint not on disk", name: "checkpoint not on disk",
...@@ -316,18 +388,33 @@ func TestHandleRestorePodEvent(t *testing.T) { ...@@ -316,18 +388,33 @@ func TestHandleRestorePodEvent(t *testing.T) {
labels[kubeLabelCheckpointHash] = tc.hash labels[kubeLabelCheckpointHash] = tc.hash
} }
w := makeTestController(t)
checkpointDir := t.TempDir()
var annotations map[string]string var annotations map[string]string
if tc.annotation != "" { if tc.annotationStatus != "" {
annotations = map[string]string{ annotations = map[string]string{
kubeAnnotationRestoreStatus: tc.annotation, kubeAnnotationRestoreStatus: tc.annotationStatus,
kubeAnnotationRestoreContainerID: tc.annotationContainerID,
}
}
if tc.hash != "" {
if annotations == nil {
annotations = make(map[string]string)
} }
annotations[kubeAnnotationCheckpointLocation] = filepath.Join(checkpointDir, tc.hash)
annotations[kubeAnnotationCheckpointStorageType] = "pvc"
} }
pod := makePod("test-pod", "default", tc.nodeName, tc.phase, tc.ready, labels, annotations) pod := makePod("test-pod", "default", tc.nodeName, tc.phase, tc.ready, labels, annotations)
w := makeTestWatcher(t) pod.Status.ContainerStatuses = []corev1.ContainerStatus{{
Name: "main",
Ready: tc.ready,
ContainerID: "containerd://" + testContainerID,
}}
if tc.createDir && tc.hash != "" { if tc.createDir && tc.hash != "" {
dir := filepath.Join(w.config.BasePath, tc.hash) dir := filepath.Join(checkpointDir, tc.hash)
if err := os.MkdirAll(dir, 0o755); err != nil { if err := os.MkdirAll(dir, 0o755); err != nil {
t.Fatalf("failed to create checkpoint dir: %v", err) t.Fatalf("failed to create checkpoint dir: %v", err)
} }
...@@ -336,10 +423,10 @@ func TestHandleRestorePodEvent(t *testing.T) { ...@@ -336,10 +423,10 @@ func TestHandleRestorePodEvent(t *testing.T) {
ctx := context.Background() ctx := context.Background()
if tc.preSeed { if tc.preSeed {
w.inFlight["default/test-pod"] = struct{}{} w.inFlight["default/test-pod/"+testContainerID] = struct{}{}
} }
w.handleRestorePodEvent(ctx, pod) w.reconcileRestorePod(ctx, pod)
triggered := len(w.inFlight) > 0 && !tc.preSeed triggered := len(w.inFlight) > 0 && !tc.preSeed
if tc.preSeed { if tc.preSeed {
...@@ -358,91 +445,60 @@ func TestHandleRestorePodEvent(t *testing.T) { ...@@ -358,91 +445,60 @@ func TestHandleRestorePodEvent(t *testing.T) {
} }
} }
func TestDoCheckpointKeepsInFlightOnTerminalStatusPatchFailure(t *testing.T) { func TestRunCheckpointKeepsLeaseAndInFlightOnTerminalStatusPatchFailure(t *testing.T) {
pod := &corev1.Pod{ pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{ ObjectMeta: metav1.ObjectMeta{
Name: "test-pod", Name: "test-pod",
Namespace: "default", Namespace: "default",
Labels: map[string]string{
"batch.kubernetes.io/job-name": "checkpoint-job",
},
},
}
job := &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{
Name: "checkpoint-job",
Namespace: "default",
}, },
} }
lease := makeLease("default", "checkpoint-job", "test-holder", time.Now())
clientset := fake.NewClientset(pod.DeepCopy()) clientset := fake.NewClientset(pod.DeepCopy(), job, lease)
patchCalls := 0 patchCalls := 0
clientset.PrependReactor("patch", "pods", func(clientgotesting.Action) (bool, runtime.Object, error) { clientset.PrependReactor("patch", "jobs", func(clientgotesting.Action) (bool, runtime.Object, error) {
patchCalls++ patchCalls++
if patchCalls == 1 {
return false, nil, nil
}
return true, nil, errors.New("terminal patch failed") return true, nil, errors.New("terminal patch failed")
}) })
w := &Watcher{ w := &NodeController{
config: &types.AgentConfig{ config: &types.AgentConfig{
NodeName: testNodeName, NodeName: testNodeName,
BasePath: t.TempDir(),
}, },
clientset: clientset, clientset: clientset,
log: testr.New(t), log: testr.New(t),
holderID: "test-holder",
inFlight: map[string]struct{}{ inFlight: map[string]struct{}{
"default/test-pod": {}, "default/test-pod": {},
}, },
stopCh: make(chan struct{}), stopCh: make(chan struct{}),
} }
err := w.doCheckpoint(context.Background(), pod, "abc123", "default/test-pod") err := w.runCheckpoint(context.Background(), pod, job, "abc123", filepath.Join(t.TempDir(), "abc123"), "pvc", "default/test-pod")
if err == nil { if err == nil {
t.Fatal("expected terminal checkpoint status update to fail") t.Fatal("expected terminal checkpoint status update to fail")
} }
if _, ok := w.inFlight["default/test-pod"]; !ok { if _, ok := w.inFlight["default/test-pod"]; !ok {
t.Fatal("checkpoint terminal status failure should keep pod in-flight") t.Fatal("checkpoint terminal status failure should keep pod in-flight")
} }
if patchCalls != 1+terminalStatusPatchRetryAttempts { if patchCalls != 1 {
t.Fatalf("patchCalls = %d, want %d", patchCalls, 1+terminalStatusPatchRetryAttempts) t.Fatalf("patchCalls = %d, want %d", patchCalls, 1)
}
}
func TestDoRestoreKeepsInFlightOnTerminalStatusPatchFailure(t *testing.T) {
pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "test-pod",
Namespace: "default",
},
Status: corev1.PodStatus{
Phase: corev1.PodRunning,
},
}
clientset := fake.NewClientset(pod.DeepCopy())
patchCalls := 0
clientset.PrependReactor("patch", "pods", func(clientgotesting.Action) (bool, runtime.Object, error) {
patchCalls++
if patchCalls == 1 {
return false, nil, nil
}
return true, nil, errors.New("terminal patch failed")
})
w := &Watcher{
config: &types.AgentConfig{
NodeName: testNodeName,
BasePath: t.TempDir(),
},
clientset: clientset,
log: testr.New(t),
inFlight: map[string]struct{}{
"default/test-pod": {},
},
stopCh: make(chan struct{}),
} }
err := w.doRestore(context.Background(), pod, "abc123", "default/test-pod") remainingLease, err := clientset.CoordinationV1().Leases("default").Get(context.Background(), "checkpoint-job", metav1.GetOptions{})
if err == nil { if err != nil {
t.Fatal("expected terminal restore status update to fail") t.Fatalf("expected checkpoint lease to remain after terminal status patch failure: %v", err)
}
if _, ok := w.inFlight["default/test-pod"]; !ok {
t.Fatal("restore terminal status failure should keep pod in-flight")
} }
if patchCalls != 1+terminalStatusPatchRetryAttempts { if remainingLease.Spec.HolderIdentity == nil || *remainingLease.Spec.HolderIdentity != "test-holder" {
t.Fatalf("patchCalls = %d, want %d", patchCalls, 1+terminalStatusPatchRetryAttempts) t.Fatalf("unexpected remaining lease holder: %#v", remainingLease.Spec.HolderIdentity)
} }
} }
package watcher package controller
import ( import (
"context" "context"
...@@ -7,7 +7,10 @@ import ( ...@@ -7,7 +7,10 @@ import (
"time" "time"
"github.com/go-logr/logr" "github.com/go-logr/logr"
batchv1 "k8s.io/api/batch/v1"
coordinationv1 "k8s.io/api/coordination/v1"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
ktypes "k8s.io/apimachinery/pkg/types" ktypes "k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes"
...@@ -15,8 +18,8 @@ import ( ...@@ -15,8 +18,8 @@ import (
) )
const ( const (
terminalStatusPatchRetryAttempts = 3 checkpointLeaseDuration = 30 * time.Second
terminalStatusPatchRetryDelay = 10 * time.Millisecond checkpointLeaseRenewInterval = 10 * time.Second
) )
func podFromInformerObj(obj interface{}) (*corev1.Pod, bool) { func podFromInformerObj(obj interface{}) (*corev1.Pod, bool) {
...@@ -78,30 +81,176 @@ func annotatePod(ctx context.Context, clientset kubernetes.Interface, log logr.L ...@@ -78,30 +81,176 @@ func annotatePod(ctx context.Context, clientset kubernetes.Interface, log logr.L
return err return err
} }
func annotatePodRetry(ctx context.Context, clientset kubernetes.Interface, log logr.Logger, pod *corev1.Pod, annotations map[string]string) error { func getCheckpointJob(ctx context.Context, clientset kubernetes.Interface, pod *corev1.Pod) (*batchv1.Job, error) {
delay := terminalStatusPatchRetryDelay jobName := pod.Labels["batch.kubernetes.io/job-name"]
var lastErr error if jobName == "" {
return nil, fmt.Errorf("pod %s/%s has no batch.kubernetes.io/job-name label", pod.Namespace, pod.Name)
}
for attempt := 1; attempt <= terminalStatusPatchRetryAttempts; attempt++ { job, err := clientset.BatchV1().Jobs(pod.Namespace).Get(ctx, jobName, metav1.GetOptions{})
if err := annotatePod(ctx, clientset, log, pod, annotations); err == nil { if err != nil {
return nil return nil, fmt.Errorf("failed to get checkpoint job %s/%s: %w", pod.Namespace, jobName, err)
} else { }
lastErr = err return job, nil
}
func isLeaseExpired(lease *coordinationv1.Lease, now time.Time) bool {
if lease == nil || lease.Spec.LeaseDurationSeconds == nil {
return true
}
last := lease.Spec.RenewTime
if last == nil {
last = lease.Spec.AcquireTime
}
if last == nil {
return true
}
return now.After(last.Time.Add(time.Duration(*lease.Spec.LeaseDurationSeconds) * time.Second))
}
func acquireCheckpointLease(ctx context.Context, clientset kubernetes.Interface, log logr.Logger, job *batchv1.Job, holderIdentity string) (bool, error) {
leaseName := job.Name
now := metav1.NewMicroTime(time.Now())
leaseDurationSeconds := int32(checkpointLeaseDuration.Seconds())
leaseClient := clientset.CoordinationV1().Leases(job.Namespace)
existingLease, err := leaseClient.Get(ctx, leaseName, metav1.GetOptions{})
if err != nil {
if !apierrors.IsNotFound(err) {
return false, fmt.Errorf("failed to get checkpoint lease %s/%s: %w", job.Namespace, leaseName, err)
}
lease := &coordinationv1.Lease{
ObjectMeta: metav1.ObjectMeta{
Name: leaseName,
Namespace: job.Namespace,
},
Spec: coordinationv1.LeaseSpec{
HolderIdentity: &holderIdentity,
LeaseDurationSeconds: &leaseDurationSeconds,
AcquireTime: &now,
RenewTime: &now,
},
}
if _, err := leaseClient.Create(ctx, lease, metav1.CreateOptions{}); err != nil {
if apierrors.IsAlreadyExists(err) {
return false, nil
}
return false, fmt.Errorf("failed to create checkpoint lease %s/%s: %w", job.Namespace, leaseName, err)
} }
return true, nil
}
if !isLeaseExpired(existingLease, now.Time) &&
existingLease.Spec.HolderIdentity != nil &&
*existingLease.Spec.HolderIdentity != holderIdentity {
return false, nil
}
existingLease.Spec.HolderIdentity = &holderIdentity
existingLease.Spec.LeaseDurationSeconds = &leaseDurationSeconds
if existingLease.Spec.AcquireTime == nil || isLeaseExpired(existingLease, now.Time) {
existingLease.Spec.AcquireTime = &now
}
existingLease.Spec.RenewTime = &now
if attempt == terminalStatusPatchRetryAttempts { if _, err := leaseClient.Update(ctx, existingLease, metav1.UpdateOptions{}); err != nil {
break if apierrors.IsConflict(err) {
log.V(1).Info("Checkpoint lease update conflicted", "lease", fmt.Sprintf("%s/%s", job.Namespace, leaseName))
return false, nil
} }
return false, fmt.Errorf("failed to update checkpoint lease %s/%s: %w", job.Namespace, leaseName, err)
}
return true, nil
}
func renewCheckpointLease(ctx context.Context, clientset kubernetes.Interface, job *batchv1.Job, holderIdentity string) error {
leaseName := job.Name
leaseClient := clientset.CoordinationV1().Leases(job.Namespace)
lease, err := leaseClient.Get(ctx, leaseName, metav1.GetOptions{})
if err != nil {
return fmt.Errorf("failed to get checkpoint lease %s/%s for renewal: %w", job.Namespace, leaseName, err)
}
if lease.Spec.HolderIdentity == nil || *lease.Spec.HolderIdentity != holderIdentity {
return fmt.Errorf("checkpoint lease %s/%s is no longer held by %q", job.Namespace, leaseName, holderIdentity)
}
now := metav1.NewMicroTime(time.Now())
leaseDurationSeconds := int32(checkpointLeaseDuration.Seconds())
lease.Spec.LeaseDurationSeconds = &leaseDurationSeconds
lease.Spec.RenewTime = &now
if _, err := leaseClient.Update(ctx, lease, metav1.UpdateOptions{}); err != nil {
return fmt.Errorf("failed to renew checkpoint lease %s/%s: %w", job.Namespace, leaseName, err)
}
return nil
}
func releaseCheckpointLease(ctx context.Context, clientset kubernetes.Interface, log logr.Logger, job *batchv1.Job, holderIdentity string) error {
leaseName := job.Name
leaseClient := clientset.CoordinationV1().Leases(job.Namespace)
lease, err := leaseClient.Get(ctx, leaseName, metav1.GetOptions{})
if err != nil {
if apierrors.IsNotFound(err) {
return nil
}
return fmt.Errorf("failed to get checkpoint lease %s/%s for release: %w", job.Namespace, leaseName, err)
}
if lease.Spec.HolderIdentity == nil || *lease.Spec.HolderIdentity != holderIdentity {
log.V(1).Info("Skipping checkpoint lease release because another holder owns it",
"lease", fmt.Sprintf("%s/%s", job.Namespace, leaseName),
"holder", holderIdentity,
)
return nil
}
if err := leaseClient.Delete(ctx, leaseName, metav1.DeleteOptions{}); err != nil && !apierrors.IsNotFound(err) {
return fmt.Errorf("failed to delete checkpoint lease %s/%s: %w", job.Namespace, leaseName, err)
}
return nil
}
func (w *NodeController) renewCheckpointLease(ctx context.Context, log logr.Logger, job *batchv1.Job, stopLease context.CancelCauseFunc) {
ticker := time.NewTicker(checkpointLeaseRenewInterval)
defer ticker.Stop()
for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return fmt.Errorf("pod annotation retry interrupted: %w", ctx.Err()) return
case <-time.After(delay): case <-ticker.C:
if err := renewCheckpointLease(ctx, w.clientset, job, w.holderID); err != nil {
log.Error(err, "Failed to renew checkpoint lease")
stopLease(fmt.Errorf("checkpoint lease renewal failed: %w", err))
return
}
} }
delay *= 2
} }
}
return fmt.Errorf("failed to annotate pod after %d attempts: %w", terminalStatusPatchRetryAttempts, lastErr) func annotateJob(ctx context.Context, clientset kubernetes.Interface, log logr.Logger, job *batchv1.Job, annotations map[string]string) error {
patchBytes, err := json.Marshal(map[string]any{
"metadata": map[string]any{
"annotations": annotations,
},
})
if err != nil {
return fmt.Errorf("failed to build job annotation patch payload: %w", err)
}
_, err = clientset.BatchV1().Jobs(job.Namespace).Patch(
ctx, job.Name, ktypes.MergePatchType, patchBytes, metav1.PatchOptions{},
)
if err != nil {
log.Error(err, "Failed to annotate checkpoint job",
"job", fmt.Sprintf("%s/%s", job.Namespace, job.Name),
"annotations", annotations,
)
}
return err
} }
func waitForPodReady(ctx context.Context, clientset kubernetes.Interface, namespace, podName, containerName string) error { func waitForPodReady(ctx context.Context, clientset kubernetes.Interface, namespace, podName, containerName string) error {
......
...@@ -15,7 +15,7 @@ import ( ...@@ -15,7 +15,7 @@ import (
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/types" "github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/types"
) )
// RestoreLogFilename is the CRIU restore log filename (also used by orchestrate/restore.go). // RestoreLogFilename is the CRIU restore log filename (also used by executor/restore.go).
const RestoreLogFilename = "restore.log" const RestoreLogFilename = "restore.log"
const ( const (
......
...@@ -43,9 +43,14 @@ func shouldSetCgroupRoot(cgMode criurpc.CriuCgMode) bool { ...@@ -43,9 +43,14 @@ func shouldSetCgroupRoot(cgMode criurpc.CriuCgMode) bool {
// applyCommonSettings sets CRIU options shared between dump and restore. // applyCommonSettings sets CRIU options shared between dump and restore.
func applyCommonSettings(opts *criurpc.CriuOpts, settings *types.CRIUSettings) error { func applyCommonSettings(opts *criurpc.CriuOpts, settings *types.CRIUSettings) error {
if settings.TcpClose && settings.TcpEstablished {
return fmt.Errorf("tcpClose and tcpEstablished cannot both be true")
}
opts.LogLevel = proto.Int32(settings.LogLevel) opts.LogLevel = proto.Int32(settings.LogLevel)
opts.ShellJob = proto.Bool(settings.ShellJob) opts.ShellJob = proto.Bool(settings.ShellJob)
opts.TcpClose = proto.Bool(settings.TcpClose) opts.TcpClose = proto.Bool(settings.TcpClose)
opts.TcpEstablished = proto.Bool(settings.TcpEstablished)
opts.FileLocks = proto.Bool(settings.FileLocks) opts.FileLocks = proto.Bool(settings.FileLocks)
opts.ExtUnixSk = proto.Bool(settings.ExtUnixSk) opts.ExtUnixSk = proto.Bool(settings.ExtUnixSk)
opts.LinkRemap = proto.Bool(settings.LinkRemap) opts.LinkRemap = proto.Bool(settings.LinkRemap)
......
...@@ -53,7 +53,7 @@ func TestApplyCommonSettings(t *testing.T) { ...@@ -53,7 +53,7 @@ func TestApplyCommonSettings(t *testing.T) {
settings := &types.CRIUSettings{ settings := &types.CRIUSettings{
LogLevel: 4, LogLevel: 4,
ShellJob: true, ShellJob: true,
TcpClose: true, TcpEstablished: true,
FileLocks: true, FileLocks: true,
ExtUnixSk: true, ExtUnixSk: true,
LinkRemap: true, LinkRemap: true,
...@@ -70,8 +70,11 @@ func TestApplyCommonSettings(t *testing.T) { ...@@ -70,8 +70,11 @@ func TestApplyCommonSettings(t *testing.T) {
if !opts.GetShellJob() { if !opts.GetShellJob() {
t.Error("ShellJob should be true") t.Error("ShellJob should be true")
} }
if !opts.GetTcpClose() { if !opts.GetTcpEstablished() {
t.Error("TcpClose should be true") t.Error("TcpEstablished should be true")
}
if opts.GetTcpClose() {
t.Error("TcpClose should be false")
} }
if !opts.GetFileLocks() { if !opts.GetFileLocks() {
t.Error("FileLocks should be true") t.Error("FileLocks should be true")
...@@ -97,6 +100,17 @@ func TestApplyCommonSettings(t *testing.T) { ...@@ -97,6 +100,17 @@ func TestApplyCommonSettings(t *testing.T) {
t.Error("expected error for invalid ManageCgroupsMode") t.Error("expected error for invalid ManageCgroupsMode")
} }
}) })
t.Run("conflicting tcp settings return error", func(t *testing.T) {
opts := &criurpc.CriuOpts{}
settings := &types.CRIUSettings{
TcpClose: true,
TcpEstablished: true,
}
if err := applyCommonSettings(opts, settings); err == nil {
t.Error("expected error for conflicting tcp settings")
}
})
} }
func TestBuildRestoreExtMounts(t *testing.T) { func TestBuildRestoreExtMounts(t *testing.T) {
......
// Package orchestrate provides the top-level checkpoint and restore orchestrators. // Package executor provides the top-level checkpoint and restore executors.
// These wire together the lib packages (criu, cuda, etc.) into multi-step workflows. // These wire together the lib packages (criu, cuda, etc.) into multi-step workflows.
package orchestrate package executor
import ( import (
"context" "context"
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
criurpc "github.com/checkpoint-restore/go-criu/v8/rpc" criurpc "github.com/checkpoint-restore/go-criu/v8/rpc"
"github.com/containerd/containerd" "github.com/containerd/containerd"
"github.com/go-logr/logr" "github.com/go-logr/logr"
"github.com/google/uuid"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/common" "github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/common"
"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/criu" "github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/criu"
...@@ -21,33 +22,44 @@ import ( ...@@ -21,33 +22,44 @@ import (
// CheckpointRequest holds per-checkpoint identifiers for a checkpoint operation. // CheckpointRequest holds per-checkpoint identifiers for a checkpoint operation.
type CheckpointRequest struct { type CheckpointRequest struct {
ContainerID string ContainerID string
ContainerName string ContainerName string
CheckpointHash string CheckpointHash string
CheckpointDir string CheckpointLocation string
NodeName string CheckpointStorageType string
PodName string NodeName string
PodNamespace string PodName string
PodNamespace string
} }
// Checkpoint performs a CRIU dump of a container. // Checkpoint performs a CRIU dump of a container.
// The operation has three phases: inspect, configure, capture. // The operation has three phases: inspect, configure, capture.
// //
// The checkpoint directory is staged under tmp/<hash> during the operation. // The checkpoint directory is staged under tmp/<uuid> during the operation.
// On success, it is atomically renamed to <hash> at the base path root. // On success, the previous checkpoint is removed and the staged directory is
// renamed into place at the base path root.
func Checkpoint(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req CheckpointRequest, cfg *types.AgentConfig) error { func Checkpoint(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req CheckpointRequest, cfg *types.AgentConfig) error {
checkpointStart := time.Now() checkpointStart := time.Now()
log.Info("=== Starting checkpoint operation ===") log.Info("=== Starting checkpoint operation ===")
finalDir := filepath.Join(req.CheckpointDir, req.CheckpointHash) if req.CheckpointStorageType != "pvc" {
tmpDir := filepath.Join(req.CheckpointDir, "tmp", req.CheckpointHash) return fmt.Errorf("checkpoint storage type %q is not supported", req.CheckpointStorageType)
if err := os.RemoveAll(tmpDir); err != nil {
return fmt.Errorf("failed to clean checkpoint staging directory: %w", err)
} }
if err := os.MkdirAll(tmpDir, 0700); err != nil { if req.CheckpointLocation == "" {
return fmt.Errorf("failed to create checkpoint directory: %w", err) return fmt.Errorf("checkpoint location is required")
} }
finalDir := req.CheckpointLocation
tmpRoot := filepath.Join(filepath.Dir(finalDir), "tmp")
if err := os.MkdirAll(tmpRoot, 0700); err != nil {
return fmt.Errorf("failed to create checkpoint staging root: %w", err)
}
tmpDir := filepath.Join(tmpRoot, uuid.NewString())
if err := os.Mkdir(tmpDir, 0700); err != nil {
return fmt.Errorf("failed to create checkpoint staging directory: %w", err)
}
defer os.RemoveAll(tmpDir)
// Phase 1: Inspect container state // Phase 1: Inspect container state
state, err := inspectContainer(ctx, ctrd, log, req) state, err := inspectContainer(ctx, ctrd, log, req)
if err != nil { if err != nil {
...@@ -67,7 +79,9 @@ func Checkpoint(ctx context.Context, ctrd *containerd.Client, log logr.Logger, r ...@@ -67,7 +79,9 @@ func Checkpoint(ctx context.Context, ctrd *containerd.Client, log logr.Logger, r
} }
// Remove any previous checkpoint with the same identity hash before finalizing // Remove any previous checkpoint with the same identity hash before finalizing
os.RemoveAll(finalDir) if err := os.RemoveAll(finalDir); err != nil {
return fmt.Errorf("failed to remove previous checkpoint directory: %w", err)
}
if err := os.Rename(tmpDir, finalDir); err != nil { if err := os.Rename(tmpDir, finalDir); err != nil {
return fmt.Errorf("failed to finalize checkpoint directory: %w", err) return fmt.Errorf("failed to finalize checkpoint directory: %w", err)
} }
......
package orchestrate package executor
import ( import (
"context" "context"
......
package orchestrate package executor
import ( import (
"bytes" "bytes"
...@@ -24,12 +24,13 @@ import ( ...@@ -24,12 +24,13 @@ import (
// RestoreRequest holds the parameters for a restore operation. // RestoreRequest holds the parameters for a restore operation.
type RestoreRequest struct { type RestoreRequest struct {
CheckpointHash string CheckpointHash string
CheckpointBase string CheckpointLocation string
NSRestorePath string CheckpointStorageType string
PodName string NSRestorePath string
PodNamespace string PodName string
ContainerName string PodNamespace string
ContainerName string
} }
// Restore performs external restore for the given request. // Restore performs external restore for the given request.
...@@ -72,8 +73,15 @@ func Restore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req ...@@ -72,8 +73,15 @@ func Restore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req
} }
func inspectRestore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req RestoreRequest) (*types.RestoreContainerSnapshot, error) { func inspectRestore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req RestoreRequest) (*types.RestoreContainerSnapshot, error) {
checkpointPath := filepath.Join(req.CheckpointBase, req.CheckpointHash) if req.CheckpointStorageType != "pvc" {
baseAbs, err := filepath.Abs(req.CheckpointBase) return nil, fmt.Errorf("checkpoint storage type %q is not supported", req.CheckpointStorageType)
}
if req.CheckpointLocation == "" {
return nil, fmt.Errorf("checkpoint location is required")
}
checkpointPath := req.CheckpointLocation
baseAbs, err := filepath.Abs(filepath.Dir(checkpointPath))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to resolve checkpoint base path: %w", err) return nil, fmt.Errorf("failed to resolve checkpoint base path: %w", err)
} }
......
...@@ -4,7 +4,6 @@ package types ...@@ -4,7 +4,6 @@ package types
import ( import (
"fmt" "fmt"
"os" "os"
"strings"
"time" "time"
) )
...@@ -13,7 +12,6 @@ import ( ...@@ -13,7 +12,6 @@ import (
type AgentConfig struct { type AgentConfig struct {
NodeName string `yaml:"-"` NodeName string `yaml:"-"`
RestrictedNamespace string `yaml:"-"` RestrictedNamespace string `yaml:"-"`
BasePath string `yaml:"basePath"`
Overlay OverlaySettings `yaml:"overlay"` Overlay OverlaySettings `yaml:"overlay"`
Restore RestoreSpec `yaml:"restore"` Restore RestoreSpec `yaml:"restore"`
CRIU CRIUSettings `yaml:"criu"` CRIU CRIUSettings `yaml:"criu"`
...@@ -29,8 +27,11 @@ func (c *AgentConfig) LoadEnvOverrides() { ...@@ -29,8 +27,11 @@ func (c *AgentConfig) LoadEnvOverrides() {
} }
func (c *AgentConfig) Validate() error { func (c *AgentConfig) Validate() error {
if strings.TrimSpace(c.BasePath) == "" { if c.CRIU.TcpClose && c.CRIU.TcpEstablished {
return &ConfigError{Field: "basePath", Message: "basePath is required"} return &ConfigError{
Field: "criu",
Message: "tcpClose and tcpEstablished cannot both be true",
}
} }
return c.Restore.Validate() return c.Restore.Validate()
} }
...@@ -65,6 +66,7 @@ type CRIUSettings struct { ...@@ -65,6 +66,7 @@ type CRIUSettings struct {
LeaveRunning bool `yaml:"leaveRunning"` LeaveRunning bool `yaml:"leaveRunning"`
ShellJob bool `yaml:"shellJob"` ShellJob bool `yaml:"shellJob"`
TcpClose bool `yaml:"tcpClose"` TcpClose bool `yaml:"tcpClose"`
TcpEstablished bool `yaml:"tcpEstablished"`
FileLocks bool `yaml:"fileLocks"` FileLocks bool `yaml:"fileLocks"`
OrphanPtsMaster bool `yaml:"orphanPtsMaster"` OrphanPtsMaster bool `yaml:"orphanPtsMaster"`
ExtUnixSk bool `yaml:"extUnixSk"` ExtUnixSk bool `yaml:"extUnixSk"`
...@@ -83,9 +85,7 @@ type CRIUSettings struct { ...@@ -83,9 +85,7 @@ type CRIUSettings struct {
// OverlaySettings is the static config for rootfs exclusions. // OverlaySettings is the static config for rootfs exclusions.
type OverlaySettings struct { type OverlaySettings struct {
SystemDirs []string `yaml:"systemDirs"` Exclusions []string `yaml:"exclusions"`
CacheDirs []string `yaml:"cacheDirs"`
AdditionalExclusions []string `yaml:"additionalExclusions"`
} }
// ConfigError represents a configuration validation error. // ConfigError represents a configuration validation error.
......
...@@ -24,7 +24,7 @@ func TestManifestRoundTrip(t *testing.T) { ...@@ -24,7 +24,7 @@ func TestManifestRoundTrip(t *testing.T) {
}, },
NewSourcePodManifest("ctr-abc", 42, "node-1", "my-pod", "default", []string{"pipe:[111]", "pipe:[222]", "pipe:[333]"}), NewSourcePodManifest("ctr-abc", 42, "node-1", "my-pod", "default", []string{"pipe:[111]", "pipe:[222]", "pipe:[333]"}),
OverlayManifest{ OverlayManifest{
Exclusions: OverlaySettings{SystemDirs: []string{"/proc", "/sys"}}, Exclusions: OverlaySettings{Exclusions: []string{"/proc", "/sys"}},
UpperDir: "/var/lib/containerd/upper", UpperDir: "/var/lib/containerd/upper",
ExternalPaths: []string{"/proc/acpi"}, ExternalPaths: []string{"/proc/acpi"},
BindMountDests: []string{"/data"}, BindMountDests: []string{"/data"},
......
...@@ -262,9 +262,10 @@ _Appears in:_ ...@@ -262,9 +262,10 @@ _Appears in:_
| Field | Description | Default | Validation | | Field | Description | Default | Validation |
| --- | --- | --- | --- | | --- | --- | --- | --- |
| `podTemplateSpec` _[PodTemplateSpec](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#podtemplatespec-v1-core)_ | PodTemplateSpec allows customizing the checkpoint Job pod<br />This should include the container that runs the workload to be checkpointed | | Required: \{\} <br /> | | `podTemplateSpec` _[PodTemplateSpec](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#podtemplatespec-v1-core)_ | PodTemplateSpec allows customizing the checkpoint Job pod<br />This should include the container that runs the workload to be checkpointed | | Required: \{\} <br /> |
| `activeDeadlineSeconds` _integer_ | ActiveDeadlineSeconds specifies the maximum time the Job can run | 3600 | Optional: \{\} <br /> | | `sharedMemory` _[SharedMemorySpec](#sharedmemoryspec)_ | SharedMemory controls the tmpfs mounted at /dev/shm for the checkpoint Job pod.<br />When omitted, checkpoint Jobs use the same default 8Gi tmpfs as Dynamo components. | | Optional: \{\} <br /> |
| `backoffLimit` _integer_ | BackoffLimit specifies the number of retries before marking the Job failed | 3 | Optional: \{\} <br /> | | `activeDeadlineSeconds` _integer_ | ActiveDeadlineSeconds specifies the maximum time the Job can run | 3600 | Minimum: 1 <br />Optional: \{\} <br /> |
| `ttlSecondsAfterFinished` _integer_ | TTLSecondsAfterFinished specifies how long to keep the Job after completion | 300 | Optional: \{\} <br /> | | `backoffLimit` _integer_ | Deprecated: BackoffLimit is ignored. Checkpoint Jobs never retry. | | Minimum: 0 <br />Optional: \{\} <br /> |
| `ttlSecondsAfterFinished` _integer_ | TTLSecondsAfterFinished specifies how long to keep the Job after completion | 300 | Minimum: 0 <br />Optional: \{\} <br /> |
#### DynamoCheckpointPhase #### DynamoCheckpointPhase
...@@ -324,7 +325,7 @@ _Appears in:_ ...@@ -324,7 +325,7 @@ _Appears in:_
| `jobName` _string_ | JobName is the name of the checkpoint creation Job | | Optional: \{\} <br /> | | `jobName` _string_ | JobName is the name of the checkpoint creation Job | | Optional: \{\} <br /> |
| `createdAt` _[Time](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#time-v1-meta)_ | CreatedAt is the timestamp when the checkpoint tar was created | | Optional: \{\} <br /> | | `createdAt` _[Time](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#time-v1-meta)_ | CreatedAt is the timestamp when the checkpoint tar was created | | Optional: \{\} <br /> |
| `message` _string_ | Message provides additional information about the current state | | Optional: \{\} <br /> | | `message` _string_ | Message provides additional information about the current state | | Optional: \{\} <br /> |
| `conditions` _[Condition](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#condition-v1-meta) array_ | Conditions represent the latest available observations of the checkpoint's state | | Optional: \{\} <br /> | | `conditions` _[Condition](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.28/#condition-v1-meta) array_ | DEPRECATED: Conditions are deprecated. Use status.phase instead. | | Optional: \{\} <br /> |
#### DynamoCheckpointStorageType #### DynamoCheckpointStorageType
...@@ -1155,7 +1156,7 @@ _Appears in:_ ...@@ -1155,7 +1156,7 @@ _Appears in:_
| --- | --- | --- | --- | | --- | --- | --- | --- |
| `enabled` _boolean_ | Enabled indicates whether checkpointing is enabled for this service | false | Optional: \{\} <br /> | | `enabled` _boolean_ | Enabled indicates whether checkpointing is enabled for this service | false | Optional: \{\} <br /> |
| `mode` _[CheckpointMode](#checkpointmode)_ | Mode defines how checkpoint creation is handled<br />- Auto: DGD controller creates Checkpoint CR automatically<br />- Manual: User must create Checkpoint CR | Auto | Enum: [Auto Manual] <br />Optional: \{\} <br /> | | `mode` _[CheckpointMode](#checkpointmode)_ | Mode defines how checkpoint creation is handled<br />- Auto: DGD controller creates Checkpoint CR automatically<br />- Manual: User must create Checkpoint CR | Auto | Enum: [Auto Manual] <br />Optional: \{\} <br /> |
| `checkpointRef` _string_ | CheckpointRef references an existing Checkpoint CR to use<br />If specified, Identity is ignored and this checkpoint is used directly | | Optional: \{\} <br /> | | `checkpointRef` _string_ | CheckpointRef references an existing DynamoCheckpoint CR by metadata.name.<br />If specified, this service's Identity is ignored and the referenced checkpoint is used directly. | | Optional: \{\} <br /> |
| `identity` _[DynamoCheckpointIdentity](#dynamocheckpointidentity)_ | Identity defines the checkpoint identity for hash computation<br />Used when Mode is Auto or when looking up existing checkpoints<br />Required when checkpointRef is not specified | | Optional: \{\} <br /> | | `identity` _[DynamoCheckpointIdentity](#dynamocheckpointidentity)_ | Identity defines the checkpoint identity for hash computation<br />Used when Mode is Auto or when looking up existing checkpoints<br />Required when checkpointRef is not specified | | Optional: \{\} <br /> |
...@@ -1174,7 +1175,7 @@ _Appears in:_ ...@@ -1174,7 +1175,7 @@ _Appears in:_
| --- | --- | --- | --- | | --- | --- | --- | --- |
| `checkpointName` _string_ | CheckpointName is the name of the associated Checkpoint CR | | Optional: \{\} <br /> | | `checkpointName` _string_ | CheckpointName is the name of the associated Checkpoint CR | | Optional: \{\} <br /> |
| `identityHash` _string_ | IdentityHash is the computed hash of the checkpoint identity | | Optional: \{\} <br /> | | `identityHash` _string_ | IdentityHash is the computed hash of the checkpoint identity | | Optional: \{\} <br /> |
| `ready` _boolean_ | Ready indicates if the checkpoint is ready for use | | Optional: \{\} <br /> | | `ready` _boolean_ | Ready indicates if the checkpoint was visible to the worker at startup | | Optional: \{\} <br /> |
#### ServiceReplicaStatus #### ServiceReplicaStatus
...@@ -1208,6 +1209,7 @@ _Appears in:_ ...@@ -1208,6 +1209,7 @@ _Appears in:_
_Appears in:_ _Appears in:_
- [DynamoCheckpointJobConfig](#dynamocheckpointjobconfig)
- [DynamoComponentDeploymentSharedSpec](#dynamocomponentdeploymentsharedspec) - [DynamoComponentDeploymentSharedSpec](#dynamocomponentdeploymentsharedspec)
- [DynamoComponentDeploymentSpec](#dynamocomponentdeploymentspec) - [DynamoComponentDeploymentSpec](#dynamocomponentdeploymentspec)
...@@ -2349,17 +2351,6 @@ These are injected into all components when the corresponding infrastructure ser ...@@ -2349,17 +2351,6 @@ These are injected into all components when the corresponding infrastructure ser
| --- | --- | --- | --- | --- | | --- | --- | --- | --- | --- |
| `OMPI_MCA_orte_keep_fqdn_hostnames` | Instructs OpenMPI to preserve FQDN hostnames for inter-node communication | `1` | `string` | Multinode deployments only | | `OMPI_MCA_orte_keep_fqdn_hostnames` | Instructs OpenMPI to preserve FQDN hostnames for inter-node communication | `1` | `string` | Multinode deployments only |
### Checkpoint / Restore
These environment variables are injected when checkpoint/restore is enabled for a component.
| Variable | Purpose | Default | Type | Condition |
| --- | --- | --- | --- | --- |
| `DYN_CHECKPOINT_PATH` | Base directory where checkpoint data is stored | From operator checkpoint config `storage.pvc.basePath` | `string` | PVC storage type |
| `DYN_CHECKPOINT_LOCATION` | Full checkpoint URI (for non-PVC backends) | — | `string` | S3 or OCI storage type |
| `DYN_CHECKPOINT_HASH` | Identity hash that uniquely identifies the checkpoint | — | `string` | Always set when checkpoint is enabled |
| `SKIP_WAIT_FOR_CHECKPOINT` | Skips the checkpoint readiness polling loop; checks once and proceeds | — | `string` | Set on restored and DGD pods |
## Service Accounts ## Service Accounts
The following component types automatically receive dedicated service accounts: The following component types automatically receive dedicated service accounts:
......
...@@ -11,7 +11,7 @@ title: Snapshot ...@@ -11,7 +11,7 @@ title: Snapshot
| Startup Type | Time | What Happens | | Startup Type | Time | What Happens |
|--------------|------|--------------| |--------------|------|--------------|
| **Cold Start** | ~1 min | Download model, load to GPU, initialize engine | | **Cold Start** | ~1 min | Download model, load to GPU, initialize engine |
| **Warm Start** (restore from checkpoint) | ~ 10 sec | Restore from checkpoint tar | | **Warm Start** (restore from checkpoint) | ~ 10 sec | Restore from a ready checkpoint directory |
> ⚠️ Restore time may vary depending on cluster configuration (storage bandwidth, GPU model, etc.) > ⚠️ Restore time may vary depending on cluster configuration (storage bandwidth, GPU model, etc.)
...@@ -146,34 +146,13 @@ spec: ...@@ -146,34 +146,13 @@ spec:
args: args:
- --model - --model
- Qwen/Qwen3-0.6B - Qwen/Qwen3-0.6B
- --disable-custom-all-reduce
env: env:
- name: GLOO_SOCKET_IFNAME
value: lo
- name: NCCL_SOCKET_IFNAME
value: lo
- name: NCCL_DEBUG - name: NCCL_DEBUG
value: ERROR value: ERROR
- name: TORCH_CPP_LOG_LEVEL - name: TORCH_CPP_LOG_LEVEL
value: ERROR value: ERROR
- name: TORCH_DISTRIBUTED_DEBUG - name: TORCH_DISTRIBUTED_DEBUG
value: "OFF" value: "OFF"
- name: CUDA_ERROR_LEVEL
value: "10"
- name: NCCL_CUMEM_ENABLE
value: "0"
- name: NCCL_CUMEM_HOST_ENABLE
value: "0"
- name: NCCL_NVLS_ENABLE
value: "0"
- name: NCCL_P2P_DISABLE
value: "0"
- name: NCCL_SHM_DISABLE
value: "1"
- name: NCCL_IB_DISABLE
value: "1"
- name: TORCH_NCCL_ENABLE_MONITORING
value: "0"
``` ```
For SGLang, use `dynamo.sglang`, an SGLang placeholder image, `backendFramework: sglang`, and the matching CLI flags. For SGLang, use `dynamo.sglang`, an SGLang placeholder image, `backendFramework: sglang`, and the matching CLI flags.
...@@ -184,24 +163,26 @@ Apply the manifest: ...@@ -184,24 +163,26 @@ Apply the manifest:
kubectl apply -f vllm-snapshot-demo.yaml -n ${NAMESPACE} kubectl apply -f vllm-snapshot-demo.yaml -n ${NAMESPACE}
``` ```
On the first rollout, the worker cold-starts, the operator creates a `DynamoCheckpoint`, and the checkpoint Job writes data into `snapshot-pvc`. On the first rollout, the worker cold-starts, the operator resolves the checkpoint identity hash, and the checkpoint Job writes a new checkpoint directory into `snapshot-pvc`.
### 5. Wait for the checkpoint to become ready ### 5. Wait for the checkpoint to become ready
Capture the checkpoint name from DGD status, then wait for the `DynamoCheckpoint` phase to become `Ready`: Auto mode resolves checkpoints by identity hash. It may create `checkpoint-<hash>` or reuse an existing checkpoint with a different CR name. For the sample identity above, the hash is `73e74442beb109ed`:
```bash ```bash
CHECKPOINT_NAME=$(kubectl get dgd vllm-snapshot-demo -n ${NAMESPACE} \ kubectl get dckpt -n ${NAMESPACE}
-o jsonpath='{.status.checkpoints.VllmDecodeWorker.checkpointName}')
CKPT_NAME=$(kubectl get dckpt -n ${NAMESPACE} \
-l nvidia.com/snapshot-checkpoint-hash=73e74442beb109ed \
-o jsonpath='{.items[0].metadata.name}')
kubectl wait \ kubectl wait \
--for=jsonpath='{.status.phase}'=Ready \ --for=jsonpath='{.status.phase}'=Ready \
"dynamocheckpoint/${CHECKPOINT_NAME}" \ "dynamocheckpoint/${CKPT_NAME}" \
-n ${NAMESPACE} \ -n ${NAMESPACE} \
--timeout=30m --timeout=5m
``` ```
The DGD status also reports the computed checkpoint hash at `.status.checkpoints.VllmDecodeWorker.identityHash`. If you change the checkpoint identity, the hash changes and so does the checkpoint selected by Auto mode.
### 6. Trigger restore ### 6. Trigger restore
...@@ -218,7 +199,7 @@ New worker pods for `VllmDecodeWorker` will restore from the ready checkpoint au ...@@ -218,7 +199,7 @@ New worker pods for `VllmDecodeWorker` will restore from the ready checkpoint au
### Auto Mode (Recommended) ### Auto Mode (Recommended)
The operator computes the checkpoint identity hash, looks for an existing `DynamoCheckpoint` with a matching `nvidia.com/snapshot-checkpoint-hash` label, and creates one if it does not find one: The operator computes the checkpoint identity hash, looks up an existing `DynamoCheckpoint` by that hash, and creates a new `DynamoCheckpoint` only when no matching checkpoint already exists:
```yaml ```yaml
checkpoint: checkpoint:
...@@ -232,7 +213,12 @@ checkpoint: ...@@ -232,7 +213,12 @@ checkpoint:
maxModelLen: 4096 maxModelLen: 4096
``` ```
When a service uses checkpointing, DGD status reports the resolved `checkpointName`, `identityHash`, and `ready` fields under `.status.checkpoints.<service-name>`. The `DynamoGraphDeployment` mirrors checkpoint resolution state under `.status.checkpoints`, including the resolved checkpoint CR name, identity hash, and whether the checkpoint was visible to the worker when it started:
```bash
kubectl get dgd vllm-snapshot-demo -n ${NAMESPACE} \
-o jsonpath='{.status.checkpoints.VllmDecodeWorker.checkpointName}{"\n"}{.status.checkpoints.VllmDecodeWorker.identityHash}{"\n"}'
```
### Manual Management and `checkpointRef` ### Manual Management and `checkpointRef`
...@@ -241,26 +227,26 @@ Use `checkpointRef` when you want a service to restore from a specific `DynamoCh ...@@ -241,26 +227,26 @@ Use `checkpointRef` when you want a service to restore from a specific `DynamoCh
```yaml ```yaml
checkpoint: checkpoint:
enabled: true enabled: true
checkpointRef: "qwen3-06b-vllm-prewarm" checkpointRef: "qwen3-06b-bf16"
``` ```
This is useful when: This is useful when:
- You want to **pre-warm checkpoints** before creating DGDs - You want to **pre-warm checkpoints** before creating DGDs
- You want **explicit control** over which checkpoint to use - You want **explicit control** over which checkpoint to use
`checkpointRef` resolves by `DynamoCheckpoint.metadata.name`, not by `status.identityHash`. A manual checkpoint can use any valid Kubernetes resource name. `checkpointRef` resolves by `DynamoCheckpoint.metadata.name`. Use a readable CR name when you want an explicit checkpoint that operators can reference directly.
If you are managing checkpoint CRs yourself, set `mode: Manual` on the service to prevent the operator from creating a new `DynamoCheckpoint` when identity-based lookup does not find one. If you are managing checkpoint CRs yourself, set `mode: Manual` on the service to prevent the operator from creating a new `DynamoCheckpoint` when identity-based lookup does not find one.
```bash ```bash
# Check checkpoint status by CR name # Check checkpoint status by CR name
kubectl get dynamocheckpoint qwen3-06b-vllm-prewarm -n ${NAMESPACE} kubectl get dynamocheckpoint qwen3-06b-bf16 -n ${NAMESPACE}
# Now create DGD referencing it # Now create DGD referencing it
kubectl apply -f my-dgd.yaml -n ${NAMESPACE} kubectl apply -f my-dgd.yaml -n ${NAMESPACE}
``` ```
If you want `mode: Auto` DGDs to discover a manually created checkpoint by identity, add the label `nvidia.com/snapshot-checkpoint-hash=<identity-hash>` to that `DynamoCheckpoint`. Auto-created checkpoints already use that label, and currently use the same hash as the CR name. `mode: Auto` still resolves checkpoints by identity hash. The operator backfills `status.identityHash` and the `nvidia.com/snapshot-checkpoint-hash` label on each `DynamoCheckpoint` so auto lookup and uniqueness checks do not depend on the CR name.
## Checkpoint Identity ## Checkpoint Identity
...@@ -309,7 +295,8 @@ The `DynamoCheckpoint` (shortname: `dckpt`) is a Kubernetes Custom Resource that ...@@ -309,7 +295,8 @@ The `DynamoCheckpoint` (shortname: `dckpt`) is a Kubernetes Custom Resource that
- **Pre-warming:** Create checkpoints before deploying DGDs for instant startup - **Pre-warming:** Create checkpoints before deploying DGDs for instant startup
- **Explicit control:** Manage checkpoint lifecycle independently from DGDs - **Explicit control:** Manage checkpoint lifecycle independently from DGDs
The operator requires `spec.identity` and `spec.job.podTemplateSpec`. The pod template should match the worker container you want checkpointed, including image, command, args, secrets, volumes, and resource limits. You do not need to set the checkpoint environment variables manually; the operator injects them for checkpoint jobs and restored pods. The operator requires `spec.identity` and `spec.job.podTemplateSpec`. The pod template should match the worker container you want checkpointed, including image, command, args, secrets, volumes, and resource limits. You do not need to set checkpoint-control plumbing manually; the operator injects the checkpoint-ready signal path for checkpoint Jobs and adds the restore metadata consumed by restored pods and the node-local controller inside the `snapshot-agent` DaemonSet.
`spec.job.backoffLimit` is deprecated and ignored. Checkpoint Jobs are always single-attempt.
**Create a checkpoint:** **Create a checkpoint:**
...@@ -317,9 +304,7 @@ The operator requires `spec.identity` and `spec.job.podTemplateSpec`. The pod te ...@@ -317,9 +304,7 @@ The operator requires `spec.identity` and `spec.job.podTemplateSpec`. The pod te
apiVersion: nvidia.com/v1alpha1 apiVersion: nvidia.com/v1alpha1
kind: DynamoCheckpoint kind: DynamoCheckpoint
metadata: metadata:
name: qwen3-06b-vllm-prewarm name: qwen3-06b-bf16
labels:
nvidia.com/snapshot-checkpoint-hash: "e5962d34ba272638" # Add this if Auto-mode identity lookup should find the CR
spec: spec:
identity: identity:
model: Qwen/Qwen3-0.6B model: Qwen/Qwen3-0.6B
...@@ -330,7 +315,6 @@ spec: ...@@ -330,7 +315,6 @@ spec:
job: job:
activeDeadlineSeconds: 3600 activeDeadlineSeconds: 3600
backoffLimit: 3
ttlSecondsAfterFinished: 300 ttlSecondsAfterFinished: 300
podTemplateSpec: podTemplateSpec:
spec: spec:
...@@ -345,18 +329,19 @@ spec: ...@@ -345,18 +329,19 @@ spec:
args: args:
- --model - --model
- Qwen/Qwen3-0.6B - Qwen/Qwen3-0.6B
- --disable-custom-all-reduce
env: env:
- name: GLOO_SOCKET_IFNAME - name: NCCL_DEBUG
value: lo value: ERROR
- name: NCCL_SOCKET_IFNAME - name: TORCH_CPP_LOG_LEVEL
value: lo value: ERROR
- name: TORCH_DISTRIBUTED_DEBUG
value: "OFF"
resources: resources:
limits: limits:
nvidia.com/gpu: "1" nvidia.com/gpu: "1"
``` ```
You can name the CR however you want if you plan to use `checkpointRef`. If you want `mode: Auto` identity lookup to find a manual CR, set the `nvidia.com/snapshot-checkpoint-hash` label to the computed 16-character identity hash. Using the hash as the CR name is a convenient convention, but it is not required. For this example identity, the operator computes a deterministic identity hash and stores it in `status.identityHash`. Auto mode uses that hash, not the CR name, when it decides whether to reuse or create a checkpoint.
**Check status:** **Check status:**
...@@ -366,9 +351,9 @@ kubectl get dynamocheckpoint -n ${NAMESPACE} ...@@ -366,9 +351,9 @@ kubectl get dynamocheckpoint -n ${NAMESPACE}
# Or use shortname # Or use shortname
kubectl get dckpt -n ${NAMESPACE} kubectl get dckpt -n ${NAMESPACE}
NAME MODEL BACKEND PHASE HASH AGE NAME MODEL BACKEND PHASE HASH AGE
qwen3-06b-vllm-prewarm Qwen/Qwen3-0.6B vllm Ready e5962d34ba272638 5m qwen3-06b-bf16 Qwen/Qwen3-0.6B vllm Ready 3bff874d069f0ed5 5m
llama3-8b-vllm-prewarm meta-llama/Llama-3-8B vllm Creating 7ab4f89c12de3456 2m llama3-8b-bf16 meta-llama/Meta-Llama-3-8B-Instruct vllm Creating 9be4f5574b5a285d 2m
``` ```
**Phases:** **Phases:**
...@@ -380,45 +365,33 @@ llama3-8b-vllm-prewarm meta-llama/Llama-3-8B vllm Creating 7ab4f89c12de ...@@ -380,45 +365,33 @@ llama3-8b-vllm-prewarm meta-llama/Llama-3-8B vllm Creating 7ab4f89c12de
| `Ready` | Checkpoint available for use | | `Ready` | Checkpoint available for use |
| `Failed` | Checkpoint creation failed | | `Failed` | Checkpoint creation failed |
`Ready` is a value in `status.phase`, not a Kubernetes condition. The `conditions` array tracks job lifecycle events:
| Condition Type | Meaning |
|----------------|---------|
| `JobCreated` | The checkpoint Job has been created |
| `JobCompleted` | The checkpoint Job has completed successfully or failed |
Other useful status fields are: Other useful status fields are:
| Field | Meaning | | Field | Meaning |
|-------|---------| |-------|---------|
| `status.identityHash` | Deterministic hash of `spec.identity` used for auto lookup and reuse |
| `status.jobName` | Name of the checkpoint Job | | `status.jobName` | Name of the checkpoint Job |
| `status.identityHash` | Computed 16-character hash for the checkpoint identity |
| `status.location` | Checkpoint location in the configured storage backend | | `status.location` | Checkpoint location in the configured storage backend |
| `status.storageType` | Storage backend type (`pvc`, `s3`, or `oci`) | | `status.storageType` | Storage backend type (`pvc`, `s3`, or `oci`) |
| `status.createdAt` | Timestamp recorded when the checkpoint becomes ready | | `status.createdAt` | Timestamp recorded when the checkpoint becomes ready |
| `status.message` | Failure or progress message when available | | `status.message` | Failure or progress message when available |
`status.conditions` is deprecated for `DynamoCheckpoint`. The legacy condition types `JobCreated` and `JobCompleted` are kept for compatibility only. Prefer `status.phase`, `status.jobName`, and `status.message` when checking checkpoint progress.
**Detailed status:** **Detailed status:**
```bash ```bash
kubectl describe dckpt qwen3-06b-vllm-prewarm -n ${NAMESPACE} kubectl describe dckpt qwen3-06b-bf16 -n ${NAMESPACE}
``` ```
```yaml ```yaml
Status: Status:
Phase: Ready Phase: Ready
IdentityHash: e5962d34ba272638 IdentityHash: 3bff874d069f0ed5
JobName: checkpoint-qwen3-06b-vllm-prewarm JobName: checkpoint-job-3bff874d069f0ed5
Location: /checkpoints/e5962d34ba272638.tar Location: /checkpoints/3bff874d069f0ed5
StorageType: pvc StorageType: pvc
CreatedAt: 2026-01-29T10:05:00Z CreatedAt: 2026-01-29T10:05:00Z
Conditions:
- Type: JobCreated
Status: "True"
Reason: JobCreated
- Type: JobCompleted
Status: "True"
Reason: JobSucceeded
``` ```
**Reference from DGD:** **Reference from DGD:**
...@@ -431,16 +404,16 @@ spec: ...@@ -431,16 +404,16 @@ spec:
VllmDecodeWorker: VllmDecodeWorker:
checkpoint: checkpoint:
enabled: true enabled: true
checkpointRef: "qwen3-06b-vllm-prewarm" checkpointRef: "qwen3-06b-bf16"
``` ```
Or use `mode: Auto` with the same identity and snapshot-hash label, and the operator will reuse it automatically. Or use `mode: Auto` with the same identity, and the operator will reuse the same deterministic checkpoint object automatically.
## Limitations ## Limitations
- **LLM workers only**: Checkpoint/restore supports LLM decode and prefill workers. Specialized workers (multimodal, embedding, diffusion) are not supported. - **LLM workers only**: Checkpoint/restore supports LLM decode and prefill workers. Specialized workers (multimodal, embedding, diffusion) are not supported.
- **Single-GPU only**: Multi-GPU configurations may work in very basic hardware configurations, but are not officially supported yet. - **Single-GPU only**: Multi-GPU configurations may work in very basic hardware configurations, but are not officially supported yet.
- **Network state**: No active TCP connections can be checkpointed - **Network state**: Restore is sensitive to live TCP socket state. Loopback bootstrap/control sockets can work with the supported CRIU TCP policies, but non-loopback or pod-IP-bound connections can still break restore.
- **Security**: Dynamo Snapshot runs as a **privileged DaemonSet** which is required to run CRIU and cuda-checkpoint. However, workload pods do not need to be privileged. - **Security**: Dynamo Snapshot runs as a **privileged DaemonSet** which is required to run CRIU and cuda-checkpoint. However, workload pods do not need to be privileged.
## Troubleshooting ## Troubleshooting
...@@ -451,7 +424,10 @@ Or use `mode: Auto` with the same identity and snapshot-hash label, and the oper ...@@ -451,7 +424,10 @@ Or use `mode: Auto` with the same identity and snapshot-hash label, and the oper
```bash ```bash
kubectl get dckpt -n ${NAMESPACE} kubectl get dckpt -n ${NAMESPACE}
kubectl describe dckpt <checkpoint-name> -n ${NAMESPACE} kubectl describe dckpt <checkpoint-name> -n ${NAMESPACE}
kubectl logs job/$(kubectl get dckpt <checkpoint-name> -n ${NAMESPACE} -o jsonpath='{.status.jobName}') -n ${NAMESPACE} JOB_NAME=$(kubectl get dckpt <checkpoint-name> -n ${NAMESPACE} -o jsonpath='{.status.jobName}')
if [ -n "${JOB_NAME}" ]; then
kubectl logs job/"${JOB_NAME}" -n ${NAMESPACE}
fi
``` ```
2. Check the DaemonSet: 2. Check the DaemonSet:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment