"lib/runtime/src/compute/pool.rs" did not exist on "7ebbd001c014e69650993a1c36af421045a6e91e"
checkpoint_test.go 17.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/*
 * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package checkpoint

import (
	"context"
	"testing"

	nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
	"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
26
	gms "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
27
	snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
28
29
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
30
	appsv1 "k8s.io/api/apps/v1"
31
32
33
34
	corev1 "k8s.io/api/core/v1"
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
	"k8s.io/apimachinery/pkg/runtime"
	"k8s.io/utils/ptr"
35
	"sigs.k8s.io/controller-runtime/pkg/client"
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
	"sigs.k8s.io/controller-runtime/pkg/client/fake"
)

const (
	testHash      = "abc123def4567890"
	testNamespace = "default"
)

func testIdentity() nvidiacomv1alpha1.DynamoCheckpointIdentity {
	return nvidiacomv1alpha1.DynamoCheckpointIdentity{
		Model:            "meta-llama/Llama-2-7b-hf",
		BackendFramework: "vllm",
	}
}

func testPodSpec() *corev1.PodSpec {
	return &corev1.PodSpec{
		Containers: []corev1.Container{{
			Name:    consts.MainContainerName,
			Image:   "test-image:latest",
			Command: []string{"python3"},
			Args:    []string{"-m", "dynamo.vllm"},
		}},
	}
}

func testScheme() *runtime.Scheme {
	s := runtime.NewScheme()
	_ = nvidiacomv1alpha1.AddToScheme(s)
	_ = corev1.AddToScheme(s)
66
	_ = appsv1.AddToScheme(s)
67
68
69
70
71
72
73
	return s
}

func testInfo() *CheckpointInfo {
	return &CheckpointInfo{Enabled: true, Hash: testHash}
}

74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
func testSnapshotAgentDaemonSet() *appsv1.DaemonSet {
	return &appsv1.DaemonSet{
		ObjectMeta: metav1.ObjectMeta{
			Name:      "snapshot-agent",
			Namespace: testNamespace,
			Labels: map[string]string{
				snapshotprotocol.SnapshotAgentLabelKey: snapshotprotocol.SnapshotAgentLabelValue,
			},
		},
		Spec: appsv1.DaemonSetSpec{
			Template: corev1.PodTemplateSpec{
				Spec: corev1.PodSpec{
					Containers: []corev1.Container{{
						Name: snapshotprotocol.SnapshotAgentContainerName,
						VolumeMounts: []corev1.VolumeMount{{
							Name:      "checkpoints",
							MountPath: "/checkpoints",
						}},
					}},
					Volumes: []corev1.Volume{{
						Name: "checkpoints",
						VolumeSource: corev1.VolumeSource{
							PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
								ClaimName: "snapshot-pvc",
							},
						},
					}},
				},
			},
		},
	}
}

107
108
109
110
type createHookClient struct {
	client.Client
	onCreate func(ctx context.Context, obj client.Object) error
}
111

112
113
114
115
116
117
118
119
120
121
func (c *createHookClient) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) error {
	if c.onCreate != nil {
		if err := c.onCreate(ctx, obj); err != nil {
			return err
		}
		c.onCreate = nil
	}

	return c.Client.Create(ctx, obj, opts...)
}
122

123
124
125
126
127
128
129
130
131
132
133
134
135
func TestCreateOrGetAutoCheckpointDeduplicatesConcurrentSameHashCheckpoint(t *testing.T) {
	ctx := context.Background()
	s := testScheme()

	identity := testIdentity()
	hash, err := ComputeIdentityHash(identity)
	require.NoError(t, err)

	friendly := &nvidiacomv1alpha1.DynamoCheckpoint{
		ObjectMeta: metav1.ObjectMeta{
			Name:      "friendly-checkpoint",
			Namespace: testNamespace,
			Labels: map[string]string{
136
				snapshotprotocol.CheckpointIDLabel: hash,
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
			},
		},
		Spec: nvidiacomv1alpha1.DynamoCheckpointSpec{
			Identity: identity,
			Job: nvidiacomv1alpha1.DynamoCheckpointJobConfig{
				PodTemplateSpec: corev1.PodTemplateSpec{},
			},
		},
		Status: nvidiacomv1alpha1.DynamoCheckpointStatus{
			IdentityHash: hash,
			Phase:        nvidiacomv1alpha1.DynamoCheckpointPhaseReady,
		},
	}

	baseClient := fake.NewClientBuilder().WithScheme(s).Build()
	c := &createHookClient{
		Client: baseClient,
		onCreate: func(ctx context.Context, obj client.Object) error {
			_, ok := obj.(*nvidiacomv1alpha1.DynamoCheckpoint)
			if !ok {
				return nil
			}
			return baseClient.Create(ctx, friendly.DeepCopy())
		},
	}

163
	ckpt, err := CreateOrGetAutoCheckpoint(ctx, c, testNamespace, identity, corev1.PodTemplateSpec{}, nil)
164
165
166
167
168
169
170
171
172
	require.NoError(t, err)
	assert.Equal(t, friendly.Name, ckpt.Name)

	list := &nvidiacomv1alpha1.DynamoCheckpointList{}
	require.NoError(t, baseClient.List(ctx, list))
	require.Len(t, list.Items, 1)
	assert.Equal(t, friendly.Name, list.Items[0].Name)
}

173
174
175
176
177
func TestCreateOrGetAutoCheckpointSetsDefaultArtifactVersion(t *testing.T) {
	ctx := context.Background()
	s := testScheme()
	c := fake.NewClientBuilder().WithScheme(s).Build()

178
	ckpt, err := CreateOrGetAutoCheckpoint(ctx, c, testNamespace, testIdentity(), corev1.PodTemplateSpec{}, nil)
179
180
	require.NoError(t, err)
	require.NotNil(t, ckpt.Annotations)
181
	assert.Equal(t, snapshotprotocol.DefaultCheckpointArtifactVersion, ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation])
182
183
}

184
185
186
// --- InjectCheckpointIntoPodSpec tests ---

func TestInjectCheckpointIntoPodSpec(t *testing.T) {
187
	t.Run("ready checkpoint injects podinfo and overrides command", func(t *testing.T) {
188
		podSpec := testPodSpec()
189
190
191
		info := &CheckpointInfo{Enabled: true, Ready: true, Identity: ptr.To(testIdentity())}
		reader := fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build()
		require.NoError(t, InjectCheckpointIntoPodSpec(context.Background(), reader, testNamespace, podSpec, info))
192
193
		assert.Equal(t, []string{"sleep", "infinity"}, podSpec.Containers[0].Command)
		assert.Nil(t, podSpec.Containers[0].Args)
194
		assert.Len(t, info.Hash, 16)
195

196
197
198
		volumes := map[string]corev1.Volume{}
		for _, volume := range podSpec.Volumes {
			volumes[volume.Name] = volume
199
		}
200
201
		require.Contains(t, volumes, consts.PodInfoVolumeName)
		require.NotNil(t, volumes[consts.PodInfoVolumeName].DownwardAPI)
202

203
204
205
206
		fields := map[string]string{}
		for _, item := range volumes[consts.PodInfoVolumeName].DownwardAPI.Items {
			if item.FieldRef != nil {
				fields[item.Path] = item.FieldRef.FieldPath
207
			}
208
		}
209
210
211
212
213
214
215
216
217
		assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoNamespace+"']", fields[consts.PodInfoFileDynNamespace])
		assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoWorkerHash+"']", fields[consts.PodInfoFileDynNamespaceWorkerSuffix])
		assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoComponentType+"']", fields[consts.PodInfoFileDynComponent])
		assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoGraphDeploymentName+"']", fields[consts.PodInfoFileDynParentDGDName])
		assert.Equal(t, consts.PodInfoFieldPodNamespace, fields[consts.PodInfoFileDynParentDGDNamespace])

		mountPaths := map[string]string{}
		for _, mount := range podSpec.Containers[0].VolumeMounts {
			mountPaths[mount.Name] = mount.MountPath
218
219
220
221
		}
		assert.Equal(t, consts.PodInfoMountPath, mountPaths[consts.PodInfoVolumeName])
	})

222
223
224
	t.Run("ready checkpoint targets the container named main", func(t *testing.T) {
		podSpec := &corev1.PodSpec{
			Containers: []corev1.Container{
225
				{Name: "main", Image: "main:latest", Command: []string{"python3"}, Args: []string{"-m", "dynamo.vllm"}},
226
227
				{Name: "sidecar", Image: "sidecar:latest", Command: []string{"sidecar"}, Args: []string{"run"}},
			},
228
		}
229
230
231
232
		info := &CheckpointInfo{Enabled: true, Ready: true, Hash: testHash}
		reader := fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build()

		require.NoError(t, InjectCheckpointIntoPodSpec(context.Background(), reader, testNamespace, podSpec, info))
233
234
235
236
		assert.Equal(t, []string{"sleep", "infinity"}, podSpec.Containers[0].Command)
		assert.Nil(t, podSpec.Containers[0].Args)
		assert.Equal(t, []string{"sidecar"}, podSpec.Containers[1].Command)
		assert.Equal(t, []string{"run"}, podSpec.Containers[1].Args)
237
238
	})

239
240
241
242
243
244
245
	t.Run("ready gms checkpoint injects restore sidecars and loader mount", func(t *testing.T) {
		podSpec := testPodSpec()
		podSpec.Containers[0].Resources.Claims = []corev1.ResourceClaim{{Name: "gpu"}}
		info := &CheckpointInfo{Enabled: true, Ready: true, Hash: testHash, GPUMemoryService: &nvidiacomv1alpha1.GPUMemoryServiceSpec{Enabled: true}}
		reader := fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build()

		require.NoError(t, InjectCheckpointIntoPodSpec(context.Background(), reader, testNamespace, podSpec, info))
246
		gmsServer := findContainer(podSpec, gms.ServerContainerName)
247
248
249
250
		require.NotNil(t, gmsServer)
		loader := findContainer(podSpec, GMSLoaderContainer)
		require.NotNil(t, loader)

251
252
253
		// Restore: server and loader are init sidecars (restartPolicy=Always)
		assert.NotNil(t, gmsServer.RestartPolicy, "restore gms-server should have RestartPolicy")
		assert.Equal(t, corev1.ContainerRestartPolicyAlways, *gmsServer.RestartPolicy)
254
		assert.Nil(t, gmsServer.StartupProbe, "restore gms-server should not have StartupProbe")
255
256
		assert.NotNil(t, loader.RestartPolicy, "restore gms-loader should have RestartPolicy")
		assert.Equal(t, corev1.ContainerRestartPolicyAlways, *loader.RestartPolicy)
257
258
259
260
261
262

		mounts := map[string]string{}
		for _, mount := range loader.VolumeMounts {
			mounts[mount.Name] = mount.MountPath
		}
		assert.Equal(t, "/checkpoints", mounts[snapshotprotocol.CheckpointVolumeName])
263
		assert.Equal(t, gms.SharedMountPath, mounts[gms.SharedVolumeName])
264
265
266
267
268
269
270
271
272
273

		env := map[string]string{}
		for _, item := range loader.Env {
			env[item.Name] = item.Value
		}
		assert.Equal(t, "/checkpoints/gms/"+testHash+"/versions/1", env["GMS_CHECKPOINT_DIR"])
		assert.Equal(t, []string{"python3", "-m", "gpu_memory_service.cli.server"}, gmsServer.Command)
		assert.Equal(t, []string{"python3", "-m", "gpu_memory_service.cli.snapshot.loader"}, loader.Command)
	})

274
275
276
277
278
	t.Run("error cases", func(t *testing.T) {
		for _, tc := range []struct {
			name    string
			podSpec *corev1.PodSpec
			info    *CheckpointInfo
279
			reader  client.Reader
280
281
			errMsg  string
		}{
282
			{"hash empty and identity nil", testPodSpec(), &CheckpointInfo{Enabled: true}, fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build(), "identity is nil"},
283
			{"no containers", &corev1.PodSpec{}, testInfo(), fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build(), "no container named"},
284
			{"snapshot daemonset missing", testPodSpec(), testInfo(), fake.NewClientBuilder().WithScheme(testScheme()).Build(), "no snapshot-agent daemonset found"},
285
286
		} {
			t.Run(tc.name, func(t *testing.T) {
287
				err := InjectCheckpointIntoPodSpec(context.Background(), tc.reader, testNamespace, tc.podSpec, tc.info)
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
				require.Error(t, err)
				assert.Contains(t, err.Error(), tc.errMsg)
			})
		}
	})
}

// --- ResolveCheckpointForService tests ---

func TestResolveCheckpointForService(t *testing.T) {
	ctx := context.Background()
	s := testScheme()

	t.Run("nil or disabled config returns disabled", func(t *testing.T) {
		c := fake.NewClientBuilder().WithScheme(s).Build()
		for _, cfg := range []*nvidiacomv1alpha1.ServiceCheckpointConfig{nil, {Enabled: false}} {
			info, err := ResolveCheckpointForService(ctx, c, testNamespace, cfg)
			require.NoError(t, err)
			assert.False(t, info.Enabled)
		}
	})

	t.Run("checkpointRef resolves ready CR", func(t *testing.T) {
311
312
		hash, err := ComputeIdentityHash(testIdentity())
		require.NoError(t, err)
313
		ckpt := &nvidiacomv1alpha1.DynamoCheckpoint{
314
			ObjectMeta: metav1.ObjectMeta{Name: hash, Namespace: testNamespace},
315
316
317
318
			Spec: nvidiacomv1alpha1.DynamoCheckpointSpec{
				Identity:         testIdentity(),
				GPUMemoryService: &nvidiacomv1alpha1.GPUMemoryServiceSpec{Enabled: true},
			},
319
			Status: nvidiacomv1alpha1.DynamoCheckpointStatus{
320
321
				Phase:        nvidiacomv1alpha1.DynamoCheckpointPhaseReady,
				IdentityHash: hash,
322
323
324
			},
		}
		c := fake.NewClientBuilder().WithScheme(s).WithObjects(ckpt).WithStatusSubresource(ckpt).Build()
325
		ref := hash
326
327
328
329
330

		info, err := ResolveCheckpointForService(ctx, c, testNamespace, &nvidiacomv1alpha1.ServiceCheckpointConfig{
			Enabled: true, CheckpointRef: &ref,
		})
		require.NoError(t, err)
331
		assert.True(t, info.Exists)
332
		assert.True(t, info.Ready)
333
334
		assert.Equal(t, hash, info.Hash)
		assert.Equal(t, hash, info.CheckpointName)
335
336
		require.NotNil(t, info.GPUMemoryService)
		assert.True(t, info.GPUMemoryService.Enabled)
337
338
339
	})

	t.Run("checkpointRef resolves not-ready CR", func(t *testing.T) {
340
341
		hash, err := ComputeIdentityHash(testIdentity())
		require.NoError(t, err)
342
		ckpt := &nvidiacomv1alpha1.DynamoCheckpoint{
343
			ObjectMeta: metav1.ObjectMeta{Name: hash, Namespace: testNamespace},
344
345
346
347
			Spec:       nvidiacomv1alpha1.DynamoCheckpointSpec{Identity: testIdentity()},
			Status:     nvidiacomv1alpha1.DynamoCheckpointStatus{Phase: nvidiacomv1alpha1.DynamoCheckpointPhaseCreating},
		}
		c := fake.NewClientBuilder().WithScheme(s).WithObjects(ckpt).WithStatusSubresource(ckpt).Build()
348
		ref := hash
349
350
351
352
353

		info, err := ResolveCheckpointForService(ctx, c, testNamespace, &nvidiacomv1alpha1.ServiceCheckpointConfig{
			Enabled: true, CheckpointRef: &ref,
		})
		require.NoError(t, err)
354
		assert.True(t, info.Exists)
355
356
357
358
359
360
361
362
363
364
365
366
		assert.False(t, info.Ready)
	})

	t.Run("checkpointRef errors when CR not found", func(t *testing.T) {
		c := fake.NewClientBuilder().WithScheme(s).Build()
		ref := "nonexistent"
		_, err := ResolveCheckpointForService(ctx, c, testNamespace, &nvidiacomv1alpha1.ServiceCheckpointConfig{
			Enabled: true, CheckpointRef: &ref,
		})
		assert.ErrorContains(t, err, "nonexistent")
	})

367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
	t.Run("checkpointRef resolves human-readable checkpoint names", func(t *testing.T) {
		hash, err := ComputeIdentityHash(testIdentity())
		require.NoError(t, err)
		ckpt := &nvidiacomv1alpha1.DynamoCheckpoint{
			ObjectMeta: metav1.ObjectMeta{Name: "not-the-hash", Namespace: testNamespace},
			Spec:       nvidiacomv1alpha1.DynamoCheckpointSpec{Identity: testIdentity()},
			Status: nvidiacomv1alpha1.DynamoCheckpointStatus{
				IdentityHash: hash,
			},
		}
		c := fake.NewClientBuilder().WithScheme(s).WithObjects(ckpt).WithStatusSubresource(ckpt).Build()
		ref := "not-the-hash"

		info, err := ResolveCheckpointForService(ctx, c, testNamespace, &nvidiacomv1alpha1.ServiceCheckpointConfig{
			Enabled: true, CheckpointRef: &ref,
		})
		require.NoError(t, err)
		assert.Equal(t, "not-the-hash", info.CheckpointName)
		assert.Equal(t, hash, info.Hash)
	})

	t.Run("identity lookup finds existing checkpoint by identity hash", func(t *testing.T) {
389
390
391
392
393
		identity := testIdentity()
		hash, err := ComputeIdentityHash(identity)
		require.NoError(t, err)

		ckpt := &nvidiacomv1alpha1.DynamoCheckpoint{
394
395
			ObjectMeta: metav1.ObjectMeta{Name: "friendly-name", Namespace: testNamespace},
			Spec:       nvidiacomv1alpha1.DynamoCheckpointSpec{Identity: identity},
396
			Status: nvidiacomv1alpha1.DynamoCheckpointStatus{
397
398
				Phase:        nvidiacomv1alpha1.DynamoCheckpointPhaseReady,
				IdentityHash: hash,
399
400
401
402
403
404
405
406
			},
		}
		c := fake.NewClientBuilder().WithScheme(s).WithObjects(ckpt).WithStatusSubresource(ckpt).Build()

		info, err := ResolveCheckpointForService(ctx, c, testNamespace, &nvidiacomv1alpha1.ServiceCheckpointConfig{
			Enabled: true, Identity: &identity,
		})
		require.NoError(t, err)
407
		assert.True(t, info.Exists)
408
409
		assert.True(t, info.Ready)
		assert.Equal(t, hash, info.Hash)
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
		assert.Equal(t, "friendly-name", info.CheckpointName)
	})

	t.Run("identity lookup returns existing not-ready checkpoint", func(t *testing.T) {
		identity := testIdentity()
		hash, err := ComputeIdentityHash(identity)
		require.NoError(t, err)

		ckpt := &nvidiacomv1alpha1.DynamoCheckpoint{
			ObjectMeta: metav1.ObjectMeta{Name: "friendly-name", Namespace: testNamespace},
			Spec:       nvidiacomv1alpha1.DynamoCheckpointSpec{Identity: identity},
			Status: nvidiacomv1alpha1.DynamoCheckpointStatus{
				Phase:        nvidiacomv1alpha1.DynamoCheckpointPhaseCreating,
				IdentityHash: hash,
			},
		}
		c := fake.NewClientBuilder().WithScheme(s).WithObjects(ckpt).WithStatusSubresource(ckpt).Build()

		info, err := ResolveCheckpointForService(ctx, c, testNamespace, &nvidiacomv1alpha1.ServiceCheckpointConfig{
			Enabled: true, Identity: &identity,
		})
		require.NoError(t, err)
		assert.True(t, info.Exists)
		assert.False(t, info.Ready)
		assert.Equal(t, hash, info.Hash)
435
436
437
438
439
440
441
442
443
	})

	t.Run("identity lookup returns not-ready when no CR found", func(t *testing.T) {
		c := fake.NewClientBuilder().WithScheme(s).Build()
		identity := testIdentity()
		info, err := ResolveCheckpointForService(ctx, c, testNamespace, &nvidiacomv1alpha1.ServiceCheckpointConfig{
			Enabled: true, Identity: &identity,
		})
		require.NoError(t, err)
444
		assert.False(t, info.Exists)
445
446
447
448
449
450
451
452
453
454
		assert.False(t, info.Ready)
		assert.Len(t, info.Hash, 16)
	})

	t.Run("errors when enabled but no ref and no identity", func(t *testing.T) {
		c := fake.NewClientBuilder().WithScheme(s).Build()
		_, err := ResolveCheckpointForService(ctx, c, testNamespace, &nvidiacomv1alpha1.ServiceCheckpointConfig{Enabled: true})
		assert.ErrorContains(t, err, "no checkpointRef or identity")
	})
}
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470

// findContainer is a test helper that locates a container by name across both
// regular containers and init containers.
func findContainer(podSpec *corev1.PodSpec, name string) *corev1.Container {
	for i := range podSpec.Containers {
		if podSpec.Containers[i].Name == name {
			return &podSpec.Containers[i]
		}
	}
	for i := range podSpec.InitContainers {
		if podSpec.InitContainers[i].Name == name {
			return &podSpec.InitContainers[i]
		}
	}
	return nil
}