watcher.go 15 KB
Newer Older
1
// Package watcher provides Kubernetes pod watching for automatic checkpoint/restore.
2
// The watcher is the sole entry point for snapshot operations — it detects pods with
3
// checkpoint/restore labels and calls the orchestrators directly.
4
5
6
7
8
9
10
package watcher

import (
	"context"
	"fmt"
	"os"
	"path/filepath"
11
	"strings"
12
	"sync"
13
	"syscall"
14
15
	"time"

16
17
	"github.com/containerd/containerd"
	"github.com/go-logr/logr"
18
19
20
21
22
23
24
25
	corev1 "k8s.io/api/core/v1"
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
	"k8s.io/apimachinery/pkg/labels"
	"k8s.io/client-go/informers"
	"k8s.io/client-go/kubernetes"
	"k8s.io/client-go/rest"
	"k8s.io/client-go/tools/cache"

26
27
28
	"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/common"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/orchestrate"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/types"
29
30
)

31
const (
32
33
34
35
36
	kubeLabelIsCheckpointSource    = "nvidia.com/snapshot-is-checkpoint-source"
	kubeLabelCheckpointHash        = "nvidia.com/snapshot-checkpoint-hash"
	kubeLabelIsRestoreTarget       = "nvidia.com/snapshot-is-restore-target"
	kubeAnnotationCheckpointStatus = "nvidia.com/snapshot-checkpoint-status"
	kubeAnnotationRestoreStatus    = "nvidia.com/snapshot-restore-status"
37
)
38

39
// Watcher watches for pods with checkpoint/restore labels and triggers operations.
40
type Watcher struct {
41
42
43
44
	config     *types.AgentConfig
	clientset  kubernetes.Interface
	containerd *containerd.Client
	log        logr.Logger
45

46
47
	inFlight   map[string]struct{}
	inFlightMu sync.Mutex
48
49
50
51

	stopCh chan struct{}
}

52
53
54
55
56
57
// NewWatcher creates a new pod watcher.
func NewWatcher(
	cfg *types.AgentConfig,
	containerd *containerd.Client,
	log logr.Logger,
) (*Watcher, error) {
58
59
60
61
62
63
64
65
66
67
68
	restConfig, err := rest.InClusterConfig()
	if err != nil {
		return nil, fmt.Errorf("failed to get in-cluster config: %w", err)
	}

	clientset, err := kubernetes.NewForConfig(restConfig)
	if err != nil {
		return nil, fmt.Errorf("failed to create kubernetes client: %w", err)
	}

	return &Watcher{
69
70
71
72
73
		config:     cfg,
		clientset:  clientset,
		containerd: containerd,
		log:        log,
		inFlight:        make(map[string]struct{}),
74
75
76
77
		stopCh:          make(chan struct{}),
	}, nil
}

78
// Start begins watching for pods and processing checkpoint/restore events.
79
func (w *Watcher) Start(ctx context.Context) error {
80
81
82
83
84
	w.log.Info("Starting pod watcher",
		"node", w.config.NodeName,
		"checkpoint", kubeLabelIsCheckpointSource,
		"restore", kubeLabelIsRestoreTarget,
	)
85

86
87
88
89
90
91
	var nsOptions []informers.SharedInformerOption
	if w.config.RestrictedNamespace != "" {
		w.log.Info("Restricting pod watching to namespace", "namespace", w.config.RestrictedNamespace)
		nsOptions = append(nsOptions, informers.WithNamespace(w.config.RestrictedNamespace))
	} else {
		w.log.Info("Watching pods cluster-wide (all namespaces)")
92
93
	}

94
95
96
97
98
	var syncFuncs []cache.InformerSynced

	// Checkpoint informer
	checkpointSelector := labels.SelectorFromSet(labels.Set{
		kubeLabelIsCheckpointSource: "true",
99
100
	}).String()

101
	ckptFactoryOpts := append([]informers.SharedInformerOption{
102
		informers.WithTweakListOptions(func(opts *metav1.ListOptions) {
103
			opts.LabelSelector = checkpointSelector
104
		}),
105
	}, nsOptions...)
106

107
108
	ckptFactory := informers.NewSharedInformerFactoryWithOptions(
		w.clientset, 30*time.Second, ckptFactoryOpts...,
109
110
	)

111
112
	ckptInformer := ckptFactory.Core().V1().Pods().Informer()
	ckptInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{
113
		AddFunc: func(obj interface{}) {
114
115
116
117
118
			pod, ok := podFromInformerObj(obj)
			if !ok {
				return
			}
			w.handleCheckpointPodEvent(ctx, pod)
119
		},
120
121
122
123
124
125
		UpdateFunc: func(_, newObj interface{}) {
			pod, ok := podFromInformerObj(newObj)
			if !ok {
				return
			}
			w.handleCheckpointPodEvent(ctx, pod)
126
127
		},
	})
128
129
	go ckptFactory.Start(w.stopCh)
	syncFuncs = append(syncFuncs, ckptInformer.HasSynced)
130

131
132
133
134
	// Restore informer
	restoreSelector := labels.SelectorFromSet(labels.Set{
		kubeLabelIsRestoreTarget: "true",
	}).String()
135

136
137
138
139
140
	restoreFactoryOpts := append([]informers.SharedInformerOption{
		informers.WithTweakListOptions(func(opts *metav1.ListOptions) {
			opts.LabelSelector = restoreSelector
		}),
	}, nsOptions...)
141

142
143
144
	restoreFactory := informers.NewSharedInformerFactoryWithOptions(
		w.clientset, 30*time.Second, restoreFactoryOpts...,
	)
145

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
	restoreInformer := restoreFactory.Core().V1().Pods().Informer()
	restoreInformer.AddEventHandler(cache.ResourceEventHandlerFuncs{
		AddFunc: func(obj interface{}) {
			pod, ok := podFromInformerObj(obj)
			if !ok {
				return
			}
			w.handleRestorePodEvent(ctx, pod)
		},
		UpdateFunc: func(_, newObj interface{}) {
			pod, ok := podFromInformerObj(newObj)
			if !ok {
				return
			}
			w.handleRestorePodEvent(ctx, pod)
		},
162
	})
163
164
	go restoreFactory.Start(w.stopCh)
	syncFuncs = append(syncFuncs, restoreInformer.HasSynced)
165

166
167
	if !cache.WaitForCacheSync(w.stopCh, syncFuncs...) {
		return fmt.Errorf("failed to sync informer caches")
168
169
	}

170
171
	w.log.Info("Pod watcher started and caches synced")
	<-ctx.Done()
172
	close(w.stopCh)
173
	return nil
174
175
}

176
func (w *Watcher) handleCheckpointPodEvent(ctx context.Context, pod *corev1.Pod) {
177
178
179
	if pod.Spec.NodeName != w.config.NodeName {
		return
	}
180
	if !isPodReady(pod) {
181
182
183
184
185
		return
	}

	podKey := fmt.Sprintf("%s/%s", pod.Namespace, pod.Name)

186
187
188
	checkpointHash, ok := pod.Labels[kubeLabelCheckpointHash]
	if !ok || checkpointHash == "" {
		w.log.Info("Pod has checkpoint label but no checkpoint-hash label", "pod", podKey)
189
190
191
		return
	}

192
193
	annotationStatus := pod.Annotations[kubeAnnotationCheckpointStatus]
	if annotationStatus == "completed" || annotationStatus == "in_progress" {
194
195
196
		return
	}

197
198
	if !w.tryAcquire(podKey) {
		return
199
200
	}

201
	w.log.Info("Pod ready, triggering checkpoint", "pod", podKey, "checkpoint_hash", checkpointHash)
202
	emitPodEvent(ctx, w.clientset, w.log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointRequested", fmt.Sprintf("Checkpoint requested: %s", checkpointHash))
203

204
	go w.doCheckpoint(ctx, pod, checkpointHash, podKey)
205
206
207
}


208
209
210
func (w *Watcher) handleRestorePodEvent(ctx context.Context, pod *corev1.Pod) {
	if pod.Spec.NodeName != w.config.NodeName {
		return
211
212
	}

213
	podKey := fmt.Sprintf("%s/%s", pod.Namespace, pod.Name)
214

215
	if pod.Status.Phase != corev1.PodRunning {
216
217
218
		return
	}

219
220
221
222
	annotationStatus := pod.Annotations[kubeAnnotationRestoreStatus]

	if isPodReady(pod) {
		return
223
224
	}

225
226
227
228
	// Restore failures require explicit intervention (new label/update) before retry.
	if annotationStatus == "completed" || annotationStatus == "in_progress" || annotationStatus == "failed" {
		return
	}
229

230
231
232
	checkpointHash, ok := pod.Labels[kubeLabelCheckpointHash]
	if !ok || checkpointHash == "" {
		w.log.Info("Restore pod has no checkpoint-hash label", "pod", podKey)
233
234
235
		return
	}

236
237
	if strings.ContainsAny(checkpointHash, "/\\") || strings.Contains(checkpointHash, "..") || filepath.Clean(checkpointHash) != checkpointHash {
		w.log.Error(fmt.Errorf("invalid checkpoint hash %q", checkpointHash), "Invalid checkpoint hash on restore pod", "pod", podKey)
238
239
240
		return
	}

241
242
243
244
	checkpointDir := filepath.Join(w.config.BasePath, checkpointHash)
	if _, err := os.Stat(checkpointDir); os.IsNotExist(err) {
		w.log.V(1).Info("Checkpoint not ready on disk, skipping restore", "pod", podKey, "checkpoint_hash", checkpointHash)
		return
245
246
	}

247
	if !w.tryAcquire(podKey) {
248
249
250
		return
	}

251
	w.log.Info("Restore pod running, triggering external restore", "pod", podKey, "checkpoint_hash", checkpointHash)
252
	emitPodEvent(ctx, w.clientset, w.log, pod, "snapshot", corev1.EventTypeNormal, "RestoreRequested", fmt.Sprintf("Restore requested from checkpoint %s", checkpointHash))
253

254
255
	go w.doRestore(ctx, pod, checkpointHash, podKey)
}
256

257
258
259
260
// doCheckpoint runs the full checkpoint workflow for a pod:
//  1. Mark pod as in_progress
//  2. Resolve the container ID and host PID
//  3. Call orchestrate.Checkpoint (inspect → configure → CUDA lock/checkpoint → CRIU dump → rootfs diff)
261
//  4. SIGUSR1 the process on success (notify workload), SIGKILL on failure (terminate immediately)
262
263
264
265
266
267
268
269
270
271
//  5. Mark pod as completed or failed
func (w *Watcher) doCheckpoint(ctx context.Context, pod *corev1.Pod, checkpointHash, podKey string) {
	defer w.release(podKey)
	log := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash)

	if err := annotatePod(ctx, w.clientset, log, pod, map[string]string{
		kubeAnnotationCheckpointStatus: "in_progress",
	}); err != nil {
		log.Error(err, "Failed to annotate pod with checkpoint in_progress")
		return
272
273
	}

274
275
276
277
278
	// Resolve the target container
	containerName := resolveMainContainerName(pod)
	if containerName == "" {
		err := fmt.Errorf("no containers found in pod spec")
		log.Error(err, "Checkpoint failed")
279
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
280
281
282
283
284
285
286
287
288
289
290
		annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationCheckpointStatus: "failed"})
		return
	}
	var containerID string
	for _, cs := range pod.Status.ContainerStatuses {
		if cs.Name == containerName {
			containerID = strings.TrimPrefix(cs.ContainerID, "containerd://")
			break
		}
	}
	if containerID == "" {
291
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", "Could not resolve target container ID")
292
293
		annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationCheckpointStatus: "failed"})
		return
294
295
	}

296
297
	// Resolve the container's host PID (needed for signaling after checkpoint)
	containerPID, _, err := common.ResolveContainer(ctx, w.containerd, containerID)
298
	if err != nil {
299
		log.Error(err, "Failed to resolve container")
300
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", fmt.Sprintf("Container resolve failed: %v", err))
301
		annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationCheckpointStatus: "failed"})
302
303
304
		return
	}

305
306
307
308
309
310
311
312
313
314
315
316
	// Step 1: Run the checkpoint orchestrator
	req := orchestrate.CheckpointRequest{
		ContainerID:    containerID,
		ContainerName:  containerName,
		CheckpointHash: checkpointHash,
		CheckpointDir:  w.config.BasePath,
		NodeName:       w.config.NodeName,
		PodName:        pod.Name,
		PodNamespace:   pod.Namespace,
	}
	if err := orchestrate.Checkpoint(ctx, w.containerd, log, req, w.config); err != nil {
		log.Error(err, "Checkpoint failed")
317
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
318
319
		// SIGKILL on failure: process is unrecoverable (CUDA locked), terminate immediately
		if signalErr := common.SendSignalToPID(log, containerPID, syscall.SIGKILL, "checkpoint failed"); signalErr != nil {
320
321
322
			log.Error(signalErr, "Failed to signal checkpoint failure to runtime process")
		}
		annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationCheckpointStatus: "failed"})
323
324
325
		return
	}

326
	// Step 2: SIGUSR1 on success: notify the workload that checkpoint completed
327
	emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "CheckpointSucceeded", fmt.Sprintf("Checkpoint completed: %s", checkpointHash))
328
329
	if err := common.SendSignalToPID(log, containerPID, syscall.SIGUSR1, "checkpoint complete"); err != nil {
		log.Error(err, "Failed to signal checkpoint completion to runtime process")
330
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "CheckpointFailed", err.Error())
331
		annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationCheckpointStatus: "failed"})
332
333
334
		return
	}

335
	annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationCheckpointStatus: "completed"})
336
337
}

338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
// doRestore runs the full restore workflow for a pod:
//  1. Mark pod as in_progress
//  2. Call orchestrate.Restore (inspect placeholder → nsrestore inside namespace)
//  3. SIGCONT the restored process to wake it up
//  4. Wait for the pod to become Ready
//  5. Mark pod as completed or failed
func (w *Watcher) doRestore(ctx context.Context, pod *corev1.Pod, checkpointHash, podKey string) {
	defer w.release(podKey)
	log := w.log.WithValues("pod", podKey, "checkpoint_hash", checkpointHash)

	if err := annotatePod(ctx, w.clientset, log, pod, map[string]string{
		kubeAnnotationRestoreStatus: "in_progress",
	}); err != nil {
		log.Error(err, "Failed to annotate pod with restore in_progress")
		return
	}
354

355
356
357
358
	containerName := resolveMainContainerName(pod)
	if containerName == "" {
		err := fmt.Errorf("no containers found in pod spec")
		log.Error(err, "Restore failed")
359
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
		annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationRestoreStatus: "failed"})
		return
	}

	// Step 1: Run the restore orchestrator (inspect + nsrestore)
	req := orchestrate.RestoreRequest{
		CheckpointHash: checkpointHash,
		CheckpointBase: w.config.BasePath,
		NSRestorePath:  w.config.Restore.NSRestorePath,
		PodName:        pod.Name,
		PodNamespace:   pod.Namespace,
		ContainerName:  containerName,
	}
	restoredPID, err := orchestrate.Restore(ctx, w.containerd, log, req)
	if err != nil {
		log.Error(err, "External restore failed")
376
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
377
378
		annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationRestoreStatus: "failed"})
		return
379
380
	}

381
382
	// Step 2: SIGCONT the restored process via PID namespace
	placeholderHostPID, _, err := common.ResolveContainerByPod(ctx, w.containerd, pod.Name, pod.Namespace, containerName)
383
	if err != nil {
384
		log.Error(err, "Failed to resolve placeholder host PID for signaling")
385
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
386
387
388
389
390
		annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationRestoreStatus: "failed"})
		return
	}
	if err := common.SendSignalViaPIDNamespace(ctx, log, placeholderHostPID, restoredPID, syscall.SIGCONT, "restore complete"); err != nil {
		log.Error(err, "Failed to signal restored runtime process")
391
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
392
		annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationRestoreStatus: "failed"})
393
394
395
		return
	}

396
397
398
399
400
401
402
403
404
	// Step 3: Wait for the pod to become Ready
	readyCtx := ctx
	if timeout := w.config.Restore.RestoreReadyTimeout(); timeout > 0 {
		var cancel context.CancelFunc
		readyCtx, cancel = context.WithTimeout(ctx, timeout)
		defer cancel()
	}
	if err := waitForPodReady(readyCtx, w.clientset, pod.Namespace, pod.Name, containerName); err != nil {
		log.Error(err, "Restore post-signal readiness check failed")
405
		emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
406
		annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationRestoreStatus: "failed"})
407
408
409
		return
	}

410
	emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeNormal, "RestoreSucceeded", fmt.Sprintf("Restore completed from checkpoint %s", checkpointHash))
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
	annotatePod(ctx, w.clientset, log, pod, map[string]string{kubeAnnotationRestoreStatus: "completed"})
}

func (w *Watcher) tryAcquire(podKey string) bool {
	w.inFlightMu.Lock()
	defer w.inFlightMu.Unlock()
	if _, held := w.inFlight[podKey]; held {
		return false
	}
	w.inFlight[podKey] = struct{}{}
	return true
}

func (w *Watcher) release(podKey string) {
	w.inFlightMu.Lock()
	defer w.inFlightMu.Unlock()
	delete(w.inFlight, podKey)
428
}
429
430
431