watcher.go 17 KB
Newer Older
1
// Package watcher provides Kubernetes pod watching for automatic checkpoint/restore.
2
// The watcher is the sole entry point for snapshot operations — it detects pods with
3
// checkpoint/restore labels and calls the orchestrators directly.
4
5
6
7
8
9
10
package watcher

import (
	"context"
	"fmt"
	"os"
	"path/filepath"
11
	"strings"
12
	"sync"
13
	"syscall"
14
15
	"time"

16
17
	"github.com/containerd/containerd"
	"github.com/go-logr/logr"
18
19
20
21
22
23
24
25
	corev1 "k8s.io/api/core/v1"
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
	"k8s.io/apimachinery/pkg/labels"
	"k8s.io/client-go/informers"
	"k8s.io/client-go/kubernetes"
	"k8s.io/client-go/rest"
	"k8s.io/client-go/tools/cache"

26
27
28
	"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/common"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/orchestrate"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/types"
29
30
)

31
const (
32
33
34
35
36
	kubeLabelIsCheckpointSource    = "nvidia.com/snapshot-is-checkpoint-source"
	kubeLabelCheckpointHash        = "nvidia.com/snapshot-checkpoint-hash"
	kubeLabelIsRestoreTarget       = "nvidia.com/snapshot-is-restore-target"
	kubeAnnotationCheckpointStatus = "nvidia.com/snapshot-checkpoint-status"
	kubeAnnotationRestoreStatus    = "nvidia.com/snapshot-restore-status"
37
)
38

39
// Watcher watches for pods with checkpoint/restore labels and triggers operations.
40
type Watcher struct {
41
42
43
44
	config     *types.AgentConfig
	clientset  kubernetes.Interface
	containerd *containerd.Client
	log        logr.Logger
45

46
47
	inFlight   map[string]struct{}
	inFlightMu sync.Mutex
48
49
50
51

	stopCh chan struct{}
}

52
53
54
55
56
57
// NewWatcher creates a new pod watcher.
func NewWatcher(
	cfg *types.AgentConfig,
	containerd *containerd.Client,
	log logr.Logger,
) (*Watcher, error) {
58
59
60
61
62
63
64
65
66
67
68
	restConfig, err := rest.InClusterConfig()
	if err != nil {
		return nil, fmt.Errorf("failed to get in-cluster config: %w", err)
	}

	clientset, err := kubernetes.NewForConfig(restConfig)
	if err != nil {
		return nil, fmt.Errorf("failed to create kubernetes client: %w", err)
	}

	return &Watcher{
69
70
71
72
		config:     cfg,
		clientset:  clientset,
		containerd: containerd,
		log:        log,
73
74
		inFlight:   make(map[string]struct{}),
		stopCh:     make(chan struct{}),
75
76
77
	}, nil
}

78
// Start begins watching for pods and processing checkpoint/restore events.
79
func (w *Watcher) Start(ctx context.Context) error {
80
81
82
83
84
	w.log.Info("Starting pod watcher",
		"node", w.config.NodeName,
		"checkpoint", kubeLabelIsCheckpointSource,
		"restore", kubeLabelIsRestoreTarget,
	)
85

86
87
88
89
90
91
	var nsOptions []informers.SharedInformerOption
	if w.config.RestrictedNamespace != "" {
		w.log.Info("Restricting pod watching to namespace", "namespace", w.config.RestrictedNamespace)
		nsOptions = append(nsOptions, informers.WithNamespace(w.config.RestrictedNamespace))
	} else {
		w.log.Info("Watching pods cluster-wide (all namespaces)")
92
93
	}

94
95
96
97
98
	var syncFuncs []cache.InformerSynced

	// Checkpoint informer
	checkpointSelector := labels.SelectorFromSet(labels.Set{
		kubeLabelIsCheckpointSource: "true",
99
100
	}).String()

101
	ckptFactoryOpts := append([]informers.SharedInformerOption{
102
		informers.WithTweakListOptions(func(opts *metav1.ListOptions) {
103
			opts.LabelSelector = checkpointSelector
104
		}),
105
	}, nsOptions...)
106

107
108
	ckptFactory := informers.NewSharedInformerFactoryWithOptions(
		w.clientset, 30*time.Second, ckptFactoryOpts...,
109
110
	)

111
	ckptInformer := ckptFactory.Core().V1().Pods().Informer()
112
	if _, err := ckptInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{
113
		AddFunc: func(obj interface{}) {
114
115
116
117
118
			pod, ok := podFromInformerObj(obj)
			if !ok {
				return
			}
			w.handleCheckpointPodEvent(ctx, pod)
119
		},
120
121
122
123
124
125
		UpdateFunc: func(_, newObj interface{}) {
			pod, ok := podFromInformerObj(newObj)
			if !ok {
				return
			}
			w.handleCheckpointPodEvent(ctx, pod)
126
		},
127
128
129
	}); err != nil {
		return fmt.Errorf("failed to add checkpoint informer handler: %w", err)
	}
130
131
	go ckptFactory.Start(w.stopCh)
	syncFuncs = append(syncFuncs, ckptInformer.HasSynced)
132

133
134
135
136
	// Restore informer
	restoreSelector := labels.SelectorFromSet(labels.Set{
		kubeLabelIsRestoreTarget: "true",
	}).String()
137

138
139
140
141
142
	restoreFactoryOpts := append([]informers.SharedInformerOption{
		informers.WithTweakListOptions(func(opts *metav1.ListOptions) {
			opts.LabelSelector = restoreSelector
		}),
	}, nsOptions...)
143

144
145
146
	restoreFactory := informers.NewSharedInformerFactoryWithOptions(
		w.clientset, 30*time.Second, restoreFactoryOpts...,
	)
147

148
	restoreInformer := restoreFactory.Core().V1().Pods().Informer()
149
	if _, err := restoreInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{
150
151
152
153
154
155
156
157
158
159
160
161
162
163
		AddFunc: func(obj interface{}) {
			pod, ok := podFromInformerObj(obj)
			if !ok {
				return
			}
			w.handleRestorePodEvent(ctx, pod)
		},
		UpdateFunc: func(_, newObj interface{}) {
			pod, ok := podFromInformerObj(newObj)
			if !ok {
				return
			}
			w.handleRestorePodEvent(ctx, pod)
		},
164
165
166
	}); err != nil {
		return fmt.Errorf("failed to add restore informer handler: %w", err)
	}
167
168
	go restoreFactory.Start(w.stopCh)
	syncFuncs = append(syncFuncs, restoreInformer.HasSynced)
169

170
171
	if !cache.WaitForCacheSync(w.stopCh, syncFuncs...) {
		return fmt.Errorf("failed to sync informer caches")
172
173
	}

174
175
	w.log.Info("Pod watcher started and caches synced")
	<-ctx.Done()
176
	close(w.stopCh)
177
	return nil
178
179
}

180
func (w *Watcher) handleCheckpointPodEvent(ctx context.Context, pod *corev1.Pod) {
181
182
183
	if pod.Spec.NodeName != w.config.NodeName {
		return
	}
184
	if !isPodReady(pod) {
185
186
187
188
189
		return
	}

	podKey := fmt.Sprintf("%s/%s", pod.Namespace, pod.Name)

190
191
192
	checkpointHash, ok := pod.Labels[kubeLabelCheckpointHash]
	if !ok || checkpointHash == "" {
		w.log.Info("Pod has checkpoint label but no checkpoint-hash label", "pod", podKey)
193
194
195
		return
	}

196
197
	annotationStatus := pod.Annotations[kubeAnnotationCheckpointStatus]
	if annotationStatus == "completed" || annotationStatus == "in_progress" {
198
199
200
		return
	}

201
202
	if !w.tryAcquire(podKey) {
		return
203
204
	}

205
	w.log.Info("Pod ready, triggering checkpoint", "pod", podKey, "checkpoint_hash", checkpointHash)
206
	emitPodEvent(ctx, w.clientset, w.log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointRequested", fmt.Sprintf("Checkpoint requested: %s", checkpointHash))
207

208
209
210
211
212
213
214
	go func() {
		if err := w.doCheckpoint(ctx, pod, checkpointHash, podKey); err != nil {
			opLog := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash)
			opLog.Error(err, "Checkpoint worker failed")
			emitPodEvent(ctx, w.clientset, opLog, pod, "snapshot", corev1.EventTypeWarning, "CheckpointWorkerFailed", err.Error())
		}
	}()
215
216
}

217
218
219
func (w *Watcher) handleRestorePodEvent(ctx context.Context, pod *corev1.Pod) {
	if pod.Spec.NodeName != w.config.NodeName {
		return
220
221
	}

222
	podKey := fmt.Sprintf("%s/%s", pod.Namespace, pod.Name)
223

224
	if pod.Status.Phase != corev1.PodRunning {
225
226
227
		return
	}

228
229
230
231
	annotationStatus := pod.Annotations[kubeAnnotationRestoreStatus]

	if isPodReady(pod) {
		return
232
233
	}

234
235
236
237
	// Restore failures require explicit intervention (new label/update) before retry.
	if annotationStatus == "completed" || annotationStatus == "in_progress" || annotationStatus == "failed" {
		return
	}
238

239
240
241
	checkpointHash, ok := pod.Labels[kubeLabelCheckpointHash]
	if !ok || checkpointHash == "" {
		w.log.Info("Restore pod has no checkpoint-hash label", "pod", podKey)
242
243
244
		return
	}

245
246
	if strings.ContainsAny(checkpointHash, "/\\") || strings.Contains(checkpointHash, "..") || filepath.Clean(checkpointHash) != checkpointHash {
		w.log.Error(fmt.Errorf("invalid checkpoint hash %q", checkpointHash), "Invalid checkpoint hash on restore pod", "pod", podKey)
247
248
249
		return
	}

250
251
252
253
	checkpointDir := filepath.Join(w.config.BasePath, checkpointHash)
	if _, err := os.Stat(checkpointDir); os.IsNotExist(err) {
		w.log.V(1).Info("Checkpoint not ready on disk, skipping restore", "pod", podKey, "checkpoint_hash", checkpointHash)
		return
254
255
	}

256
	if !w.tryAcquire(podKey) {
257
258
259
		return
	}

260
	w.log.Info("Restore pod running, triggering external restore", "pod", podKey, "checkpoint_hash", checkpointHash)
261
	emitPodEvent(ctx, w.clientset, w.log, pod, "snapshot", corev1.EventTypeNormal, "RestoreRequested", fmt.Sprintf("Restore requested from checkpoint %s", checkpointHash))
262

263
264
265
266
267
268
269
	go func() {
		if err := w.doRestore(ctx, pod, checkpointHash, podKey); err != nil {
			opLog := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash)
			opLog.Error(err, "Restore worker failed")
			emitPodEvent(ctx, w.clientset, opLog, pod, "snapshot", corev1.EventTypeWarning, "RestoreWorkerFailed", err.Error())
		}
	}()
270
}
271

272
273
274
275
// doCheckpoint runs the full checkpoint workflow for a pod:
//  1. Mark pod as in_progress
//  2. Resolve the container ID and host PID
//  3. Call orchestrate.Checkpoint (inspect → configure → CUDA lock/checkpoint → CRIU dump → rootfs diff)
276
//  4. SIGUSR1 the process on success (notify workload), SIGKILL on failure (terminate immediately)
277
//  5. Mark pod as completed or failed
278
279
280
281
282
283
284
func (w *Watcher) doCheckpoint(ctx context.Context, pod *corev1.Pod, checkpointHash, podKey string) error {
	releaseOnExit := true
	defer func() {
		if releaseOnExit {
			w.release(podKey)
		}
	}()
285
	log := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash)
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
	setCheckpointStatus := func(value string) error {
		annotations := map[string]string{
			kubeAnnotationCheckpointStatus: value,
		}

		if value == "failed" || value == "completed" {
			if err := annotatePodRetry(ctx, w.clientset, log, pod, annotations); err != nil {
				releaseOnExit = false
				return fmt.Errorf("failed to persist terminal checkpoint status %q: %w", value, err)
			}
			return nil
		}

		if err := annotatePod(ctx, w.clientset, log, pod, annotations); err != nil {
			return fmt.Errorf("failed to update checkpoint status %q: %w", value, err)
		}
		return nil
	}
304
305
306
307

	if err := annotatePod(ctx, w.clientset, log, pod, map[string]string{
		kubeAnnotationCheckpointStatus: "in_progress",
	}); err != nil {
308
		return fmt.Errorf("failed to annotate pod with checkpoint in_progress: %w", err)
309
310
	}

311
312
313
314
315
	// Resolve the target container
	containerName := resolveMainContainerName(pod)
	if containerName == "" {
		err := fmt.Errorf("no containers found in pod spec")
		log.Error(err, "Checkpoint failed")
316
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
317
318
319
320
		if statusErr := setCheckpointStatus("failed"); statusErr != nil {
			return statusErr
		}
		return nil
321
322
323
324
325
326
327
328
329
	}
	var containerID string
	for _, cs := range pod.Status.ContainerStatuses {
		if cs.Name == containerName {
			containerID = strings.TrimPrefix(cs.ContainerID, "containerd://")
			break
		}
	}
	if containerID == "" {
330
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", "Could not resolve target container ID")
331
332
333
334
		if statusErr := setCheckpointStatus("failed"); statusErr != nil {
			return statusErr
		}
		return nil
335
336
	}

337
338
	// Resolve the container's host PID (needed for signaling after checkpoint)
	containerPID, _, err := common.ResolveContainer(ctx, w.containerd, containerID)
339
	if err != nil {
340
		log.Error(err, "Failed to resolve container")
341
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", fmt.Sprintf("Container resolve failed: %v", err))
342
343
344
345
		if statusErr := setCheckpointStatus("failed"); statusErr != nil {
			return statusErr
		}
		return nil
346
347
	}

348
349
350
351
352
353
354
355
356
357
358
359
	// Step 1: Run the checkpoint orchestrator
	req := orchestrate.CheckpointRequest{
		ContainerID:    containerID,
		ContainerName:  containerName,
		CheckpointHash: checkpointHash,
		CheckpointDir:  w.config.BasePath,
		NodeName:       w.config.NodeName,
		PodName:        pod.Name,
		PodNamespace:   pod.Namespace,
	}
	if err := orchestrate.Checkpoint(ctx, w.containerd, log, req, w.config); err != nil {
		log.Error(err, "Checkpoint failed")
360
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
361
362
		// SIGKILL on failure: process is unrecoverable (CUDA locked), terminate immediately
		if signalErr := common.SendSignalToPID(log, containerPID, syscall.SIGKILL, "checkpoint failed"); signalErr != nil {
363
364
			log.Error(signalErr, "Failed to signal checkpoint failure to runtime process")
		}
365
366
367
368
		if statusErr := setCheckpointStatus("failed"); statusErr != nil {
			return statusErr
		}
		return nil
369
370
	}

371
	// Step 2: SIGUSR1 on success: notify the workload that checkpoint completed
372
	emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointSucceeded", fmt.Sprintf("Checkpoint completed: %s", checkpointHash))
373
374
	if err := common.SendSignalToPID(log, containerPID, syscall.SIGUSR1, "checkpoint complete"); err != nil {
		log.Error(err, "Failed to signal checkpoint completion to runtime process")
375
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
376
377
378
379
		if statusErr := setCheckpointStatus("failed"); statusErr != nil {
			return statusErr
		}
		return nil
380
381
	}

382
383
384
385
	if err := setCheckpointStatus("completed"); err != nil {
		return err
	}
	return nil
386
387
}

388
389
390
391
392
393
// doRestore runs the full restore workflow for a pod:
//  1. Mark pod as in_progress
//  2. Call orchestrate.Restore (inspect placeholder → nsrestore inside namespace)
//  3. SIGCONT the restored process to wake it up
//  4. Wait for the pod to become Ready
//  5. Mark pod as completed or failed
394
395
396
397
398
399
400
func (w *Watcher) doRestore(ctx context.Context, pod *corev1.Pod, checkpointHash, podKey string) error {
	releaseOnExit := true
	defer func() {
		if releaseOnExit {
			w.release(podKey)
		}
	}()
401
	log := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash)
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
	setRestoreStatus := func(value string) error {
		annotations := map[string]string{
			kubeAnnotationRestoreStatus: value,
		}

		if value == "failed" || value == "completed" {
			if err := annotatePodRetry(ctx, w.clientset, log, pod, annotations); err != nil {
				releaseOnExit = false
				return fmt.Errorf("failed to persist terminal restore status %q: %w", value, err)
			}
			return nil
		}

		if err := annotatePod(ctx, w.clientset, log, pod, annotations); err != nil {
			return fmt.Errorf("failed to update restore status %q: %w", value, err)
		}
		return nil
	}
420
421
422
423

	if err := annotatePod(ctx, w.clientset, log, pod, map[string]string{
		kubeAnnotationRestoreStatus: "in_progress",
	}); err != nil {
424
		return fmt.Errorf("failed to annotate pod with restore in_progress: %w", err)
425
	}
426

427
428
429
430
	containerName := resolveMainContainerName(pod)
	if containerName == "" {
		err := fmt.Errorf("no containers found in pod spec")
		log.Error(err, "Restore failed")
431
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
432
433
434
435
		if statusErr := setRestoreStatus("failed"); statusErr != nil {
			return statusErr
		}
		return nil
436
437
438
439
440
441
442
443
444
445
446
447
448
449
	}

	// Step 1: Run the restore orchestrator (inspect + nsrestore)
	req := orchestrate.RestoreRequest{
		CheckpointHash: checkpointHash,
		CheckpointBase: w.config.BasePath,
		NSRestorePath:  w.config.Restore.NSRestorePath,
		PodName:        pod.Name,
		PodNamespace:   pod.Namespace,
		ContainerName:  containerName,
	}
	restoredPID, err := orchestrate.Restore(ctx, w.containerd, log, req)
	if err != nil {
		log.Error(err, "External restore failed")
450
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
451
452
453
454
		if statusErr := setRestoreStatus("failed"); statusErr != nil {
			return statusErr
		}
		return nil
455
456
	}

457
458
	// Step 2: SIGCONT the restored process via PID namespace
	placeholderHostPID, _, err := common.ResolveContainerByPod(ctx, w.containerd, pod.Name, pod.Namespace, containerName)
459
	if err != nil {
460
		log.Error(err, "Failed to resolve placeholder host PID for signaling")
461
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
462
463
464
465
		if statusErr := setRestoreStatus("failed"); statusErr != nil {
			return statusErr
		}
		return nil
466
467
468
	}
	if err := common.SendSignalViaPIDNamespace(ctx, log, placeholderHostPID, restoredPID, syscall.SIGCONT, "restore complete"); err != nil {
		log.Error(err, "Failed to signal restored runtime process")
469
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
470
471
472
473
		if statusErr := setRestoreStatus("failed"); statusErr != nil {
			return statusErr
		}
		return nil
474
475
	}

476
477
478
479
480
481
482
483
484
	// Step 3: Wait for the pod to become Ready
	readyCtx := ctx
	if timeout := w.config.Restore.RestoreReadyTimeout(); timeout > 0 {
		var cancel context.CancelFunc
		readyCtx, cancel = context.WithTimeout(ctx, timeout)
		defer cancel()
	}
	if err := waitForPodReady(readyCtx, w.clientset, pod.Namespace, pod.Name, containerName); err != nil {
		log.Error(err, "Restore post-signal readiness check failed")
485
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
486
487
488
489
		if statusErr := setRestoreStatus("failed"); statusErr != nil {
			return statusErr
		}
		return nil
490
491
	}

492
	emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "RestoreSucceeded", fmt.Sprintf("Restore completed from checkpoint %s", checkpointHash))
493
494
495
496
	if err := setRestoreStatus("completed"); err != nil {
		return err
	}
	return nil
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
}

func (w *Watcher) tryAcquire(podKey string) bool {
	w.inFlightMu.Lock()
	defer w.inFlightMu.Unlock()
	if _, held := w.inFlight[podKey]; held {
		return false
	}
	w.inFlight[podKey] = struct{}{}
	return true
}

func (w *Watcher) release(podKey string) {
	w.inFlightMu.Lock()
	defer w.inFlightMu.Unlock()
	delete(w.inFlight, podKey)
513
}