controller.go 21.5 KB
Newer Older
1
2
3
4
5
// Package controller implements the node-local control loop inside snapshot-agent.
// It does not own CRDs or replace the operator. Instead it watches pod, job, and
// lease state on the current node and delegates CRIU/CUDA execution to the
// snapshot executor workflows.
package controller
6
7
8
9
10
11

import (
	"context"
	"fmt"
	"os"
	"path/filepath"
12
	"strings"
13
	"sync"
14
	"syscall"
15
16
	"time"

17
18
	"github.com/containerd/containerd"
	"github.com/go-logr/logr"
19
20
	"github.com/google/uuid"
	batchv1 "k8s.io/api/batch/v1"
21
22
23
24
25
26
27
28
	corev1 "k8s.io/api/core/v1"
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
	"k8s.io/apimachinery/pkg/labels"
	"k8s.io/client-go/informers"
	"k8s.io/client-go/kubernetes"
	"k8s.io/client-go/rest"
	"k8s.io/client-go/tools/cache"

29
30
31
32
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/executor"
	snapshotruntime "github.com/ai-dynamo/dynamo/deploy/snapshot/internal/runtime"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/types"
	snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
33
)
34

35
36
37
// NodeController watches local-node pods with checkpoint metadata and reconciles
// snapshot execution for checkpoint and restore requests.
type NodeController struct {
38
39
40
41
	config     *types.AgentConfig
	clientset  kubernetes.Interface
	containerd *containerd.Client
	log        logr.Logger
42
	holderID   string
43

44
45
	inFlight   map[string]struct{}
	inFlightMu sync.Mutex
46
47
48
49

	stopCh chan struct{}
}

50
51
// NewNodeController creates the node-local controller that runs inside snapshot-agent.
func NewNodeController(
52
53
54
	cfg *types.AgentConfig,
	containerd *containerd.Client,
	log logr.Logger,
55
) (*NodeController, error) {
56
57
58
59
60
61
62
63
64
65
	restConfig, err := rest.InClusterConfig()
	if err != nil {
		return nil, fmt.Errorf("failed to get in-cluster config: %w", err)
	}

	clientset, err := kubernetes.NewForConfig(restConfig)
	if err != nil {
		return nil, fmt.Errorf("failed to create kubernetes client: %w", err)
	}

66
	return &NodeController{
67
68
69
70
		config:     cfg,
		clientset:  clientset,
		containerd: containerd,
		log:        log,
71
		holderID:   "snapshot-agent/" + uuid.NewString(),
72
73
		inFlight:   make(map[string]struct{}),
		stopCh:     make(chan struct{}),
74
75
76
	}, nil
}

77
78
79
// Run starts the local pod informers and processes checkpoint/restore events.
func (w *NodeController) Run(ctx context.Context) error {
	w.log.Info("Starting snapshot node controller",
80
		"node", w.config.NodeName,
81
82
		"checkpoint", snapshotprotocol.CheckpointSourceLabel,
		"restore", snapshotprotocol.RestoreTargetLabel,
83
	)
84

85
86
87
88
89
90
	var nsOptions []informers.SharedInformerOption
	if w.config.RestrictedNamespace != "" {
		w.log.Info("Restricting pod watching to namespace", "namespace", w.config.RestrictedNamespace)
		nsOptions = append(nsOptions, informers.WithNamespace(w.config.RestrictedNamespace))
	} else {
		w.log.Info("Watching pods cluster-wide (all namespaces)")
91
92
	}

93
94
95
96
	var syncFuncs []cache.InformerSynced

	// Checkpoint informer
	checkpointSelector := labels.SelectorFromSet(labels.Set{
97
		snapshotprotocol.CheckpointSourceLabel: "true",
98
99
	}).String()

100
	ckptFactoryOpts := append([]informers.SharedInformerOption{
101
		informers.WithTweakListOptions(func(opts *metav1.ListOptions) {
102
			opts.LabelSelector = checkpointSelector
103
		}),
104
	}, nsOptions...)
105

106
107
	ckptFactory := informers.NewSharedInformerFactoryWithOptions(
		w.clientset, 30*time.Second, ckptFactoryOpts...,
108
109
	)

110
	ckptInformer := ckptFactory.Core().V1().Pods().Informer()
111
	if _, err := ckptInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{
112
		AddFunc: func(obj interface{}) {
113
114
115
116
			pod, ok := podFromInformerObj(obj)
			if !ok {
				return
			}
117
			w.reconcileCheckpointPod(ctx, pod)
118
		},
119
120
121
122
123
		UpdateFunc: func(_, newObj interface{}) {
			pod, ok := podFromInformerObj(newObj)
			if !ok {
				return
			}
124
			w.reconcileCheckpointPod(ctx, pod)
125
		},
126
127
128
	}); err != nil {
		return fmt.Errorf("failed to add checkpoint informer handler: %w", err)
	}
129
130
	go ckptFactory.Start(w.stopCh)
	syncFuncs = append(syncFuncs, ckptInformer.HasSynced)
131

132
133
	// Restore informer
	restoreSelector := labels.SelectorFromSet(labels.Set{
134
		snapshotprotocol.RestoreTargetLabel: "true",
135
	}).String()
136

137
138
139
140
141
	restoreFactoryOpts := append([]informers.SharedInformerOption{
		informers.WithTweakListOptions(func(opts *metav1.ListOptions) {
			opts.LabelSelector = restoreSelector
		}),
	}, nsOptions...)
142

143
144
145
	restoreFactory := informers.NewSharedInformerFactoryWithOptions(
		w.clientset, 30*time.Second, restoreFactoryOpts...,
	)
146

147
	restoreInformer := restoreFactory.Core().V1().Pods().Informer()
148
	if _, err := restoreInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{
149
150
151
152
153
		AddFunc: func(obj interface{}) {
			pod, ok := podFromInformerObj(obj)
			if !ok {
				return
			}
154
			w.reconcileRestorePod(ctx, pod)
155
156
157
158
159
160
		},
		UpdateFunc: func(_, newObj interface{}) {
			pod, ok := podFromInformerObj(newObj)
			if !ok {
				return
			}
161
			w.reconcileRestorePod(ctx, pod)
162
		},
163
164
165
	}); err != nil {
		return fmt.Errorf("failed to add restore informer handler: %w", err)
	}
166
167
	go restoreFactory.Start(w.stopCh)
	syncFuncs = append(syncFuncs, restoreInformer.HasSynced)
168

169
170
	if !cache.WaitForCacheSync(w.stopCh, syncFuncs...) {
		return fmt.Errorf("failed to sync informer caches")
171
172
	}

173
	w.log.Info("Snapshot node controller started and caches synced")
174
	<-ctx.Done()
175
	close(w.stopCh)
176
	return nil
177
178
}

179
func (w *NodeController) reconcileCheckpointPod(ctx context.Context, pod *corev1.Pod) {
180
181
182
	if pod.Spec.NodeName != w.config.NodeName {
		return
	}
183
	if !isPodReady(pod) {
184
185
186
187
188
		return
	}

	podKey := fmt.Sprintf("%s/%s", pod.Namespace, pod.Name)

189
190
191
	checkpointID, ok := pod.Labels[snapshotprotocol.CheckpointIDLabel]
	if !ok || checkpointID == "" {
		w.log.Info("Pod has checkpoint label but no checkpoint-id label", "pod", podKey)
192
193
194
		return
	}

195
196
197
198
199
200
	job, err := getCheckpointJob(ctx, w.clientset, pod)
	if err != nil {
		w.log.Error(err, "Failed to resolve checkpoint job", "pod", podKey)
		return
	}

201
202
	jobStatus := job.Annotations[snapshotprotocol.CheckpointStatusAnnotation]
	if jobStatus == snapshotprotocol.CheckpointStatusCompleted || jobStatus == snapshotprotocol.CheckpointStatusFailed {
203
204
205
		return
	}

206
207
	if !w.tryAcquire(podKey) {
		return
208
209
	}

210
	checkpointLocation, err := w.checkpointLocationFromPod(pod, checkpointID)
211
212
	if err != nil {
		w.release(podKey)
213
		w.log.Error(err, "Checkpoint pod is missing storage metadata", "pod", podKey, "checkpoint_id", checkpointID)
214
215
216
217
218
219
		return
	}

	acquiredLease, err := acquireCheckpointLease(ctx, w.clientset, w.log, job, w.holderID)
	if err != nil {
		w.release(podKey)
220
		w.log.Error(err, "Failed to acquire checkpoint lease", "pod", podKey, "checkpoint_id", checkpointID)
221
222
223
224
225
226
227
		return
	}
	if !acquiredLease {
		w.release(podKey)
		return
	}

228
229
	startedAt := time.Now()
	w.log.Info("Checkpoint target detected, triggering checkpoint", "pod", podKey, "checkpoint_id", checkpointID)
230
	emitPodEvent(ctx, w.clientset, w.log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointRequested", fmt.Sprintf("Checkpoint requested: %s", checkpointID))
231

232
	go func() {
233
		if err := w.runCheckpoint(ctx, pod, job, checkpointID, checkpointLocation, podKey, startedAt); err != nil {
234
			opLog := w.log.WithValues("pod", podKey, "checkpoint_id", checkpointID)
235
			opLog.Error(err, "Checkpoint controller worker failed")
236
237
238
			emitPodEvent(ctx, w.clientset, opLog, pod, "snapshot", corev1.EventTypeWarning, "CheckpointWorkerFailed", err.Error())
		}
	}()
239
240
}

241
func (w *NodeController) reconcileRestorePod(ctx context.Context, pod *corev1.Pod) {
242
243
	if pod.Spec.NodeName != w.config.NodeName {
		return
244
245
	}

246
	podKey := fmt.Sprintf("%s/%s", pod.Namespace, pod.Name)
247

248
	if pod.Status.Phase != corev1.PodRunning {
249
250
251
		return
	}

252
253
254
	checkpointID, ok := pod.Labels[snapshotprotocol.CheckpointIDLabel]
	if !ok || checkpointID == "" {
		w.log.Info("Restore pod has no checkpoint-id label", "pod", podKey)
255
256
257
		return
	}

258
259
	if strings.ContainsAny(checkpointID, "/\\") || strings.Contains(checkpointID, "..") || filepath.Clean(checkpointID) != checkpointID {
		w.log.Error(fmt.Errorf("invalid checkpoint id %q", checkpointID), "Invalid checkpoint id on restore pod", "pod", podKey)
260
261
262
		return
	}

263
	checkpointLocation, err := w.checkpointLocationFromPod(pod, checkpointID)
264
	if err != nil {
265
		w.log.Error(err, "Restore pod is missing storage metadata", "pod", podKey, "checkpoint_id", checkpointID)
266
267
268
		return
	}
	if _, err := os.Stat(checkpointLocation); os.IsNotExist(err) {
269
		w.log.V(1).Info("Checkpoint not ready on disk, skipping restore", "pod", podKey, "checkpoint_id", checkpointID, "checkpoint_location", checkpointLocation)
270
		return
271
272
	}

273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
	containerName := resolveMainContainerName(pod)
	if containerName == "" {
		w.log.Info("Restore pod has no containers", "pod", podKey)
		return
	}

	containerID := ""
	for _, cs := range pod.Status.ContainerStatuses {
		if cs.Name != containerName || cs.ContainerID == "" {
			continue
		}
		containerID = strings.TrimPrefix(cs.ContainerID, "containerd://")
		break
	}
	if containerID == "" {
		w.log.V(1).Info("Restore pod has no running main container yet", "pod", podKey, "container", containerName)
		return
	}

292
293
294
	annotationStatus := pod.Annotations[snapshotprotocol.RestoreStatusAnnotation]
	annotationContainerID := pod.Annotations[snapshotprotocol.RestoreContainerIDAnnotation]
	if annotationContainerID == containerID && (annotationStatus == snapshotprotocol.RestoreStatusCompleted || annotationStatus == snapshotprotocol.RestoreStatusInProgress) {
295
296
297
298
299
		return
	}

	restoreAttemptKey := fmt.Sprintf("%s/%s", podKey, containerID)
	if !w.tryAcquire(restoreAttemptKey) {
300
301
302
		return
	}

303
304
	startedAt := time.Now()
	w.log.Info("Restore target detected, triggering external restore", "pod", podKey, "checkpoint_id", checkpointID)
305
	emitPodEvent(ctx, w.clientset, w.log, pod, "snapshot", corev1.EventTypeNormal, "RestoreRequested", fmt.Sprintf("Restore requested from checkpoint %s", checkpointID))
306

307
	go func() {
308
		if err := w.runRestore(ctx, pod, containerName, containerID, checkpointID, checkpointLocation, restoreAttemptKey, startedAt); err != nil {
309
			opLog := w.log.WithValues("pod", podKey, "checkpoint_id", checkpointID)
310
			opLog.Error(err, "Restore controller worker failed")
311
312
313
			emitPodEvent(ctx, w.clientset, opLog, pod, "snapshot", corev1.EventTypeWarning, "RestoreWorkerFailed", err.Error())
		}
	}()
314
}
315

316
317
// runCheckpoint runs the full checkpoint workflow for a pod:
//  1. Hold and renew the checkpoint lease
318
//  2. Resolve the container ID and host PID
319
//  3. Call executor.Checkpoint (inspect → configure → CUDA lock/checkpoint → CRIU dump → rootfs diff)
320
//  4. SIGUSR1 the process on success (notify workload), SIGKILL on failure (terminate immediately)
321
//  5. Mark job as completed or failed
322
func (w *NodeController) runCheckpoint(ctx context.Context, pod *corev1.Pod, job *batchv1.Job, checkpointID, checkpointLocation, podKey string, startedAt time.Time) error {
323
	releasePodOnExit := true
324
	defer func() {
325
		if releasePodOnExit {
326
327
328
			w.release(podKey)
		}
	}()
329
	log := w.log.WithValues("pod", podKey, "checkpoint_id", checkpointID)
330
331
	leaseCtx, stopLease := context.WithCancelCause(ctx)
	defer stopLease(nil)
332

333
334
335
336
	releaseLeaseOnExit := true
	defer func() {
		if !releaseLeaseOnExit {
			return
337
		}
338
339
340
341
342
343
		releaseCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
		defer cancel()
		if err := releaseCheckpointLease(releaseCtx, w.clientset, log, job, w.holderID); err != nil {
			log.Error(err, "Failed to release checkpoint lease")
		}
	}()
344

345
346
347
348
	go w.renewCheckpointLease(leaseCtx, log, job, stopLease)

	setCheckpointStatus := func(value string) error {
		if err := annotateJob(ctx, w.clientset, log, job, map[string]string{
349
			snapshotprotocol.CheckpointStatusAnnotation: value,
350
351
352
353
		}); err != nil {
			releasePodOnExit = false
			releaseLeaseOnExit = false
			return fmt.Errorf("failed to persist terminal checkpoint status %q: %w", value, err)
354
355
356
		}
		return nil
	}
357
358
359
360
361
362

	// Resolve the target container
	containerName := resolveMainContainerName(pod)
	if containerName == "" {
		err := fmt.Errorf("no containers found in pod spec")
		log.Error(err, "Checkpoint failed")
363
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
364
		if statusErr := setCheckpointStatus(snapshotprotocol.CheckpointStatusFailed); statusErr != nil {
365
366
367
			return statusErr
		}
		return nil
368
369
370
371
372
373
374
375
376
	}
	var containerID string
	for _, cs := range pod.Status.ContainerStatuses {
		if cs.Name == containerName {
			containerID = strings.TrimPrefix(cs.ContainerID, "containerd://")
			break
		}
	}
	if containerID == "" {
377
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", "Could not resolve target container ID")
378
		if statusErr := setCheckpointStatus(snapshotprotocol.CheckpointStatusFailed); statusErr != nil {
379
380
381
			return statusErr
		}
		return nil
382
383
	}

384
	// Resolve the container's host PID (needed for signaling after checkpoint)
385
	containerPID, _, err := snapshotruntime.ResolveContainer(ctx, w.containerd, containerID)
386
	if err != nil {
387
		log.Error(err, "Failed to resolve container")
388
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", fmt.Sprintf("Container resolve failed: %v", err))
389
		if statusErr := setCheckpointStatus(snapshotprotocol.CheckpointStatusFailed); statusErr != nil {
390
391
392
			return statusErr
		}
		return nil
393
394
	}

395
	// Step 1: Run the checkpoint orchestrator
396
	req := executor.CheckpointRequest{
397
398
399
400
		ContainerID:        containerID,
		ContainerName:      containerName,
		CheckpointID:       checkpointID,
		CheckpointLocation: checkpointLocation,
401
		StartedAt:          startedAt,
402
403
404
405
		NodeName:           w.config.NodeName,
		PodName:            pod.Name,
		PodNamespace:       pod.Namespace,
		Clientset:          w.clientset,
406
407
408
409
410
	}
	if err := executor.Checkpoint(leaseCtx, w.containerd, log, req, w.config); err != nil {
		if cause := context.Cause(leaseCtx); cause != nil && cause != context.Canceled {
			err = fmt.Errorf("checkpoint lease lost: %w", cause)
		}
411
		log.Error(err, "Checkpoint failed")
412
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
413
		// SIGKILL on failure: process is unrecoverable (CUDA locked), terminate immediately
414
		if signalErr := snapshotruntime.SendSignalToPID(log, containerPID, syscall.SIGKILL, "checkpoint failed"); signalErr != nil {
415
416
			log.Error(signalErr, "Failed to signal checkpoint failure to runtime process")
		}
417
		if statusErr := setCheckpointStatus(snapshotprotocol.CheckpointStatusFailed); statusErr != nil {
418
419
420
			return statusErr
		}
		return nil
421
422
	}

423
424
425
426
427
428
429
430
431
	info, err := os.Stat(checkpointLocation)
	if err != nil || !info.IsDir() {
		if err == nil {
			err = fmt.Errorf("published checkpoint path %s is not a directory", checkpointLocation)
		} else {
			err = fmt.Errorf("published checkpoint path %s is missing: %w", checkpointLocation, err)
		}
		log.Error(err, "Checkpoint failed verification")
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
432
		if signalErr := snapshotruntime.SendSignalToPID(log, containerPID, syscall.SIGKILL, "checkpoint verification failed"); signalErr != nil {
433
434
			log.Error(signalErr, "Failed to signal checkpoint verification failure to runtime process")
		}
435
		if statusErr := setCheckpointStatus(snapshotprotocol.CheckpointStatusFailed); statusErr != nil {
436
437
438
439
440
			return statusErr
		}
		return nil
	}

441
	// Step 2: SIGUSR1 on success: notify the workload that checkpoint completed
442
443
	emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointSucceeded", fmt.Sprintf("Checkpoint completed: %s", checkpointID))
	if err := snapshotruntime.SendSignalToPID(log, containerPID, syscall.SIGUSR1, "checkpoint complete"); err != nil {
444
		log.Error(err, "Failed to signal checkpoint completion to runtime process")
445
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
446
		if statusErr := setCheckpointStatus(snapshotprotocol.CheckpointStatusFailed); statusErr != nil {
447
448
449
			return statusErr
		}
		return nil
450
451
	}

452
	if err := setCheckpointStatus(snapshotprotocol.CheckpointStatusCompleted); err != nil {
453
454
455
		return err
	}
	return nil
456
457
}

458
459
460
// runRestore runs the full restore workflow for a pod:
//  1. Mark the current container instance as in_progress
//  2. Call executor.Restore (inspect placeholder → nsrestore inside namespace)
461
462
//  3. SIGCONT the restored process to wake it up
//  4. Wait for the pod to become Ready
463
//  5. Mark the container instance as completed
464
func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, containerName, containerID, checkpointID, checkpointLocation, restoreAttemptKey string, startedAt time.Time) error {
465
466
467
	releaseOnExit := true
	defer func() {
		if releaseOnExit {
468
			w.release(restoreAttemptKey)
469
470
		}
	}()
471
	podKey := fmt.Sprintf("%s/%s", pod.Namespace, pod.Name)
472
	log := w.log.WithValues("pod", podKey, "checkpoint_id", checkpointID, "container_id", containerID)
473
474
	setRestoreStatus := func(value string) error {
		annotations := map[string]string{
475
476
			snapshotprotocol.RestoreStatusAnnotation:      value,
			snapshotprotocol.RestoreContainerIDAnnotation: containerID,
477
		}
478
		if err := annotatePod(ctx, w.clientset, log, pod, annotations); err != nil {
479
			if value == snapshotprotocol.RestoreStatusCompleted {
480
481
482
483
484
485
486
				releaseOnExit = false
				return fmt.Errorf("failed to persist terminal restore status %q: %w", value, err)
			}
			return fmt.Errorf("failed to update restore status %q: %w", value, err)
		}
		return nil
	}
487
488

	if err := annotatePod(ctx, w.clientset, log, pod, map[string]string{
489
490
		snapshotprotocol.RestoreStatusAnnotation:      snapshotprotocol.RestoreStatusInProgress,
		snapshotprotocol.RestoreContainerIDAnnotation: containerID,
491
	}); err != nil {
492
		return fmt.Errorf("failed to annotate pod with restore in_progress: %w", err)
493
	}
494

495
	// Step 1: Run the restore orchestrator (inspect + nsrestore)
496
	req := executor.RestoreRequest{
497
498
		CheckpointID:       checkpointID,
		CheckpointLocation: checkpointLocation,
499
		StartedAt:          startedAt,
500
501
502
503
504
		NSRestorePath:      w.config.Restore.NSRestorePath,
		PodName:            pod.Name,
		PodNamespace:       pod.Namespace,
		ContainerName:      containerName,
		Clientset:          w.clientset,
505
506
	}
	restoredPID, err := executor.Restore(ctx, w.containerd, log, req)
507
508
	if err != nil {
		log.Error(err, "External restore failed")
509
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
510
		placeholderHostPID, _, pidErr := snapshotruntime.ResolveContainerByPod(ctx, w.containerd, pod.Name, pod.Namespace, containerName)
511
512
513
514
		if pidErr != nil {
			releaseOnExit = false
			return fmt.Errorf("restore failed and placeholder PID could not be resolved: %w", pidErr)
		}
515
		if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore failed"); killErr != nil {
516
517
			releaseOnExit = false
			return fmt.Errorf("restore failed and placeholder could not be killed: %w", killErr)
518
519
		}
		return nil
520
521
	}

522
	// Step 2: SIGCONT the restored process via PID namespace
523
	placeholderHostPID, _, err := snapshotruntime.ResolveContainerByPod(ctx, w.containerd, pod.Name, pod.Namespace, containerName)
524
	if err != nil {
525
		log.Error(err, "Failed to resolve placeholder host PID for signaling")
526
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
527
528
		releaseOnExit = false
		return fmt.Errorf("failed to resolve placeholder host PID for signaling: %w", err)
529
	}
530
	if err := snapshotruntime.SendSignalViaPIDNamespace(ctx, log, placeholderHostPID, restoredPID, syscall.SIGCONT, "restore complete"); err != nil {
531
		log.Error(err, "Failed to signal restored runtime process")
532
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
533
		if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore signaling failed"); killErr != nil {
534
			log.Error(killErr, "Failed to kill placeholder after restore signaling failure")
535
		}
536
537
		releaseOnExit = false
		return fmt.Errorf("failed to signal restored runtime process: %w", err)
538
539
	}

540
541
542
543
544
545
546
547
548
	// Step 3: Wait for the pod to become Ready
	readyCtx := ctx
	if timeout := w.config.Restore.RestoreReadyTimeout(); timeout > 0 {
		var cancel context.CancelFunc
		readyCtx, cancel = context.WithTimeout(ctx, timeout)
		defer cancel()
	}
	if err := waitForPodReady(readyCtx, w.clientset, pod.Namespace, pod.Name, containerName); err != nil {
		log.Error(err, "Restore post-signal readiness check failed")
549
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
550
		if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore readiness failed"); killErr != nil {
551
			log.Error(killErr, "Failed to kill placeholder after restore readiness failure")
552
		}
553
554
		releaseOnExit = false
		return fmt.Errorf("restore post-signal readiness check failed: %w", err)
555
556
	}

557
558
	emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "RestoreSucceeded", fmt.Sprintf("Restore completed from checkpoint %s", checkpointID))
	if err := setRestoreStatus(snapshotprotocol.RestoreStatusCompleted); err != nil {
559
560
561
		return err
	}
	return nil
562
563
}

564
func (w *NodeController) tryAcquire(podKey string) bool {
565
566
567
568
569
570
571
572
573
	w.inFlightMu.Lock()
	defer w.inFlightMu.Unlock()
	if _, held := w.inFlight[podKey]; held {
		return false
	}
	w.inFlight[podKey] = struct{}{}
	return true
}

574
func (w *NodeController) release(podKey string) {
575
576
577
	w.inFlightMu.Lock()
	defer w.inFlightMu.Unlock()
	delete(w.inFlight, podKey)
578
}
579

580
581
582
583
584
585
586
587
588
589
590
func (w *NodeController) checkpointLocationFromPod(pod *corev1.Pod, checkpointID string) (string, error) {
	resolvedStorage, err := snapshotprotocol.ResolveCheckpointStorage(
		checkpointID,
		strings.TrimSpace(pod.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation]),
		snapshotprotocol.Storage{
			Type:     w.config.Storage.Type,
			BasePath: w.config.Storage.BasePath,
		},
	)
	if err != nil {
		return "", err
591
	}
592
	return resolvedStorage.Location, nil
593
}