restore_test.go 17.2 KB
Newer Older
1
2
3
package protocol

import (
4
	"fmt"
5
	"math"
6
7
8
9
10
11
12
13
	"testing"

	appsv1 "k8s.io/api/apps/v1"
	corev1 "k8s.io/api/core/v1"
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

func TestNewRestorePod(t *testing.T) {
14
15
16
	readinessProbe := &corev1.Probe{PeriodSeconds: 7, TimeoutSeconds: 3}
	livenessProbe := &corev1.Probe{InitialDelaySeconds: 11}
	startupProbe := &corev1.Probe{FailureThreshold: 120}
17
18
19
20
21
22
23
24
25
26
27
28
29
	restorePod := NewRestorePod(&corev1.Pod{
		ObjectMeta: metav1.ObjectMeta{
			Name:        "worker",
			Labels:      map[string]string{"existing": "label"},
			Annotations: map[string]string{"existing": "annotation"},
		},
		Spec: corev1.PodSpec{
			RestartPolicy: corev1.RestartPolicyAlways,
			Containers: []corev1.Container{{
				Name:           "main",
				Image:          "test:latest",
				Command:        []string{"python3", "-m", "dynamo.vllm"},
				Args:           []string{"--model", "Qwen"},
30
31
32
				ReadinessProbe: readinessProbe.DeepCopy(),
				LivenessProbe:  livenessProbe.DeepCopy(),
				StartupProbe:   startupProbe.DeepCopy(),
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
			}},
		},
	}, PodOptions{
		Namespace:       "test-ns",
		CheckpointID:    "hash",
		ArtifactVersion: "2",
		Storage: Storage{
			Type:     StorageTypePVC,
			PVCName:  "snapshot-pvc",
			BasePath: "/checkpoints",
		},
		SeccompProfile: DefaultSeccompLocalhostProfile,
	})

	if restorePod.Name != "worker" || restorePod.Namespace != "test-ns" {
		t.Fatalf("unexpected restore pod identity: %#v", restorePod.ObjectMeta)
	}
	if restorePod.Labels[RestoreTargetLabel] != "true" {
		t.Fatalf("expected restore target label: %#v", restorePod.Labels)
	}
	if restorePod.Labels[CheckpointIDLabel] != "hash" {
		t.Fatalf("expected checkpoint id label: %#v", restorePod.Labels)
	}
	if restorePod.Annotations[CheckpointArtifactVersionAnnotation] != "2" {
		t.Fatalf("expected checkpoint artifact version annotation: %#v", restorePod.Annotations)
	}
	if restorePod.Spec.RestartPolicy != corev1.RestartPolicyNever {
		t.Fatalf("expected restartPolicy Never, got %#v", restorePod.Spec.RestartPolicy)
	}
	if len(restorePod.Spec.Containers[0].Command) != 2 || restorePod.Spec.Containers[0].Command[0] != "sleep" || restorePod.Spec.Containers[0].Command[1] != "infinity" {
		t.Fatalf("expected placeholder command, got %#v", restorePod.Spec.Containers[0].Command)
	}
	if restorePod.Spec.Containers[0].Args != nil {
		t.Fatalf("expected restore args to be cleared: %#v", restorePod.Spec.Containers[0].Args)
	}
68
69
	if restorePod.Spec.Containers[0].ReadinessProbe == nil {
		t.Fatalf("expected readiness probe to be preserved")
70
	}
71
72
	if got := restorePod.Spec.Containers[0].ReadinessProbe.PeriodSeconds; got != readinessProbe.PeriodSeconds {
		t.Fatalf("expected readiness probe period %d, got %d", readinessProbe.PeriodSeconds, got)
73
	}
74
75
76
77
78
79
80
81
82
83
84
	if restorePod.Spec.Containers[0].LivenessProbe == nil {
		t.Fatalf("expected liveness probe to be preserved")
	}
	if got := restorePod.Spec.Containers[0].LivenessProbe.InitialDelaySeconds; got != livenessProbe.InitialDelaySeconds {
		t.Fatalf("expected liveness initial delay %d, got %d", livenessProbe.InitialDelaySeconds, got)
	}
	if restorePod.Spec.Containers[0].StartupProbe == nil {
		t.Fatalf("expected startup probe to be preserved")
	}
	if got := restorePod.Spec.Containers[0].StartupProbe.FailureThreshold; got != math.MaxInt32 {
		t.Fatalf("expected startup failure threshold %d, got %d", math.MaxInt32, got)
85
86
87
88
	}
	if restorePod.Spec.SecurityContext == nil || restorePod.Spec.SecurityContext.SeccompProfile == nil {
		t.Fatalf("expected seccomp profile to be injected: %#v", restorePod.Spec.SecurityContext)
	}
89
90
	if len(restorePod.Spec.Volumes) != 2 {
		t.Fatalf("expected checkpoint and snapshot-control volumes, got %#v", restorePod.Spec.Volumes)
91
	}
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
	if len(restorePod.Spec.Containers[0].VolumeMounts) != 2 {
		t.Fatalf("expected checkpoint and snapshot-control mounts, got %#v", restorePod.Spec.Containers[0].VolumeMounts)
	}
	foundMount := false
	for _, m := range restorePod.Spec.Containers[0].VolumeMounts {
		if m.Name == SnapshotControlVolumeName {
			foundMount = true
			break
		}
	}
	if !foundMount {
		t.Fatalf("expected %s mount, got %#v", SnapshotControlVolumeName, restorePod.Spec.Containers[0].VolumeMounts)
	}
	foundEnv := false
	for _, e := range restorePod.Spec.Containers[0].Env {
		if e.Name == SnapshotControlDirEnv {
			foundEnv = true
			break
		}
	}
	if !foundEnv {
		t.Fatalf("expected %s env, got %#v", SnapshotControlDirEnv, restorePod.Spec.Containers[0].Env)
114
115
116
117
118
	}
}

func TestPrepareRestorePodSpec(t *testing.T) {
	podSpec := corev1.PodSpec{}
119
120
121
	readinessProbe := &corev1.Probe{PeriodSeconds: 13, SuccessThreshold: 1}
	livenessProbe := &corev1.Probe{TimeoutSeconds: 5}
	startupProbe := &corev1.Probe{FailureThreshold: 60}
122
123
124
	container := corev1.Container{
		Command:        []string{"python3", "-m", "dynamo.vllm"},
		Args:           []string{"--model", "Qwen"},
125
126
127
		ReadinessProbe: readinessProbe.DeepCopy(),
		LivenessProbe:  livenessProbe.DeepCopy(),
		StartupProbe:   startupProbe.DeepCopy(),
128
129
130
131
132
133
134
135
136
137
138
139
140
	}

	storage := Storage{
		Type:     StorageTypePVC,
		PVCName:  "snapshot-pvc",
		BasePath: "/checkpoints",
	}
	PrepareRestorePodSpec(&podSpec, &container, storage, DefaultSeccompLocalhostProfile, true)
	PrepareRestorePodSpec(&podSpec, &container, storage, DefaultSeccompLocalhostProfile, true)

	if podSpec.SecurityContext == nil || podSpec.SecurityContext.SeccompProfile == nil {
		t.Fatalf("expected seccomp profile to be injected: %#v", podSpec.SecurityContext)
	}
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
	if len(podSpec.Volumes) != 2 {
		t.Fatalf("expected checkpoint and snapshot-control volumes, got %#v", podSpec.Volumes)
	}
	if len(container.VolumeMounts) != 2 {
		t.Fatalf("expected checkpoint and snapshot-control mounts, got %#v", container.VolumeMounts)
	}
	volCount := 0
	for _, v := range podSpec.Volumes {
		if v.Name == SnapshotControlVolumeName {
			volCount++
		}
	}
	if volCount != 1 {
		t.Fatalf("expected single %s volume after repeated calls, got %#v", SnapshotControlVolumeName, podSpec.Volumes)
	}
	mountCount := 0
	for _, m := range container.VolumeMounts {
		if m.Name == SnapshotControlVolumeName {
			mountCount++
		}
161
	}
162
163
164
165
166
167
168
169
170
171
172
	if mountCount != 1 {
		t.Fatalf("expected single %s mount after repeated calls, got %#v", SnapshotControlVolumeName, container.VolumeMounts)
	}
	envCount := 0
	for _, e := range container.Env {
		if e.Name == SnapshotControlDirEnv {
			envCount++
		}
	}
	if envCount != 1 {
		t.Fatalf("expected single %s env after repeated calls, got %#v", SnapshotControlDirEnv, container.Env)
173
174
175
176
177
178
179
	}
	if len(container.Command) != 2 || container.Command[0] != "sleep" || container.Command[1] != "infinity" {
		t.Fatalf("expected placeholder command, got %#v", container.Command)
	}
	if container.Args != nil {
		t.Fatalf("expected restore args to be cleared: %#v", container.Args)
	}
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
	if container.ReadinessProbe == nil {
		t.Fatalf("expected readiness probe to be preserved")
	}
	if got := container.ReadinessProbe.PeriodSeconds; got != readinessProbe.PeriodSeconds {
		t.Fatalf("expected readiness probe period %d, got %d", readinessProbe.PeriodSeconds, got)
	}
	if container.LivenessProbe == nil {
		t.Fatalf("expected liveness probe to be preserved")
	}
	if got := container.LivenessProbe.TimeoutSeconds; got != livenessProbe.TimeoutSeconds {
		t.Fatalf("expected liveness timeout %d, got %d", livenessProbe.TimeoutSeconds, got)
	}
	if container.StartupProbe == nil {
		t.Fatalf("expected startup probe to be preserved")
	}
	if got := container.StartupProbe.FailureThreshold; got != math.MaxInt32 {
		t.Fatalf("expected startup failure threshold %d, got %d", math.MaxInt32, got)
	}
	if got := container.StartupProbe.SuccessThreshold; got != 1 {
		t.Fatalf("expected startup success threshold 1, got %d", got)
	}
}

func TestPrepareRestorePodSpecSynthesizesStartupProbeFromLiveness(t *testing.T) {
	podSpec := corev1.PodSpec{}
	livenessProbe := &corev1.Probe{
		ProbeHandler: corev1.ProbeHandler{
			HTTPGet: &corev1.HTTPGetAction{Path: "/livez"},
		},
		PeriodSeconds:    5,
		TimeoutSeconds:   4,
		FailureThreshold: 2,
	}
	container := corev1.Container{
		Command:       []string{"python3", "-m", "dynamo.vllm"},
		Args:          []string{"--model", "Qwen"},
		LivenessProbe: livenessProbe.DeepCopy(),
	}

	PrepareRestorePodSpec(&podSpec, &container, Storage{}, "", true)

	if container.LivenessProbe == nil {
		t.Fatalf("expected liveness probe to be preserved")
	}
	if container.StartupProbe == nil {
		t.Fatalf("expected startup probe to be synthesized")
	}
	if container.StartupProbe.HTTPGet == nil || container.StartupProbe.HTTPGet.Path != "/livez" {
		t.Fatalf("expected startup probe HTTP path /livez, got %#v", container.StartupProbe.HTTPGet)
	}
	if got := container.StartupProbe.FailureThreshold; got != math.MaxInt32 {
		t.Fatalf("expected startup failure threshold %d, got %d", math.MaxInt32, got)
	}
	if got := container.StartupProbe.SuccessThreshold; got != 1 {
		t.Fatalf("expected startup success threshold 1, got %d", got)
	}
}

238
func TestNewRestorePodTargetsFirstContainerWhenSidecarsPresent(t *testing.T) {
239
240
241
242
	restorePod := NewRestorePod(&corev1.Pod{
		ObjectMeta: metav1.ObjectMeta{Name: "worker"},
		Spec: corev1.PodSpec{
			Containers: []corev1.Container{
243
				{Name: "worker", Image: "test:latest", Command: []string{"python3"}, Args: []string{"-m", "dynamo.vllm"}},
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
				{Name: "sidecar", Image: "sidecar:latest", Command: []string{"sidecar"}, Args: []string{"run"}},
			},
		},
	}, PodOptions{
		Namespace:       "test-ns",
		CheckpointID:    "hash",
		ArtifactVersion: "2",
		Storage: Storage{
			Type:     StorageTypePVC,
			PVCName:  "snapshot-pvc",
			BasePath: "/checkpoints",
		},
		SeccompProfile: DefaultSeccompLocalhostProfile,
	})

259
260
	if got := restorePod.Spec.Containers[0].Command; len(got) != 2 || got[0] != "sleep" || got[1] != "infinity" {
		t.Fatalf("expected first container placeholder command, got %#v", got)
261
	}
262
263
	if restorePod.Spec.Containers[0].Args != nil {
		t.Fatalf("expected first container args to be cleared: %#v", restorePod.Spec.Containers[0].Args)
264
	}
265
266
	if got := restorePod.Spec.Containers[1].Command; len(got) != 1 || got[0] != "sidecar" {
		t.Fatalf("expected sidecar command to remain unchanged, got %#v", got)
267
268
269
	}
}

270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
func TestPrepareRestorePodSpecSynthesizesStartupProbeFromReadiness(t *testing.T) {
	podSpec := corev1.PodSpec{}
	readinessProbe := &corev1.Probe{
		ProbeHandler: corev1.ProbeHandler{
			Exec: &corev1.ExecAction{Command: []string{"cat", "/tmp/ready"}},
		},
		PeriodSeconds:    13,
		SuccessThreshold: 3,
		FailureThreshold: 4,
	}
	container := corev1.Container{
		Command:        []string{"python3", "-m", "dynamo.vllm"},
		Args:           []string{"--model", "Qwen"},
		ReadinessProbe: readinessProbe.DeepCopy(),
	}

	PrepareRestorePodSpec(&podSpec, &container, Storage{}, "", true)

	if container.ReadinessProbe == nil {
		t.Fatalf("expected readiness probe to be preserved")
	}
	if got := container.ReadinessProbe.SuccessThreshold; got != readinessProbe.SuccessThreshold {
		t.Fatalf("expected readiness success threshold %d, got %d", readinessProbe.SuccessThreshold, got)
	}
	if container.StartupProbe == nil {
		t.Fatalf("expected startup probe to be synthesized")
	}
	if container.StartupProbe.Exec == nil || len(container.StartupProbe.Exec.Command) != 2 || container.StartupProbe.Exec.Command[0] != "cat" || container.StartupProbe.Exec.Command[1] != "/tmp/ready" {
		t.Fatalf("expected startup probe exec command to match readiness probe: %#v", container.StartupProbe.Exec)
	}
	if got := container.StartupProbe.FailureThreshold; got != math.MaxInt32 {
		t.Fatalf("expected startup failure threshold %d, got %d", math.MaxInt32, got)
	}
	if got := container.StartupProbe.SuccessThreshold; got != 1 {
		t.Fatalf("expected startup success threshold 1, got %d", got)
305
306
307
	}
}

308
309
func validRestoreSpecFixture(profile string) *corev1.PodSpec {
	return &corev1.PodSpec{
310
311
312
313
314
315
		SecurityContext: &corev1.PodSecurityContext{
			SeccompProfile: &corev1.SeccompProfile{
				Type:             corev1.SeccompProfileTypeLocalhost,
				LocalhostProfile: &profile,
			},
		},
316
317
318
319
320
321
322
		Volumes: []corev1.Volume{
			{
				Name: CheckpointVolumeName,
				VolumeSource: corev1.VolumeSource{
					PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
						ClaimName: "snapshot-pvc",
					},
323
324
				},
			},
325
326
327
328
329
			{
				Name:         SnapshotControlVolumeName,
				VolumeSource: corev1.VolumeSource{EmptyDir: &corev1.EmptyDirVolumeSource{}},
			},
		},
330
331
		Containers: []corev1.Container{{
			Name: "main",
332
333
334
335
336
			VolumeMounts: []corev1.VolumeMount{
				{Name: CheckpointVolumeName, MountPath: "/checkpoints"},
				{Name: SnapshotControlVolumeName, MountPath: SnapshotControlMountPath},
			},
			Env: []corev1.EnvVar{{Name: SnapshotControlDirEnv, Value: SnapshotControlMountPath}},
337
338
		}},
	}
339
340
341
342
343
}

func TestValidateRestorePodSpec(t *testing.T) {
	profile := DefaultSeccompLocalhostProfile
	podSpec := validRestoreSpecFixture(profile)
344
345
346
347
348
349
350
351
352
353
354
	storage := Storage{
		Type:     StorageTypePVC,
		PVCName:  "snapshot-pvc",
		BasePath: "/checkpoints",
	}

	if err := ValidateRestorePodSpec(podSpec, storage, DefaultSeccompLocalhostProfile); err != nil {
		t.Fatalf("expected restore pod spec to be valid, got %v", err)
	}

	badSpec := podSpec.DeepCopy()
355
	badSpec.Volumes = []corev1.Volume{badSpec.Volumes[1]}
356
357
358
359
360
	if err := ValidateRestorePodSpec(badSpec, storage, DefaultSeccompLocalhostProfile); err == nil || err.Error() != "missing checkpoint-storage volume for PVC snapshot-pvc" {
		t.Fatalf("expected missing volume error, got %v", err)
	}

	badSpec = podSpec.DeepCopy()
361
	badSpec.Containers[0].VolumeMounts = []corev1.VolumeMount{badSpec.Containers[0].VolumeMounts[1]}
362
363
364
365
	if err := ValidateRestorePodSpec(badSpec, storage, DefaultSeccompLocalhostProfile); err == nil || err.Error() != "missing checkpoint-storage mount at /checkpoints" {
		t.Fatalf("expected missing mount error, got %v", err)
	}

366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
	badSpec = podSpec.DeepCopy()
	badSpec.Volumes = []corev1.Volume{badSpec.Volumes[0]}
	if err := ValidateRestorePodSpec(badSpec, storage, DefaultSeccompLocalhostProfile); err == nil || err.Error() != fmt.Sprintf("missing %s emptyDir volume; add it via snapshotprotocol.EnsureControlVolume", SnapshotControlVolumeName) {
		t.Fatalf("expected missing control volume error, got %v", err)
	}

	badSpec = podSpec.DeepCopy()
	badSpec.Containers[0].VolumeMounts = []corev1.VolumeMount{badSpec.Containers[0].VolumeMounts[0]}
	if err := ValidateRestorePodSpec(badSpec, storage, DefaultSeccompLocalhostProfile); err == nil || err.Error() != fmt.Sprintf("missing %s mount at %s", SnapshotControlVolumeName, SnapshotControlMountPath) {
		t.Fatalf("expected missing control mount error, got %v", err)
	}

	badSpec = podSpec.DeepCopy()
	badSpec.Containers[0].Env = nil
	if err := ValidateRestorePodSpec(badSpec, storage, DefaultSeccompLocalhostProfile); err == nil || err.Error() != fmt.Sprintf("missing %s env var on worker container", SnapshotControlDirEnv) {
		t.Fatalf("expected missing control env error, got %v", err)
	}

384
385
386
387
388
389
390
	badSpec = podSpec.DeepCopy()
	badSpec.SecurityContext = nil
	if err := ValidateRestorePodSpec(badSpec, storage, DefaultSeccompLocalhostProfile); err == nil || err.Error() != "missing localhost seccomp profile" {
		t.Fatalf("expected missing seccomp error, got %v", err)
	}
}

391
func TestValidateRestorePodSpecAcceptsFirstContainerAsWorker(t *testing.T) {
392
393
394
	podSpec := validRestoreSpecFixture(DefaultSeccompLocalhostProfile)
	podSpec.Containers[0].Name = "worker"
	podSpec.Containers = append(podSpec.Containers, corev1.Container{Name: "sidecar"})
395
396
397
398
399
400
401

	storage := Storage{
		Type:     StorageTypePVC,
		PVCName:  "snapshot-pvc",
		BasePath: "/checkpoints",
	}

402
403
404
	// Containers[0] is always the worker, regardless of name
	if err := ValidateRestorePodSpec(podSpec, storage, DefaultSeccompLocalhostProfile); err != nil {
		t.Fatalf("expected validation to pass for first container as worker, got %v", err)
405
406
407
	}
}

408
func TestValidateRestorePodSpecAllowsWorkerWithSidecars(t *testing.T) {
409
410
411
	podSpec := validRestoreSpecFixture(DefaultSeccompLocalhostProfile)
	podSpec.Containers[0].Name = "worker"
	podSpec.Containers = append(podSpec.Containers, corev1.Container{Name: "sidecar"})
412
413
414
415
416
417
418
419

	storage := Storage{
		Type:     StorageTypePVC,
		PVCName:  "snapshot-pvc",
		BasePath: "/checkpoints",
	}

	if err := ValidateRestorePodSpec(podSpec, storage, DefaultSeccompLocalhostProfile); err != nil {
420
		t.Fatalf("expected worker with sidecars to validate, got %v", err)
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
	}
}

func TestDiscoverStorageFromDaemonSetsUsesCheckpointsVolume(t *testing.T) {
	daemonSet := appsv1.DaemonSet{
		ObjectMeta: metav1.ObjectMeta{Name: "snapshot-agent", Namespace: "test-ns"},
		Spec: appsv1.DaemonSetSpec{
			Template: corev1.PodTemplateSpec{
				Spec: corev1.PodSpec{
					Containers: []corev1.Container{{
						Name: SnapshotAgentContainerName,
						VolumeMounts: []corev1.VolumeMount{
							{Name: "cache", MountPath: "/cache"},
							{Name: SnapshotAgentVolumeName, MountPath: "/checkpoints"},
						},
					}},
					Volumes: []corev1.Volume{
						{
							Name: "cache",
							VolumeSource: corev1.VolumeSource{
								PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ClaimName: "cache-pvc"},
							},
						},
						{
							Name: SnapshotAgentVolumeName,
							VolumeSource: corev1.VolumeSource{
								PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ClaimName: "snapshot-pvc"},
							},
						},
					},
				},
			},
		},
	}

	storage, err := DiscoverStorageFromDaemonSets("test-ns", []appsv1.DaemonSet{daemonSet})
	if err != nil {
		t.Fatalf("expected daemonset storage discovery to succeed, got %v", err)
	}
	if storage.PVCName != "snapshot-pvc" || storage.BasePath != "/checkpoints" {
		t.Fatalf("expected snapshot PVC discovery, got %#v", storage)
	}
}