"docs/vscode:/vscode.git/clone" did not exist on "2fa4623d9e673faddb05018ac3f3d65780bb083d"
checkpoint_test.go 15.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/*
 * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package checkpoint

import (
	"context"
	"testing"

	nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
	"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
26
	snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
27
28
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
29
	appsv1 "k8s.io/api/apps/v1"
30
31
32
33
	corev1 "k8s.io/api/core/v1"
	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
	"k8s.io/apimachinery/pkg/runtime"
	"k8s.io/utils/ptr"
34
	"sigs.k8s.io/controller-runtime/pkg/client"
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
	"sigs.k8s.io/controller-runtime/pkg/client/fake"
)

const (
	testHash      = "abc123def4567890"
	testNamespace = "default"
)

func testIdentity() nvidiacomv1alpha1.DynamoCheckpointIdentity {
	return nvidiacomv1alpha1.DynamoCheckpointIdentity{
		Model:            "meta-llama/Llama-2-7b-hf",
		BackendFramework: "vllm",
	}
}

func testPodSpec() *corev1.PodSpec {
	return &corev1.PodSpec{
		Containers: []corev1.Container{{
			Name:    consts.MainContainerName,
			Image:   "test-image:latest",
			Command: []string{"python3"},
			Args:    []string{"-m", "dynamo.vllm"},
		}},
	}
}

func testScheme() *runtime.Scheme {
	s := runtime.NewScheme()
	_ = nvidiacomv1alpha1.AddToScheme(s)
	_ = corev1.AddToScheme(s)
65
	_ = appsv1.AddToScheme(s)
66
67
68
69
70
71
72
	return s
}

func testInfo() *CheckpointInfo {
	return &CheckpointInfo{Enabled: true, Hash: testHash}
}

73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
func testSnapshotAgentDaemonSet() *appsv1.DaemonSet {
	return &appsv1.DaemonSet{
		ObjectMeta: metav1.ObjectMeta{
			Name:      "snapshot-agent",
			Namespace: testNamespace,
			Labels: map[string]string{
				snapshotprotocol.SnapshotAgentLabelKey: snapshotprotocol.SnapshotAgentLabelValue,
			},
		},
		Spec: appsv1.DaemonSetSpec{
			Template: corev1.PodTemplateSpec{
				Spec: corev1.PodSpec{
					Containers: []corev1.Container{{
						Name: snapshotprotocol.SnapshotAgentContainerName,
						VolumeMounts: []corev1.VolumeMount{{
							Name:      "checkpoints",
							MountPath: "/checkpoints",
						}},
					}},
					Volumes: []corev1.Volume{{
						Name: "checkpoints",
						VolumeSource: corev1.VolumeSource{
							PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
								ClaimName: "snapshot-pvc",
							},
						},
					}},
				},
			},
		},
	}
}

106
107
108
109
type createHookClient struct {
	client.Client
	onCreate func(ctx context.Context, obj client.Object) error
}
110

111
112
113
114
115
116
117
118
119
120
func (c *createHookClient) Create(ctx context.Context, obj client.Object, opts ...client.CreateOption) error {
	if c.onCreate != nil {
		if err := c.onCreate(ctx, obj); err != nil {
			return err
		}
		c.onCreate = nil
	}

	return c.Client.Create(ctx, obj, opts...)
}
121

122
123
124
125
126
127
128
129
130
131
132
133
134
func TestCreateOrGetAutoCheckpointDeduplicatesConcurrentSameHashCheckpoint(t *testing.T) {
	ctx := context.Background()
	s := testScheme()

	identity := testIdentity()
	hash, err := ComputeIdentityHash(identity)
	require.NoError(t, err)

	friendly := &nvidiacomv1alpha1.DynamoCheckpoint{
		ObjectMeta: metav1.ObjectMeta{
			Name:      "friendly-checkpoint",
			Namespace: testNamespace,
			Labels: map[string]string{
135
				consts.KubeLabelCheckpointID: hash,
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
			},
		},
		Spec: nvidiacomv1alpha1.DynamoCheckpointSpec{
			Identity: identity,
			Job: nvidiacomv1alpha1.DynamoCheckpointJobConfig{
				PodTemplateSpec: corev1.PodTemplateSpec{},
			},
		},
		Status: nvidiacomv1alpha1.DynamoCheckpointStatus{
			IdentityHash: hash,
			Phase:        nvidiacomv1alpha1.DynamoCheckpointPhaseReady,
		},
	}

	baseClient := fake.NewClientBuilder().WithScheme(s).Build()
	c := &createHookClient{
		Client: baseClient,
		onCreate: func(ctx context.Context, obj client.Object) error {
			_, ok := obj.(*nvidiacomv1alpha1.DynamoCheckpoint)
			if !ok {
				return nil
			}
			return baseClient.Create(ctx, friendly.DeepCopy())
		},
	}

	ckpt, err := CreateOrGetAutoCheckpoint(ctx, c, testNamespace, identity, corev1.PodTemplateSpec{})
	require.NoError(t, err)
	assert.Equal(t, friendly.Name, ckpt.Name)

	list := &nvidiacomv1alpha1.DynamoCheckpointList{}
	require.NoError(t, baseClient.List(ctx, list))
	require.Len(t, list.Items, 1)
	assert.Equal(t, friendly.Name, list.Items[0].Name)
}

172
173
174
175
176
177
178
179
180
181
182
func TestCreateOrGetAutoCheckpointSetsDefaultArtifactVersion(t *testing.T) {
	ctx := context.Background()
	s := testScheme()
	c := fake.NewClientBuilder().WithScheme(s).Build()

	ckpt, err := CreateOrGetAutoCheckpoint(ctx, c, testNamespace, testIdentity(), corev1.PodTemplateSpec{})
	require.NoError(t, err)
	require.NotNil(t, ckpt.Annotations)
	assert.Equal(t, consts.DefaultCheckpointArtifactVersion, ckpt.Annotations[consts.KubeAnnotationCheckpointArtifactVersion])
}

183
184
185
// --- InjectCheckpointIntoPodSpec tests ---

func TestInjectCheckpointIntoPodSpec(t *testing.T) {
186
	t.Run("ready checkpoint injects podinfo and overrides command", func(t *testing.T) {
187
		podSpec := testPodSpec()
188
189
190
		info := &CheckpointInfo{Enabled: true, Ready: true, Identity: ptr.To(testIdentity())}
		reader := fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build()
		require.NoError(t, InjectCheckpointIntoPodSpec(context.Background(), reader, testNamespace, podSpec, info))
191
192
		assert.Equal(t, []string{"sleep", "infinity"}, podSpec.Containers[0].Command)
		assert.Nil(t, podSpec.Containers[0].Args)
193
		assert.Len(t, info.Hash, 16)
194

195
196
197
		volumes := map[string]corev1.Volume{}
		for _, volume := range podSpec.Volumes {
			volumes[volume.Name] = volume
198
		}
199
200
		require.Contains(t, volumes, consts.PodInfoVolumeName)
		require.NotNil(t, volumes[consts.PodInfoVolumeName].DownwardAPI)
201

202
203
204
205
		fields := map[string]string{}
		for _, item := range volumes[consts.PodInfoVolumeName].DownwardAPI.Items {
			if item.FieldRef != nil {
				fields[item.Path] = item.FieldRef.FieldPath
206
			}
207
		}
208
209
210
211
212
213
214
215
216
		assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoNamespace+"']", fields[consts.PodInfoFileDynNamespace])
		assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoWorkerHash+"']", fields[consts.PodInfoFileDynNamespaceWorkerSuffix])
		assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoComponentType+"']", fields[consts.PodInfoFileDynComponent])
		assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoGraphDeploymentName+"']", fields[consts.PodInfoFileDynParentDGDName])
		assert.Equal(t, consts.PodInfoFieldPodNamespace, fields[consts.PodInfoFileDynParentDGDNamespace])

		mountPaths := map[string]string{}
		for _, mount := range podSpec.Containers[0].VolumeMounts {
			mountPaths[mount.Name] = mount.MountPath
217
218
219
220
		}
		assert.Equal(t, consts.PodInfoMountPath, mountPaths[consts.PodInfoVolumeName])
	})

221
222
223
224
225
226
	t.Run("ready checkpoint targets the container named main", func(t *testing.T) {
		podSpec := &corev1.PodSpec{
			Containers: []corev1.Container{
				{Name: "sidecar", Image: "sidecar:latest", Command: []string{"sidecar"}, Args: []string{"run"}},
				{Name: consts.MainContainerName, Image: "main:latest", Command: []string{"python3"}, Args: []string{"-m", "dynamo.vllm"}},
			},
227
		}
228
229
230
231
232
233
234
235
		info := &CheckpointInfo{Enabled: true, Ready: true, Hash: testHash}
		reader := fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build()

		require.NoError(t, InjectCheckpointIntoPodSpec(context.Background(), reader, testNamespace, podSpec, info))
		assert.Equal(t, []string{"sidecar"}, podSpec.Containers[0].Command)
		assert.Equal(t, []string{"run"}, podSpec.Containers[0].Args)
		assert.Equal(t, []string{"sleep", "infinity"}, podSpec.Containers[1].Command)
		assert.Nil(t, podSpec.Containers[1].Args)
236
237
238
239
240
241
242
	})

	t.Run("error cases", func(t *testing.T) {
		for _, tc := range []struct {
			name    string
			podSpec *corev1.PodSpec
			info    *CheckpointInfo
243
			reader  client.Reader
244
245
			errMsg  string
		}{
246
247
248
249
			{"hash empty and identity nil", testPodSpec(), &CheckpointInfo{Enabled: true}, fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build(), "identity is nil"},
			{"no containers", &corev1.PodSpec{}, testInfo(), fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build(), "no container found"},
			{"main container missing", &corev1.PodSpec{Containers: []corev1.Container{{Name: "sidecar", Image: "img", Command: []string{"python3"}}}}, testInfo(), fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build(), "main container not found"},
			{"snapshot daemonset missing", testPodSpec(), testInfo(), fake.NewClientBuilder().WithScheme(testScheme()).Build(), "no snapshot-agent daemonset found"},
250
251
		} {
			t.Run(tc.name, func(t *testing.T) {
252
				err := InjectCheckpointIntoPodSpec(context.Background(), tc.reader, testNamespace, tc.podSpec, tc.info)
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
				require.Error(t, err)
				assert.Contains(t, err.Error(), tc.errMsg)
			})
		}
	})
}

// --- ResolveCheckpointForService tests ---

func TestResolveCheckpointForService(t *testing.T) {
	ctx := context.Background()
	s := testScheme()

	t.Run("nil or disabled config returns disabled", func(t *testing.T) {
		c := fake.NewClientBuilder().WithScheme(s).Build()
		for _, cfg := range []*nvidiacomv1alpha1.ServiceCheckpointConfig{nil, {Enabled: false}} {
			info, err := ResolveCheckpointForService(ctx, c, testNamespace, cfg)
			require.NoError(t, err)
			assert.False(t, info.Enabled)
		}
	})

	t.Run("checkpointRef resolves ready CR", func(t *testing.T) {
276
277
		hash, err := ComputeIdentityHash(testIdentity())
		require.NoError(t, err)
278
		ckpt := &nvidiacomv1alpha1.DynamoCheckpoint{
279
			ObjectMeta: metav1.ObjectMeta{Name: hash, Namespace: testNamespace},
280
281
			Spec:       nvidiacomv1alpha1.DynamoCheckpointSpec{Identity: testIdentity()},
			Status: nvidiacomv1alpha1.DynamoCheckpointStatus{
282
283
				Phase:        nvidiacomv1alpha1.DynamoCheckpointPhaseReady,
				IdentityHash: hash,
284
285
286
			},
		}
		c := fake.NewClientBuilder().WithScheme(s).WithObjects(ckpt).WithStatusSubresource(ckpt).Build()
287
		ref := hash
288
289
290
291
292

		info, err := ResolveCheckpointForService(ctx, c, testNamespace, &nvidiacomv1alpha1.ServiceCheckpointConfig{
			Enabled: true, CheckpointRef: &ref,
		})
		require.NoError(t, err)
293
		assert.True(t, info.Exists)
294
		assert.True(t, info.Ready)
295
296
		assert.Equal(t, hash, info.Hash)
		assert.Equal(t, hash, info.CheckpointName)
297
298
299
	})

	t.Run("checkpointRef resolves not-ready CR", func(t *testing.T) {
300
301
		hash, err := ComputeIdentityHash(testIdentity())
		require.NoError(t, err)
302
		ckpt := &nvidiacomv1alpha1.DynamoCheckpoint{
303
			ObjectMeta: metav1.ObjectMeta{Name: hash, Namespace: testNamespace},
304
305
306
307
			Spec:       nvidiacomv1alpha1.DynamoCheckpointSpec{Identity: testIdentity()},
			Status:     nvidiacomv1alpha1.DynamoCheckpointStatus{Phase: nvidiacomv1alpha1.DynamoCheckpointPhaseCreating},
		}
		c := fake.NewClientBuilder().WithScheme(s).WithObjects(ckpt).WithStatusSubresource(ckpt).Build()
308
		ref := hash
309
310
311
312
313

		info, err := ResolveCheckpointForService(ctx, c, testNamespace, &nvidiacomv1alpha1.ServiceCheckpointConfig{
			Enabled: true, CheckpointRef: &ref,
		})
		require.NoError(t, err)
314
		assert.True(t, info.Exists)
315
316
317
318
319
320
321
322
323
324
325
326
		assert.False(t, info.Ready)
	})

	t.Run("checkpointRef errors when CR not found", func(t *testing.T) {
		c := fake.NewClientBuilder().WithScheme(s).Build()
		ref := "nonexistent"
		_, err := ResolveCheckpointForService(ctx, c, testNamespace, &nvidiacomv1alpha1.ServiceCheckpointConfig{
			Enabled: true, CheckpointRef: &ref,
		})
		assert.ErrorContains(t, err, "nonexistent")
	})

327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
	t.Run("checkpointRef resolves human-readable checkpoint names", func(t *testing.T) {
		hash, err := ComputeIdentityHash(testIdentity())
		require.NoError(t, err)
		ckpt := &nvidiacomv1alpha1.DynamoCheckpoint{
			ObjectMeta: metav1.ObjectMeta{Name: "not-the-hash", Namespace: testNamespace},
			Spec:       nvidiacomv1alpha1.DynamoCheckpointSpec{Identity: testIdentity()},
			Status: nvidiacomv1alpha1.DynamoCheckpointStatus{
				IdentityHash: hash,
			},
		}
		c := fake.NewClientBuilder().WithScheme(s).WithObjects(ckpt).WithStatusSubresource(ckpt).Build()
		ref := "not-the-hash"

		info, err := ResolveCheckpointForService(ctx, c, testNamespace, &nvidiacomv1alpha1.ServiceCheckpointConfig{
			Enabled: true, CheckpointRef: &ref,
		})
		require.NoError(t, err)
		assert.Equal(t, "not-the-hash", info.CheckpointName)
		assert.Equal(t, hash, info.Hash)
	})

	t.Run("identity lookup finds existing checkpoint by identity hash", func(t *testing.T) {
349
350
351
352
353
		identity := testIdentity()
		hash, err := ComputeIdentityHash(identity)
		require.NoError(t, err)

		ckpt := &nvidiacomv1alpha1.DynamoCheckpoint{
354
355
			ObjectMeta: metav1.ObjectMeta{Name: "friendly-name", Namespace: testNamespace},
			Spec:       nvidiacomv1alpha1.DynamoCheckpointSpec{Identity: identity},
356
			Status: nvidiacomv1alpha1.DynamoCheckpointStatus{
357
358
				Phase:        nvidiacomv1alpha1.DynamoCheckpointPhaseReady,
				IdentityHash: hash,
359
360
361
362
363
364
365
366
			},
		}
		c := fake.NewClientBuilder().WithScheme(s).WithObjects(ckpt).WithStatusSubresource(ckpt).Build()

		info, err := ResolveCheckpointForService(ctx, c, testNamespace, &nvidiacomv1alpha1.ServiceCheckpointConfig{
			Enabled: true, Identity: &identity,
		})
		require.NoError(t, err)
367
		assert.True(t, info.Exists)
368
369
		assert.True(t, info.Ready)
		assert.Equal(t, hash, info.Hash)
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
		assert.Equal(t, "friendly-name", info.CheckpointName)
	})

	t.Run("identity lookup returns existing not-ready checkpoint", func(t *testing.T) {
		identity := testIdentity()
		hash, err := ComputeIdentityHash(identity)
		require.NoError(t, err)

		ckpt := &nvidiacomv1alpha1.DynamoCheckpoint{
			ObjectMeta: metav1.ObjectMeta{Name: "friendly-name", Namespace: testNamespace},
			Spec:       nvidiacomv1alpha1.DynamoCheckpointSpec{Identity: identity},
			Status: nvidiacomv1alpha1.DynamoCheckpointStatus{
				Phase:        nvidiacomv1alpha1.DynamoCheckpointPhaseCreating,
				IdentityHash: hash,
			},
		}
		c := fake.NewClientBuilder().WithScheme(s).WithObjects(ckpt).WithStatusSubresource(ckpt).Build()

		info, err := ResolveCheckpointForService(ctx, c, testNamespace, &nvidiacomv1alpha1.ServiceCheckpointConfig{
			Enabled: true, Identity: &identity,
		})
		require.NoError(t, err)
		assert.True(t, info.Exists)
		assert.False(t, info.Ready)
		assert.Equal(t, hash, info.Hash)
395
396
397
398
399
400
401
402
403
	})

	t.Run("identity lookup returns not-ready when no CR found", func(t *testing.T) {
		c := fake.NewClientBuilder().WithScheme(s).Build()
		identity := testIdentity()
		info, err := ResolveCheckpointForService(ctx, c, testNamespace, &nvidiacomv1alpha1.ServiceCheckpointConfig{
			Enabled: true, Identity: &identity,
		})
		require.NoError(t, err)
404
		assert.False(t, info.Exists)
405
406
407
408
409
410
411
412
413
414
		assert.False(t, info.Ready)
		assert.Len(t, info.Hash, 16)
	})

	t.Run("errors when enabled but no ref and no identity", func(t *testing.T) {
		c := fake.NewClientBuilder().WithScheme(s).Build()
		_, err := ResolveCheckpointForService(ctx, c, testNamespace, &nvidiacomv1alpha1.ServiceCheckpointConfig{Enabled: true})
		assert.ErrorContains(t, err, "no checkpointRef or identity")
	})
}