Unverified Commit f3b181a9 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

feat(gms): operator-managed GMS checkpoint/restore support (#8153)

parent 091cdb51
......@@ -67,6 +67,31 @@ spec:
spec:
description: DynamoCheckpointSpec defines the desired state of DynamoCheckpoint
properties:
gpuMemoryService:
description: |-
GPUMemoryService enables checkpoint-time GPU Memory Service wiring.
It is intentionally outside spec.identity, so it does not affect the
checkpoint identity hash or deduplication.
properties:
deviceClassName:
default: gpu.nvidia.com
description: DeviceClassName is the DRA DeviceClass to request GPUs from.
type: string
enabled:
description: |-
Enabled activates the GMS sidecar. GPU resources on the main container
are replaced with a DRA ResourceClaim for shared GPU access.
type: boolean
mode:
default: intraPod
description: Mode selects the GMS deployment topology.
enum:
- intraPod
- interPod
type: string
required:
- enabled
type: object
identity:
description: Identity defines the inputs that determine checkpoint equivalence
properties:
......
......@@ -124,6 +124,12 @@ type DynamoCheckpointSpec struct {
// +kubebuilder:validation:Required
Identity DynamoCheckpointIdentity `json:"identity"`
// GPUMemoryService enables checkpoint-time GPU Memory Service wiring.
// It is intentionally outside spec.identity, so it does not affect the
// checkpoint identity hash or deduplication.
// +optional
GPUMemoryService *GPUMemoryServiceSpec `json:"gpuMemoryService,omitempty"`
// Job defines the configuration for the checkpoint creation Job
// +kubebuilder:validation:Required
Job DynamoCheckpointJobConfig `json:"job"`
......
......@@ -340,6 +340,11 @@ func (in *DynamoCheckpointList) DeepCopyObject() runtime.Object {
func (in *DynamoCheckpointSpec) DeepCopyInto(out *DynamoCheckpointSpec) {
*out = *in
in.Identity.DeepCopyInto(&out.Identity)
if in.GPUMemoryService != nil {
in, out := &in.GPUMemoryService, &out.GPUMemoryService
*out = new(GPUMemoryServiceSpec)
**out = **in
}
in.Job.DeepCopyInto(&out.Job)
}
......
......@@ -67,6 +67,31 @@ spec:
spec:
description: DynamoCheckpointSpec defines the desired state of DynamoCheckpoint
properties:
gpuMemoryService:
description: |-
GPUMemoryService enables checkpoint-time GPU Memory Service wiring.
It is intentionally outside spec.identity, so it does not affect the
checkpoint identity hash or deduplication.
properties:
deviceClassName:
default: gpu.nvidia.com
description: DeviceClassName is the DRA DeviceClass to request GPUs from.
type: string
enabled:
description: |-
Enabled activates the GMS sidecar. GPU resources on the main container
are replaced with a DRA ResourceClaim for shared GPU access.
type: boolean
mode:
default: intraPod
description: Mode selects the GMS deployment topology.
enum:
- intraPod
- interPod
type: string
required:
- enabled
type: object
identity:
description: Identity defines the inputs that determine checkpoint equivalence
properties:
......
......@@ -27,6 +27,10 @@ spec:
dtype: "bfloat16"
maxModelLen: 2048
# Optional: enable GMS-specific checkpoint capture and restore helpers.
gpuMemoryService:
enabled: false
# Job configuration for checkpoint creation
job:
activeDeadlineSeconds: 3600
......
......@@ -23,6 +23,7 @@ import (
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
......@@ -159,7 +160,7 @@ func TestCreateOrGetAutoCheckpointDeduplicatesConcurrentSameHashCheckpoint(t *te
},
}
ckpt, err := CreateOrGetAutoCheckpoint(ctx, c, testNamespace, identity, corev1.PodTemplateSpec{})
ckpt, err := CreateOrGetAutoCheckpoint(ctx, c, testNamespace, identity, corev1.PodTemplateSpec{}, nil)
require.NoError(t, err)
assert.Equal(t, friendly.Name, ckpt.Name)
......@@ -174,7 +175,7 @@ func TestCreateOrGetAutoCheckpointSetsDefaultArtifactVersion(t *testing.T) {
s := testScheme()
c := fake.NewClientBuilder().WithScheme(s).Build()
ckpt, err := CreateOrGetAutoCheckpoint(ctx, c, testNamespace, testIdentity(), corev1.PodTemplateSpec{})
ckpt, err := CreateOrGetAutoCheckpoint(ctx, c, testNamespace, testIdentity(), corev1.PodTemplateSpec{}, nil)
require.NoError(t, err)
require.NotNil(t, ckpt.Annotations)
assert.Equal(t, snapshotprotocol.DefaultCheckpointArtifactVersion, ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation])
......@@ -182,6 +183,50 @@ func TestCreateOrGetAutoCheckpointSetsDefaultArtifactVersion(t *testing.T) {
// --- InjectCheckpointIntoPodSpec tests ---
func TestEnsurePodInfoVolumeMergesExistingDownwardAPIItems(t *testing.T) {
podSpec := &corev1.PodSpec{
Volumes: []corev1.Volume{{
Name: consts.PodInfoVolumeName,
VolumeSource: corev1.VolumeSource{
DownwardAPI: &corev1.DownwardAPIVolumeSource{
Items: []corev1.DownwardAPIVolumeFile{
{
Path: "pod_name",
FieldRef: &corev1.ObjectFieldSelector{FieldPath: "metadata.name"},
},
{
Path: "custom",
FieldRef: &corev1.ObjectFieldSelector{FieldPath: "metadata.labels['custom']"},
},
},
},
},
}},
}
EnsurePodInfoVolume(podSpec)
require.Len(t, podSpec.Volumes, 1)
require.NotNil(t, podSpec.Volumes[0].DownwardAPI)
fields := map[string]string{}
for _, item := range podSpec.Volumes[0].DownwardAPI.Items {
if item.FieldRef != nil {
fields[item.Path] = item.FieldRef.FieldPath
}
}
assert.Equal(t, consts.PodInfoFieldPodName, fields["pod_name"])
assert.Equal(t, consts.PodInfoFieldPodUID, fields["pod_uid"])
assert.Equal(t, consts.PodInfoFieldPodNamespace, fields["pod_namespace"])
assert.Equal(t, "metadata.labels['custom']", fields["custom"])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoNamespace+"']", fields[consts.PodInfoFileDynNamespace])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoWorkerHash+"']", fields[consts.PodInfoFileDynNamespaceWorkerSuffix])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoComponentType+"']", fields[consts.PodInfoFileDynComponent])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoGraphDeploymentName+"']", fields[consts.PodInfoFileDynParentDGDName])
assert.Equal(t, consts.PodInfoFieldPodNamespace, fields[consts.PodInfoFileDynParentDGDNamespace])
}
func TestInjectCheckpointIntoPodSpec(t *testing.T) {
t.Run("ready checkpoint injects podinfo and overrides command", func(t *testing.T) {
podSpec := testPodSpec()
......@@ -218,6 +263,50 @@ func TestInjectCheckpointIntoPodSpec(t *testing.T) {
assert.Equal(t, consts.PodInfoMountPath, mountPaths[consts.PodInfoVolumeName])
})
t.Run("ready checkpoint augments existing podinfo volume", func(t *testing.T) {
podSpec := testPodSpec()
podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{
Name: consts.PodInfoVolumeName,
VolumeSource: corev1.VolumeSource{
DownwardAPI: &corev1.DownwardAPIVolumeSource{
Items: []corev1.DownwardAPIVolumeFile{
{Path: "pod_name", FieldRef: &corev1.ObjectFieldSelector{FieldPath: consts.PodInfoFieldPodName}},
{Path: "pod_uid", FieldRef: &corev1.ObjectFieldSelector{FieldPath: consts.PodInfoFieldPodUID}},
{Path: "pod_namespace", FieldRef: &corev1.ObjectFieldSelector{FieldPath: consts.PodInfoFieldPodNamespace}},
},
},
},
})
info := &CheckpointInfo{Enabled: true, Ready: true, Identity: ptr.To(testIdentity())}
reader := fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build()
require.NoError(t, InjectCheckpointIntoPodSpec(context.Background(), reader, testNamespace, podSpec, info))
var podInfoVolume *corev1.Volume
for i := range podSpec.Volumes {
if podSpec.Volumes[i].Name == consts.PodInfoVolumeName {
podInfoVolume = &podSpec.Volumes[i]
break
}
}
require.NotNil(t, podInfoVolume)
require.NotNil(t, podInfoVolume.DownwardAPI)
fields := map[string]string{}
for _, item := range podInfoVolume.DownwardAPI.Items {
if item.FieldRef != nil {
fields[item.Path] = item.FieldRef.FieldPath
}
}
assert.Equal(t, consts.PodInfoFieldPodName, fields["pod_name"])
assert.Equal(t, consts.PodInfoFieldPodUID, fields["pod_uid"])
assert.Equal(t, consts.PodInfoFieldPodNamespace, fields["pod_namespace"])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoNamespace+"']", fields[consts.PodInfoFileDynNamespace])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoWorkerHash+"']", fields[consts.PodInfoFileDynNamespaceWorkerSuffix])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoComponentType+"']", fields[consts.PodInfoFileDynComponent])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoGraphDeploymentName+"']", fields[consts.PodInfoFileDynParentDGDName])
assert.Equal(t, consts.PodInfoFieldPodNamespace, fields[consts.PodInfoFileDynParentDGDNamespace])
})
t.Run("ready checkpoint targets the container named main", func(t *testing.T) {
podSpec := &corev1.PodSpec{
Containers: []corev1.Container{
......@@ -235,6 +324,39 @@ func TestInjectCheckpointIntoPodSpec(t *testing.T) {
assert.Nil(t, podSpec.Containers[1].Args)
})
t.Run("ready gms checkpoint injects restore sidecars and loader mount", func(t *testing.T) {
podSpec := testPodSpec()
podSpec.Containers[0].Resources.Claims = []corev1.ResourceClaim{{Name: "gpu"}}
info := &CheckpointInfo{Enabled: true, Ready: true, Hash: testHash, GPUMemoryService: &nvidiacomv1alpha1.GPUMemoryServiceSpec{Enabled: true}}
reader := fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build()
require.NoError(t, InjectCheckpointIntoPodSpec(context.Background(), reader, testNamespace, podSpec, info))
gmsServer := findContainer(podSpec, gmsruntime.ServerContainerName)
require.NotNil(t, gmsServer)
loader := findContainer(podSpec, GMSLoaderContainer)
require.NotNil(t, loader)
// Restore: gms-server should be a regular container, not an init container
assert.Empty(t, podSpec.InitContainers, "restore pods should not have gms-server as init container")
assert.Nil(t, gmsServer.RestartPolicy, "restore gms-server should not have RestartPolicy")
assert.Nil(t, gmsServer.StartupProbe, "restore gms-server should not have StartupProbe")
mounts := map[string]string{}
for _, mount := range loader.VolumeMounts {
mounts[mount.Name] = mount.MountPath
}
assert.Equal(t, "/checkpoints", mounts[snapshotprotocol.CheckpointVolumeName])
assert.Equal(t, gmsruntime.SharedMountPath, mounts[gmsruntime.SharedVolumeName])
env := map[string]string{}
for _, item := range loader.Env {
env[item.Name] = item.Value
}
assert.Equal(t, "/checkpoints/gms/"+testHash+"/versions/1", env["GMS_CHECKPOINT_DIR"])
assert.Equal(t, []string{"python3", "-m", "gpu_memory_service.cli.server"}, gmsServer.Command)
assert.Equal(t, []string{"python3", "-m", "gpu_memory_service.cli.snapshot.loader"}, loader.Command)
})
t.Run("error cases", func(t *testing.T) {
for _, tc := range []struct {
name string
......@@ -277,7 +399,10 @@ func TestResolveCheckpointForService(t *testing.T) {
require.NoError(t, err)
ckpt := &nvidiacomv1alpha1.DynamoCheckpoint{
ObjectMeta: metav1.ObjectMeta{Name: hash, Namespace: testNamespace},
Spec: nvidiacomv1alpha1.DynamoCheckpointSpec{Identity: testIdentity()},
Spec: nvidiacomv1alpha1.DynamoCheckpointSpec{
Identity: testIdentity(),
GPUMemoryService: &nvidiacomv1alpha1.GPUMemoryServiceSpec{Enabled: true},
},
Status: nvidiacomv1alpha1.DynamoCheckpointStatus{
Phase: nvidiacomv1alpha1.DynamoCheckpointPhaseReady,
IdentityHash: hash,
......@@ -294,6 +419,8 @@ func TestResolveCheckpointForService(t *testing.T) {
assert.True(t, info.Ready)
assert.Equal(t, hash, info.Hash)
assert.Equal(t, hash, info.CheckpointName)
require.NotNil(t, info.GPUMemoryService)
assert.True(t, info.GPUMemoryService.Enabled)
})
t.Run("checkpointRef resolves not-ready CR", func(t *testing.T) {
......@@ -412,3 +539,19 @@ func TestResolveCheckpointForService(t *testing.T) {
assert.ErrorContains(t, err, "no checkpointRef or identity")
})
}
// findContainer is a test helper that locates a container by name across both
// regular containers and init containers.
func findContainer(podSpec *corev1.PodSpec, name string) *corev1.Container {
for i := range podSpec.Containers {
if podSpec.Containers[i].Name == name {
return &podSpec.Containers[i]
}
}
for i := range podSpec.InitContainers {
if podSpec.InitContainers[i].Name == name {
return &podSpec.InitContainers[i]
}
}
return nil
}
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package checkpoint
import (
"context"
"fmt"
"path/filepath"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
ctrlclient "sigs.k8s.io/controller-runtime/pkg/client"
)
const (
GMSLoaderContainer = "gms-loader"
GMSSaverContainer = "gms-saver"
gmsCheckpointLoaderModule = "gpu_memory_service.cli.snapshot.loader"
gmsCheckpointSaverModule = "gpu_memory_service.cli.snapshot.saver"
)
func ResolveGMSCheckpointStorage(
ctx context.Context,
reader ctrlclient.Reader,
namespace string,
checkpointID string,
artifactVersion string,
) (snapshotprotocol.Storage, error) {
if reader == nil {
return snapshotprotocol.Storage{}, fmt.Errorf("checkpoint client is required")
}
daemonSets := &appsv1.DaemonSetList{}
if err := reader.List(
ctx,
daemonSets,
ctrlclient.InNamespace(namespace),
ctrlclient.MatchingLabels{snapshotprotocol.SnapshotAgentLabelKey: snapshotprotocol.SnapshotAgentLabelValue},
); err != nil {
return snapshotprotocol.Storage{}, fmt.Errorf("list snapshot-agent daemonsets in %s: %w", namespace, err)
}
storage, err := snapshotprotocol.DiscoverStorageFromDaemonSets(namespace, daemonSets.Items)
if err != nil {
return snapshotprotocol.Storage{}, err
}
return snapshotprotocol.ResolveCheckpointStorage(checkpointID, artifactVersion, storage)
}
// BuildGMSRestoreSidecars prepares GMS infrastructure for a restore pod and
// returns the additional containers the caller must append to podSpec.Containers.
//
// The GMS server runs as a regular container (not init) because the CRIU-restored
// main process already has GPU memory mapped and does not need sockets at
// startup. The gms-loader polls for sockets internally via wait_for_weights_socket.
func BuildGMSRestoreSidecars(
podSpec *corev1.PodSpec,
mainContainer *corev1.Container,
storage snapshotprotocol.Storage,
) []corev1.Container {
if podSpec == nil || mainContainer == nil {
return nil
}
// Remove gms-server from initContainers if the DGD-level
// applyGPUMemoryService already placed it there. For restore pods the
// server runs as a regular container so that all containers start in
// parallel — the restored main process does not need sockets at startup.
for i := range podSpec.InitContainers {
if podSpec.InitContainers[i].Name == gmsruntime.ServerContainerName {
podSpec.InitContainers = append(podSpec.InitContainers[:i], podSpec.InitContainers[i+1:]...)
break
}
}
server := gmsruntime.BuildServerContainer(podSpec, mainContainer)
loader := gmsCheckpointLoaderContainer(mainContainer.Image)
copyGMSDeviceClaims(mainContainer, &loader)
ensureCheckpointVolume(podSpec, storage.PVCName)
loader.VolumeMounts = append(loader.VolumeMounts, corev1.VolumeMount{Name: snapshotprotocol.CheckpointVolumeName, MountPath: storage.BasePath})
loader.Env = append(loader.Env, corev1.EnvVar{Name: "GMS_CHECKPOINT_DIR", Value: resolveGMSArtifactDir(storage)})
return []corev1.Container{server, loader}
}
// BuildGMSCheckpointJobSidecars prepares GMS infrastructure for a checkpoint
// job and returns the additional containers the caller must append to
// podSpec.Containers.
func BuildGMSCheckpointJobSidecars(
podSpec *corev1.PodSpec,
mainContainer *corev1.Container,
storage snapshotprotocol.Storage,
) ([]corev1.Container, error) {
if podSpec == nil || mainContainer == nil {
return nil, nil
}
if len(mainContainer.Resources.Claims) == 0 {
return nil, fmt.Errorf("gms sidecars require main container resource claims")
}
if storage.PVCName == "" || storage.BasePath == "" || storage.Location == "" {
return nil, fmt.Errorf("gms checkpoint jobs require resolved checkpoint storage")
}
gmsruntime.EnsureServerSidecar(podSpec, mainContainer)
ensureGMSCheckpointControl(podSpec)
saver := gmsCheckpointSaverContainer(mainContainer.Image)
copyGMSDeviceClaims(mainContainer, &saver)
ensureCheckpointVolume(podSpec, storage.PVCName)
saver.VolumeMounts = append(saver.VolumeMounts, corev1.VolumeMount{Name: snapshotprotocol.CheckpointVolumeName, MountPath: storage.BasePath})
saver.Env = append(saver.Env, corev1.EnvVar{Name: "GMS_CHECKPOINT_DIR", Value: resolveGMSArtifactDir(storage)})
return []corev1.Container{saver}, nil
}
func resolveGMSArtifactDir(storage snapshotprotocol.Storage) string {
// GMS data lives under /checkpoints/gms/<hash>/versions/<version>
// separate from the CRIU tree (/checkpoints/<hash>/versions/<version>)
// so the non-root saver can create directories at the PVC root.
artifactVersion := filepath.Base(storage.Location)
checkpointID := filepath.Base(filepath.Dir(filepath.Dir(storage.Location)))
return filepath.Join(storage.BasePath, "gms", checkpointID, "versions", artifactVersion)
}
func gmsCheckpointLoaderContainer(image string) corev1.Container {
container := corev1.Container{
Name: GMSLoaderContainer,
Image: image,
Command: []string{"python3", "-m", gmsCheckpointLoaderModule},
Env: []corev1.EnvVar{
{Name: "TMPDIR", Value: gmsruntime.SharedMountPath},
{Name: "GMS_SOCKET_DIR", Value: gmsruntime.SharedMountPath},
},
VolumeMounts: []corev1.VolumeMount{
{Name: gmsruntime.SharedVolumeName, MountPath: gmsruntime.SharedMountPath},
},
}
return container
}
func gmsCheckpointSaverContainer(image string) corev1.Container {
container := corev1.Container{
Name: GMSSaverContainer,
Image: image,
Command: []string{"python3", "-m", gmsCheckpointSaverModule},
Env: []corev1.EnvVar{
{Name: "POD_NAME", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{FieldPath: "metadata.name"}}},
{Name: "POD_NAMESPACE", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{FieldPath: "metadata.namespace"}}},
{Name: "TMPDIR", Value: gmsruntime.SharedMountPath},
{Name: "GMS_SOCKET_DIR", Value: gmsruntime.SharedMountPath},
{Name: "GMS_CONTROL_DIR", Value: gmsruntime.ControlDir},
},
VolumeMounts: []corev1.VolumeMount{
{Name: gmsruntime.SharedVolumeName, MountPath: gmsruntime.SharedMountPath},
{Name: gmsruntime.ControlVolumeName, MountPath: gmsruntime.ControlDir},
},
}
return container
}
// ensureGMSCheckpointControl adds the control volume and injects
// GMS_CONTROL_DIR into the GMS server container for checkpoint coordination.
func ensureGMSCheckpointControl(podSpec *corev1.PodSpec) {
podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{
Name: gmsruntime.ControlVolumeName,
VolumeSource: corev1.VolumeSource{EmptyDir: &corev1.EmptyDirVolumeSource{}},
})
server := gmsruntime.FindServerContainer(podSpec)
if server != nil {
server.VolumeMounts = append(server.VolumeMounts, corev1.VolumeMount{Name: gmsruntime.ControlVolumeName, MountPath: gmsruntime.ControlDir})
server.Env = append(server.Env, corev1.EnvVar{Name: "GMS_CONTROL_DIR", Value: gmsruntime.ControlDir})
}
}
func copyGMSDeviceClaims(mainContainer *corev1.Container, container *corev1.Container) {
if mainContainer == nil || container == nil || len(mainContainer.Resources.Claims) == 0 {
return
}
container.Resources.Claims = append([]corev1.ResourceClaim{}, mainContainer.Resources.Claims...)
}
func ensureCheckpointVolume(podSpec *corev1.PodSpec, pvcName string) {
if pvcName == "" {
return
}
for i := range podSpec.Volumes {
if podSpec.Volumes[i].Name == snapshotprotocol.CheckpointVolumeName {
return
}
}
podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{
Name: snapshotprotocol.CheckpointVolumeName,
VolumeSource: corev1.VolumeSource{
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ClaimName: pvcName},
},
})
}
......@@ -9,17 +9,42 @@ import (
)
func EnsurePodInfoVolume(podSpec *corev1.PodSpec) {
for _, volume := range podSpec.Volumes {
if volume.Name == commonconsts.PodInfoVolumeName {
return
for i := range podSpec.Volumes {
if podSpec.Volumes[i].Name != commonconsts.PodInfoVolumeName {
continue
}
if podSpec.Volumes[i].DownwardAPI == nil {
podSpec.Volumes[i].VolumeSource.DownwardAPI = &corev1.DownwardAPIVolumeSource{}
}
// Merge required items into existing downwardAPI volume.
source := podSpec.Volumes[i].DownwardAPI
pathToIndex := make(map[string]int, len(source.Items))
for j := range source.Items {
pathToIndex[source.Items[j].Path] = j
}
for _, item := range podInfoItems() {
if idx, ok := pathToIndex[item.Path]; ok {
source.Items[idx] = item
continue
}
source.Items = append(source.Items, item)
pathToIndex[item.Path] = len(source.Items) - 1
}
return
}
podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{
Name: commonconsts.PodInfoVolumeName,
VolumeSource: corev1.VolumeSource{
DownwardAPI: &corev1.DownwardAPIVolumeSource{
Items: []corev1.DownwardAPIVolumeFile{
Items: podInfoItems(),
},
},
})
}
func podInfoItems() []corev1.DownwardAPIVolumeFile {
return []corev1.DownwardAPIVolumeFile{
{
Path: "pod_name",
FieldRef: &corev1.ObjectFieldSelector{
......@@ -68,10 +93,7 @@ func EnsurePodInfoVolume(podSpec *corev1.PodSpec) {
FieldPath: commonconsts.PodInfoFieldPodNamespace,
},
},
},
},
},
})
}
}
func EnsurePodInfoMount(container *corev1.Container) {
......
......@@ -94,5 +94,26 @@ func InjectCheckpointIntoPodSpec(
EnsurePodInfoVolume(podSpec)
EnsurePodInfoMount(mainContainer)
// GMS restore sidecars (server + loader) are only needed when the checkpoint
// is ready and the pod will actually be CRIU-restored.
if info.Ready && info.GPUMemoryService != nil && info.GPUMemoryService.Enabled {
if len(mainContainer.Resources.Claims) == 0 {
return fmt.Errorf("gms sidecars require main container resource claims")
}
storage, err := ResolveGMSCheckpointStorage(
ctx,
reader,
namespace,
info.Hash,
info.ArtifactVersion,
)
if err != nil {
return err
}
gmsSidecars := BuildGMSRestoreSidecars(podSpec, mainContainer, storage)
podSpec.Containers = append(podSpec.Containers, gmsSidecars...)
}
return nil
}
......@@ -31,6 +31,7 @@ type CheckpointInfo struct {
Enabled bool
Exists bool
Identity *nvidiacomv1alpha1.DynamoCheckpointIdentity
GPUMemoryService *nvidiacomv1alpha1.GPUMemoryServiceSpec
Hash string
ArtifactVersion string
CheckpointName string
......@@ -47,6 +48,7 @@ func checkpointInfoFromObject(ckpt *nvidiacomv1alpha1.DynamoCheckpoint) (*Checkp
Enabled: true,
Exists: true,
Identity: &ckpt.Spec.Identity,
GPUMemoryService: ckpt.Spec.GPUMemoryService,
Hash: hash,
ArtifactVersion: checkpointArtifactVersion(ckpt),
CheckpointName: ckpt.Name,
......
......@@ -107,6 +107,7 @@ func CreateOrGetAutoCheckpoint(
namespace string,
identity nvidiacomv1alpha1.DynamoCheckpointIdentity,
podTemplate corev1.PodTemplateSpec,
gpuMemoryService *nvidiacomv1alpha1.GPUMemoryServiceSpec,
) (*nvidiacomv1alpha1.DynamoCheckpoint, error) {
hash, err := ComputeIdentityHash(identity)
if err != nil {
......@@ -126,6 +127,7 @@ func CreateOrGetAutoCheckpoint(
},
Spec: nvidiacomv1alpha1.DynamoCheckpointSpec{
Identity: identity,
GPUMemoryService: gpuMemoryService,
Job: nvidiacomv1alpha1.DynamoCheckpointJobConfig{
PodTemplateSpec: podTemplate,
},
......
......@@ -4,6 +4,7 @@
package controller
import (
"context"
"fmt"
configv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/config/v1alpha1"
......@@ -16,6 +17,7 @@ import (
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
ctrlclient "sigs.k8s.io/controller-runtime/pkg/client"
)
func buildCheckpointWorkerDefaultEnv(
......@@ -51,6 +53,8 @@ func buildCheckpointWorkerDefaultEnv(
}
func buildCheckpointJob(
ctx context.Context,
reader ctrlclient.Reader,
config *configv1alpha1.OperatorConfiguration,
ckpt *nvidiacomv1alpha1.DynamoCheckpoint,
jobName string,
......@@ -77,8 +81,10 @@ func buildCheckpointJob(
checkpoint.EnsurePodInfoVolume(&podTemplate.Spec)
if len(podTemplate.Spec.Containers) > 0 {
mainContainer := &podTemplate.Spec.Containers[0]
mainContainer, err := snapshotprotocol.ResolveCheckpointWorkerContainer(&podTemplate.Spec)
if err != nil {
return nil, err
}
mainContainer.Env = dynamo.MergeEnvs(
buildCheckpointWorkerDefaultEnv(ckpt, podTemplate),
mainContainer.Env,
......@@ -101,7 +107,25 @@ func buildCheckpointJob(
mainContainer.StartupProbe = nil
checkpoint.EnsurePodInfoMount(mainContainer)
dynamo.ApplySharedMemoryVolumeAndMount(&podTemplate.Spec, mainContainer, ckpt.Spec.Job.SharedMemory)
var gmsSidecars []corev1.Container
if ckpt.Spec.GPUMemoryService != nil && ckpt.Spec.GPUMemoryService.Enabled {
storage, err := checkpoint.ResolveGMSCheckpointStorage(
ctx,
reader,
ckpt.Namespace,
hash,
ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation],
)
if err != nil {
return nil, err
}
gmsSidecars, err = checkpoint.BuildGMSCheckpointJobSidecars(&podTemplate.Spec, mainContainer, storage)
if err != nil {
return nil, err
}
}
podTemplate.Spec.Containers = append(podTemplate.Spec.Containers, gmsSidecars...)
activeDeadlineSeconds := ckpt.Spec.Job.ActiveDeadlineSeconds
if activeDeadlineSeconds == nil {
......@@ -110,11 +134,9 @@ func buildCheckpointJob(
}
wrapLaunchJob := false
if len(podTemplate.Spec.Containers) != 0 {
if gpus, ok := podTemplate.Spec.Containers[0].Resources.Limits[corev1.ResourceName(consts.KubeResourceGPUNvidia)]; ok {
if gpus, ok := mainContainer.Resources.Limits[corev1.ResourceName(consts.KubeResourceGPUNvidia)]; ok {
wrapLaunchJob = gpus.Cmp(*resource.NewQuantity(1, resource.DecimalSI)) > 0
}
}
ttlSecondsAfterFinish := snapshotprotocol.DefaultCheckpointJobTTLSeconds
return snapshotprotocol.NewCheckpointJob(podTemplate, snapshotprotocol.CheckpointJobOptions{
......
......@@ -197,7 +197,7 @@ func (r *CheckpointReconciler) handlePending(ctx context.Context, ckpt *nvidiaco
// Use SyncResource to create/update the checkpoint Job
modified, _, err := commonController.SyncResource(ctx, r, ckpt, func(ctx context.Context) (*batchv1.Job, bool, error) {
job, err := buildCheckpointJob(r.Config, ckpt, jobName)
job, err := buildCheckpointJob(ctx, r.Client, r.Config, ckpt, jobName)
return job, false, err
})
if err != nil {
......
......@@ -26,9 +26,11 @@ import (
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
appsv1 "k8s.io/api/apps/v1"
batchv1 "k8s.io/api/batch/v1"
coordinationv1 "k8s.io/api/coordination/v1"
corev1 "k8s.io/api/core/v1"
......@@ -65,6 +67,7 @@ var defaultCheckpointJobName = snapshotprotocol.GetCheckpointJobName(testHash, s
func checkpointTestScheme() *runtime.Scheme {
s := runtime.NewScheme()
_ = nvidiacomv1alpha1.AddToScheme(s)
_ = appsv1.AddToScheme(s)
_ = corev1.AddToScheme(s)
_ = batchv1.AddToScheme(s)
_ = coordinationv1.AddToScheme(s)
......@@ -130,6 +133,17 @@ func makeCheckpointLease(name string, renewTime time.Time, durationSeconds int32
}
}
func requireCheckpointContainer(t *testing.T, containers []corev1.Container, name string) *corev1.Container {
t.Helper()
for i := range containers {
if containers[i].Name == name {
return &containers[i]
}
}
t.Fatalf("container %q not found", name)
return nil
}
func TestBuildCheckpointJob(t *testing.T) {
s := checkpointTestScheme()
ckpt := makeTestCheckpoint(nvidiacomv1alpha1.DynamoCheckpointPhasePending)
......@@ -139,7 +153,7 @@ func TestBuildCheckpointJob(t *testing.T) {
}
r := makeCheckpointReconciler(s, ckpt)
job, err := buildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
job, err := buildCheckpointJob(context.Background(), nil, r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err)
podSpec := job.Spec.Template.Spec
main := podSpec.Containers[0]
......@@ -236,7 +250,7 @@ func TestBuildCheckpointJob(t *testing.T) {
backoff := int32(5)
ckpt.Spec.Job.ActiveDeadlineSeconds = &deadline
ckpt.Spec.Job.BackoffLimit = &backoff //nolint:staticcheck // Compatibility test: deprecated field must remain ignored by checkpoint Jobs.
job, err = buildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
job, err = buildCheckpointJob(context.Background(), nil, r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err)
assert.Equal(t, int64(7200), *job.Spec.ActiveDeadlineSeconds)
assert.Equal(t, int32(0), *job.Spec.BackoffLimit)
......@@ -247,12 +261,142 @@ func TestBuildCheckpointJob(t *testing.T) {
corev1.ResourceName("nvidia.com/gpu"): resource.MustParse("2"),
},
}
job, err = buildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
job, err = buildCheckpointJob(context.Background(), nil, r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err)
assert.Equal(t, []string{"cuda-checkpoint"}, job.Spec.Template.Spec.Containers[0].Command)
assert.Equal(t, []string{"--launch-job", "python3", "-m", "dynamo.vllm"}, job.Spec.Template.Spec.Containers[0].Args)
}
func TestBuildCheckpointJobTargetsMainContainerWhenSidecarIsFirst(t *testing.T) {
s := checkpointTestScheme()
ckpt := makeTestCheckpoint(nvidiacomv1alpha1.DynamoCheckpointPhasePending)
ckpt.Spec.Job.PodTemplateSpec.Spec.Containers = []corev1.Container{
{
Name: "sidecar",
Image: "sidecar:latest",
Command: []string{"sleep"},
Args: []string{"infinity"},
},
{
Name: consts.MainContainerName,
Image: "test-image:latest",
Command: []string{"python3", "-m", "dynamo.vllm"},
Env: []corev1.EnvVar{{Name: "HF_TOKEN", Value: "secret"}},
Resources: corev1.ResourceRequirements{
Limits: corev1.ResourceList{
corev1.ResourceName(consts.KubeResourceGPUNvidia): resource.MustParse("2"),
},
},
},
}
r := makeCheckpointReconciler(s, ckpt)
job, err := buildCheckpointJob(context.Background(), nil, r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err)
main := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, consts.MainContainerName)
assert.Equal(t, []string{"cuda-checkpoint"}, main.Command)
assert.Equal(t, []string{"--launch-job", "python3", "-m", "dynamo.vllm"}, main.Args)
require.NotNil(t, main.ReadinessProbe)
assert.Equal(t, []string{"cat", "/tmp/ready-for-checkpoint"}, main.ReadinessProbe.Exec.Command)
assert.Nil(t, main.LivenessProbe)
assert.Nil(t, main.StartupProbe)
mainEnv := map[string]string{}
for _, env := range main.Env {
mainEnv[env.Name] = env.Value
}
assert.Equal(t, "/tmp/ready-for-checkpoint", mainEnv[consts.EnvReadyForCheckpointFile])
assert.Equal(t, "secret", mainEnv["HF_TOKEN"])
sidecar := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, "sidecar")
assert.Equal(t, []string{"sleep"}, sidecar.Command)
assert.Equal(t, []string{"infinity"}, sidecar.Args)
assert.Nil(t, sidecar.ReadinessProbe)
assert.Nil(t, sidecar.LivenessProbe)
assert.Nil(t, sidecar.StartupProbe)
for _, env := range sidecar.Env {
assert.NotEqual(t, consts.EnvReadyForCheckpointFile, env.Name)
}
}
func TestBuildCheckpointJobAddsGMSSidecars(t *testing.T) {
s := checkpointTestScheme()
ckpt := makeTestCheckpoint(nvidiacomv1alpha1.DynamoCheckpointPhasePending)
ckpt.Spec.GPUMemoryService = &nvidiacomv1alpha1.GPUMemoryServiceSpec{Enabled: true}
ckpt.Spec.Job.PodTemplateSpec.Spec.Containers[0].Resources.Claims = []corev1.ResourceClaim{{Name: "gpu"}}
snapshotAgentDaemonSet := &appsv1.DaemonSet{
ObjectMeta: metav1.ObjectMeta{
Name: "snapshot-agent",
Namespace: testNamespace,
Labels: map[string]string{
snapshotprotocol.SnapshotAgentLabelKey: snapshotprotocol.SnapshotAgentLabelValue,
},
},
Spec: appsv1.DaemonSetSpec{
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{{
Name: snapshotprotocol.SnapshotAgentContainerName,
VolumeMounts: []corev1.VolumeMount{{
Name: snapshotprotocol.SnapshotAgentVolumeName,
MountPath: "/checkpoints",
}},
}},
Volumes: []corev1.Volume{{
Name: snapshotprotocol.SnapshotAgentVolumeName,
VolumeSource: corev1.VolumeSource{
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
ClaimName: "snapshot-pvc",
},
},
}},
},
},
},
}
reader := fake.NewClientBuilder().WithScheme(s).WithObjects(snapshotAgentDaemonSet).Build()
r := makeCheckpointReconciler(s, ckpt)
job, err := buildCheckpointJob(context.Background(), reader, r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err)
main := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, consts.MainContainerName)
weightsServer := requireCheckpointContainer(t, job.Spec.Template.Spec.InitContainers, gmsruntime.ServerContainerName)
saver := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, checkpoint.GMSSaverContainer)
volNames := map[string]bool{}
for _, v := range job.Spec.Template.Spec.Volumes {
volNames[v.Name] = true
}
assert.True(t, volNames[gmsruntime.SharedVolumeName])
assert.True(t, volNames[gmsruntime.ControlVolumeName])
assert.True(t, volNames[snapshotprotocol.CheckpointVolumeName])
mainMounts := map[string]string{}
for _, m := range main.VolumeMounts {
mainMounts[m.Name] = m.MountPath
}
assert.Equal(t, gmsruntime.SharedMountPath, mainMounts[gmsruntime.SharedVolumeName])
assert.Equal(t, []string{"python3", "-m", "gpu_memory_service.cli.server"}, weightsServer.Command)
assert.Equal(t, corev1.ContainerRestartPolicyAlways, *weightsServer.RestartPolicy)
require.NotNil(t, weightsServer.StartupProbe)
assert.Equal(t, []string{"python3", "-m", "gpu_memory_service.cli.snapshot.saver"}, saver.Command)
saverMounts := map[string]string{}
for _, m := range saver.VolumeMounts {
saverMounts[m.Name] = m.MountPath
}
assert.Equal(t, "/checkpoints", saverMounts[snapshotprotocol.CheckpointVolumeName])
saverEnv := map[string]string{}
for _, env := range saver.Env {
saverEnv[env.Name] = env.Value
}
assert.Equal(t, "/checkpoints/gms/"+testHash+"/versions/1", saverEnv["GMS_CHECKPOINT_DIR"])
}
func TestBuildCheckpointJobInjectsStandardEnvVars(t *testing.T) {
s := checkpointTestScheme()
ckpt := makeTestCheckpoint(nvidiacomv1alpha1.DynamoCheckpointPhasePending)
......@@ -272,7 +416,7 @@ func TestBuildCheckpointJobInjectsStandardEnvVars(t *testing.T) {
customShmSize := resource.MustParse("16Gi")
ckpt.Spec.Job.SharedMemory = &nvidiacomv1alpha1.SharedMemorySpec{Size: customShmSize}
job, err := buildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName)
job, err := buildCheckpointJob(context.Background(), nil, r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err)
foundCustomShmVolume := false
for _, v := range job.Spec.Template.Spec.Volumes {
......
......@@ -29,10 +29,12 @@ import (
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
"github.com/google/go-cmp/cmp"
"github.com/onsi/gomega"
"github.com/onsi/gomega/format"
"github.com/stretchr/testify/require"
istioNetworking "istio.io/api/networking/v1beta1"
networkingv1beta1 "istio.io/client-go/pkg/apis/networking/v1beta1"
appsv1 "k8s.io/api/apps/v1"
......@@ -1248,7 +1250,7 @@ func TestDynamoComponentDeploymentReconciler_createOrUpdateOrDeleteDeployments_R
g.Expect(deployment3).NotTo(gomega.BeNil())
}
func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabels(t *testing.T) {
func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabels(t *testing.T) { //nolint:gocyclo
s := scheme.Scheme
if err := v1alpha1.AddToScheme(s); err != nil {
t.Fatalf("Failed to add v1alpha1 to scheme: %v", err)
......@@ -1376,6 +1378,129 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe
}
})
t.Run("ready gms checkpoint injects gms restore sidecars", func(t *testing.T) {
identity := v1alpha1.DynamoCheckpointIdentity{Model: "test-model", BackendFramework: "vllm"}
checkpointName, err := checkpoint.ComputeIdentityHash(identity)
if err != nil {
t.Fatalf("ComputeIdentityHash failed: %v", err)
}
dcd := makeDCD(checkpointName)
dcd.Spec.ExtraPodSpec.MainContainer.Resources.Claims = []corev1.ResourceClaim{{Name: "gpu"}}
ckpt := &v1alpha1.DynamoCheckpoint{
ObjectMeta: metav1.ObjectMeta{
Name: checkpointName,
Namespace: "default",
},
Spec: v1alpha1.DynamoCheckpointSpec{
Identity: identity,
GPUMemoryService: &v1alpha1.GPUMemoryServiceSpec{Enabled: true},
},
Status: v1alpha1.DynamoCheckpointStatus{
Phase: v1alpha1.DynamoCheckpointPhaseReady,
},
}
r := makeReconciler(dcd, ckpt)
podTemplateSpec, err := r.generatePodTemplateSpec(
context.Background(),
generateResourceOption{dynamoComponentDeployment: dcd},
dynamo.RoleMain,
)
if err != nil {
t.Fatalf("generatePodTemplateSpec failed: %v", err)
}
find := func(name string) *corev1.Container {
for i := range podTemplateSpec.Spec.Containers {
if podTemplateSpec.Spec.Containers[i].Name == name {
return &podTemplateSpec.Spec.Containers[i]
}
}
for i := range podTemplateSpec.Spec.InitContainers {
if podTemplateSpec.Spec.InitContainers[i].Name == name {
return &podTemplateSpec.Spec.InitContainers[i]
}
}
return nil
}
gmsServer := find(gmsruntime.ServerContainerName)
require.NotNil(t, gmsServer)
loader := find(checkpoint.GMSLoaderContainer)
require.NotNil(t, loader)
mounts := map[string]string{}
for _, mount := range loader.VolumeMounts {
mounts[mount.Name] = mount.MountPath
}
if got := mounts[snapshotprotocol.CheckpointVolumeName]; got != "/checkpoints" {
t.Fatalf("expected gms loader checkpoint mount at /checkpoints, got %q", got)
}
if got := gmsServer.Command; len(got) != 3 || got[0] != "python3" || got[1] != "-m" || got[2] != "gpu_memory_service.cli.server" { //nolint:goconst
t.Fatalf("expected weights server to run python module, got %#v", got)
}
// Restore: gms-server should be a regular container, not an init container
if gmsServer.RestartPolicy != nil {
t.Fatalf("expected restore gms-server to have no RestartPolicy (regular container), got %#v", gmsServer.RestartPolicy)
}
if gmsServer.StartupProbe != nil {
t.Fatalf("expected restore gms-server to have no StartupProbe")
}
if got := loader.Command; len(got) != 3 || got[0] != "python3" || got[1] != "-m" || got[2] != "gpu_memory_service.cli.snapshot.loader" {
t.Fatalf("expected loader to run python module, got %#v", got)
}
})
t.Run("ready checkpoint rewrites only main when extra sidecars are present", func(t *testing.T) {
identity := v1alpha1.DynamoCheckpointIdentity{Model: "test-model", BackendFramework: "vllm"}
checkpointName, err := checkpoint.ComputeIdentityHash(identity)
if err != nil {
t.Fatalf("ComputeIdentityHash failed: %v", err)
}
dcd := makeDCD(checkpointName)
dcd.Spec.ExtraPodSpec.PodSpec = &corev1.PodSpec{
Containers: []corev1.Container{{
Name: "gms-loader",
Image: "sidecar:latest",
Command: []string{"python3"},
Args: []string{"-m", "sidecar"},
}},
}
ckpt := &v1alpha1.DynamoCheckpoint{
ObjectMeta: metav1.ObjectMeta{
Name: checkpointName,
Namespace: "default",
},
Spec: v1alpha1.DynamoCheckpointSpec{Identity: identity},
Status: v1alpha1.DynamoCheckpointStatus{
Phase: v1alpha1.DynamoCheckpointPhaseReady,
},
}
r := makeReconciler(dcd, ckpt)
podTemplateSpec, err := r.generatePodTemplateSpec(
context.Background(),
generateResourceOption{dynamoComponentDeployment: dcd},
dynamo.RoleMain,
)
if err != nil {
t.Fatalf("generatePodTemplateSpec failed: %v", err)
}
if got := podTemplateSpec.Spec.Containers[0]; got.Name != "gms-loader" || len(got.Command) != 1 || got.Command[0] != "python3" {
t.Fatalf("expected sidecar container to remain unchanged, got %#v", got)
}
if got := podTemplateSpec.Spec.Containers[1]; got.Name != commonconsts.MainContainerName || len(got.Command) != 2 || got.Command[0] != "sleep" || got.Command[1] != "infinity" {
t.Fatalf("expected main container to be rewritten for restore, got %#v", got)
}
if podTemplateSpec.Spec.Containers[1].Args != nil {
t.Fatalf("expected main container args to be cleared, got %#v", podTemplateSpec.Spec.Containers[1].Args)
}
if got := podTemplateSpec.Labels[snapshotprotocol.RestoreTargetLabel]; got != commonconsts.KubeLabelValueTrue {
t.Fatalf("expected %s label to be true, got %q", snapshotprotocol.RestoreTargetLabel, got)
}
})
t.Run("operator reasserts restore identity labels after metadata merge", func(t *testing.T) {
identity := v1alpha1.DynamoCheckpointIdentity{Model: "test-model", BackendFramework: "vllm"}
checkpointName, err := checkpoint.ComputeIdentityHash(identity)
......
......@@ -1380,18 +1380,7 @@ func (r *DynamoGraphDeploymentReconciler) createCheckpointCR(
return nil, fmt.Errorf("checkpoint identity is required for Auto mode")
}
identity := component.Checkpoint.Identity
checkpointIdentity := nvidiacomv1alpha1.DynamoCheckpointIdentity{
Model: identity.Model,
BackendFramework: identity.BackendFramework,
DynamoVersion: identity.DynamoVersion,
TensorParallelSize: identity.TensorParallelSize,
PipelineParallelSize: identity.PipelineParallelSize,
Dtype: identity.Dtype,
MaxModelLen: identity.MaxModelLen,
ExtraParameters: identity.ExtraParameters,
}
checkpointIdentity := *component.Checkpoint.Identity.DeepCopy()
// Capture config is not part of the checkpoint identity. Once a checkpoint object exists for a
// hash, later reconcilers must reuse it instead of racing to overwrite the capture pod template.
......@@ -1399,7 +1388,7 @@ func (r *DynamoGraphDeploymentReconciler) createCheckpointCR(
dynamoDeployment,
component,
serviceName,
identity.BackendFramework,
checkpointIdentity.BackendFramework,
)
if err != nil {
return nil, fmt.Errorf("failed to build checkpoint job pod template: %w", err)
......@@ -1411,6 +1400,7 @@ func (r *DynamoGraphDeploymentReconciler) createCheckpointCR(
dynamoDeployment.Namespace,
checkpointIdentity,
podTemplate,
component.GPUMemoryService,
)
}
......
......@@ -456,7 +456,7 @@ func TestDynamoGraphDeploymentReconciler_reconcileCheckpoints_checkpointRefSkips
referenced := &v1alpha1.DynamoCheckpoint{
ObjectMeta: metav1.ObjectMeta{
Name: "friendly-checkpoint",
Name: friendlyCheckpointName,
Namespace: "default",
},
Spec: v1alpha1.DynamoCheckpointSpec{
......@@ -526,7 +526,7 @@ func TestDynamoGraphDeploymentReconciler_reconcileCheckpoints_checkpointRefSkips
if info.Hash != hash {
t.Fatalf("checkpoint hash = %s, want %s", info.Hash, hash)
}
if checkpointStatuses["worker"].CheckpointName != "friendly-checkpoint" {
if checkpointStatuses["worker"].CheckpointName != friendlyCheckpointName {
t.Fatalf("checkpoint status name = %s, want friendly-checkpoint", checkpointStatuses["worker"].CheckpointName)
}
......@@ -537,11 +537,96 @@ func TestDynamoGraphDeploymentReconciler_reconcileCheckpoints_checkpointRefSkips
if len(checkpoints.Items) != 1 {
t.Fatalf("expected only the referenced checkpoint to exist, found %d", len(checkpoints.Items))
}
if checkpoints.Items[0].Name != "friendly-checkpoint" {
if checkpoints.Items[0].Name != friendlyCheckpointName {
t.Fatalf("unexpected checkpoint %s", checkpoints.Items[0].Name)
}
}
func TestDynamoGraphDeploymentReconciler_reconcileCheckpoints_checkpointRefUsesReadyReferencedCR(t *testing.T) {
if err := v1alpha1.AddToScheme(scheme.Scheme); err != nil {
t.Fatalf("Failed to add v1alpha1 to scheme: %v", err)
}
ctx := context.Background()
identity := v1alpha1.DynamoCheckpointIdentity{
Model: "meta-llama/Llama-2-7b-hf",
BackendFramework: "vllm",
}
hash, err := checkpoint.ComputeIdentityHash(identity)
if err != nil {
t.Fatalf("Failed to compute checkpoint hash: %v", err)
}
referenced := &v1alpha1.DynamoCheckpoint{
ObjectMeta: metav1.ObjectMeta{
Name: friendlyCheckpointName,
Namespace: "default",
},
Spec: v1alpha1.DynamoCheckpointSpec{
Identity: identity,
},
Status: v1alpha1.DynamoCheckpointStatus{
Phase: v1alpha1.DynamoCheckpointPhaseReady,
IdentityHash: hash,
},
}
reconciler := &DynamoGraphDeploymentReconciler{
Client: fake.NewClientBuilder().
WithScheme(scheme.Scheme).
WithObjects(referenced).
WithStatusSubresource(referenced).
Build(),
Config: &configv1alpha1.OperatorConfiguration{},
Recorder: record.NewFakeRecorder(10),
}
ref := friendlyCheckpointName
dgd := &v1alpha1.DynamoGraphDeployment{
ObjectMeta: metav1.ObjectMeta{
Name: "test-dgd",
Namespace: "default",
},
Spec: v1alpha1.DynamoGraphDeploymentSpec{
Services: map[string]*v1alpha1.DynamoComponentDeploymentSharedSpec{
"worker": {
ComponentType: string(commonconsts.ComponentTypeWorker),
Checkpoint: &v1alpha1.ServiceCheckpointConfig{
Enabled: true,
Mode: v1alpha1.CheckpointModeAuto,
CheckpointRef: &ref,
},
},
},
},
}
checkpointStatuses, checkpointInfos, err := reconciler.reconcileCheckpoints(ctx, dgd)
if err != nil {
t.Fatalf("reconcileCheckpoints() error = %v", err)
}
info, ok := checkpointInfos["worker"]
if !ok {
t.Fatalf("expected checkpoint info for worker service")
}
if !info.Ready {
t.Fatalf("expected referenced checkpoint to be ready")
}
if !info.Exists {
t.Fatalf("expected referenced checkpoint to exist")
}
if info.Hash != hash {
t.Fatalf("checkpoint hash = %s, want %s", info.Hash, hash)
}
if checkpointStatuses["worker"].CheckpointName != friendlyCheckpointName {
t.Fatalf("checkpoint status name = %s, want friendly-checkpoint", checkpointStatuses["worker"].CheckpointName)
}
if !checkpointStatuses["worker"].Ready {
t.Fatalf("expected checkpoint status to be ready")
}
}
func TestDynamoGraphDeploymentReconciler_reconcileCheckpoints_autoModeWaitsForExistingCreatingCheckpoint(t *testing.T) {
if err := v1alpha1.AddToScheme(scheme.Scheme); err != nil {
t.Fatalf("Failed to add v1alpha1 to scheme: %v", err)
......
......@@ -10,30 +10,24 @@ import (
"fmt"
"strconv"
"strings"
"time"
"github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
corev1 "k8s.io/api/core/v1"
resourcev1 "k8s.io/api/resource/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client"
)
const (
gmsSharedVolumeName = "gms-shared"
gmsSharedMountPath = "/shared"
gmsDRAClaimName = "shared-gpu"
defaultDeviceClassName = "gpu.nvidia.com"
gmsProcessesPerGPU = 2
gmsStartupProbeTimeout = 2 * time.Minute
gmsStartupProbePeriodSec = 2
)
func isGMSEnabled(component *v1alpha1.DynamoComponentDeploymentSharedSpec) bool {
// IsGMSEnabled reports whether GPU Memory Service is requested for the component.
func IsGMSEnabled(component *v1alpha1.DynamoComponentDeploymentSharedSpec) bool {
return component.GPUMemoryService != nil && component.GPUMemoryService.Enabled
}
......@@ -58,6 +52,9 @@ func getGPUCount(component *v1alpha1.DynamoComponentDeploymentSharedSpec) (int,
if err != nil {
return 0, fmt.Errorf("invalid GPU count %q: %w", gpuStr, err)
}
if count <= 0 {
return 0, fmt.Errorf("GPU count must be greater than 0 when GPU memory service is enabled")
}
return count, nil
}
......@@ -70,49 +67,49 @@ func getDeviceClassName(component *v1alpha1.DynamoComponentDeploymentSharedSpec)
return defaultDeviceClassName
}
// applyGPUMemoryService transforms a pod spec to include a GMS sidecar with
// DRA shared GPU access. The main container's GPU resources are replaced with
// a DRA ResourceClaim, and a GMS init container is added.
//
// claimTemplateName is the name of the ResourceClaimTemplate that will provide
// shared GPU access; callers should compute it via GMSResourceClaimTemplateName.
func applyGPUMemoryService(
// resolveMainContainer finds the container named "main" in the pod spec.
// Falls back to Containers[0] when there is no container named "main"
// (e.g. failover pods with engine-0/engine-1 naming).
func resolveMainContainer(podSpec *corev1.PodSpec) (*corev1.Container, error) {
if len(podSpec.Containers) == 0 {
return nil, fmt.Errorf("pod spec must have at least one container for GPU memory service")
}
for i := range podSpec.Containers {
if podSpec.Containers[i].Name == commonconsts.MainContainerName {
return &podSpec.Containers[i], nil
}
}
return &podSpec.Containers[0], nil
}
// ApplyGPUMemoryService transforms a pod spec to include GMS server sidecars
// with DRA shared GPU access. The main container's GPU resources are replaced
// with a DRA ResourceClaim.
func ApplyGPUMemoryService(
podSpec *corev1.PodSpec,
component *v1alpha1.DynamoComponentDeploymentSharedSpec,
claimTemplateName string,
) error {
if len(podSpec.Containers) == 0 {
return fmt.Errorf("pod spec must have at least one container for GPU memory service")
}
gpuCount, err := getGPUCount(component)
if err != nil {
return err
}
_ = gpuCount // GPU count is used for DRA claim template; sidecar discovers devices via pynvml
mainContainer := &podSpec.Containers[0]
mainContainer, err := resolveMainContainer(podSpec)
if err != nil {
return err
}
// Replace GPU resources with DRA claim on main container
removeGPUResources(mainContainer)
mainContainer.Resources.Claims = append(mainContainer.Resources.Claims, corev1.ResourceClaim{
Name: gmsDRAClaimName,
})
// Add shared volume mount and TMPDIR to main container
mainContainer.VolumeMounts = append(mainContainer.VolumeMounts, corev1.VolumeMount{
Name: gmsSharedVolumeName,
MountPath: gmsSharedMountPath,
Name: gmsruntime.DRAClaimName,
})
mainContainer.Env = append(mainContainer.Env, corev1.EnvVar{
Name: "TMPDIR", Value: gmsSharedMountPath,
})
// Add GMS sidecar
gmsSidecar := buildGMSSidecar(mainContainer.Image, gpuCount)
podSpec.InitContainers = append(podSpec.InitContainers, gmsSidecar)
// Add shared volume
podSpec.Volumes = append(podSpec.Volumes, gmsSharedVolume())
// Add GMS server sidecar, shared volume, and socket env vars.
// The sidecar gets DRA claims copied from main automatically.
gmsruntime.EnsureServerSidecar(podSpec, mainContainer)
// GPU nodes are typically tainted with nvidia.com/gpu=NoSchedule. With
// traditional scheduling the device-plugin injects the matching toleration,
......@@ -126,7 +123,7 @@ func applyGPUMemoryService(
// Add pod-level DRA resource claim referencing the ResourceClaimTemplate
podSpec.ResourceClaims = append(podSpec.ResourceClaims, corev1.PodResourceClaim{
Name: gmsDRAClaimName,
Name: gmsruntime.DRAClaimName,
ResourceClaimTemplateName: &claimTemplateName,
})
......@@ -145,85 +142,6 @@ func removeGPUResources(container *corev1.Container) {
}
}
// buildGMSSidecar creates the GMS weight server as a sidecar init container
// (restartPolicy: Always). kubelet starts it before regular containers and
// keeps it running for the pod's lifetime.
//
// Each GPU gets two GMS subprocesses (weights + kv_cache) via a bash wrapper
// that forwards signals and exits if any child dies. TMPDIR is set so
// UUID-based sockets land in the shared volume.
func buildGMSSidecar(image string, gpuCount int) corev1.Container {
return corev1.Container{
Name: "gms-weights",
Image: image,
Command: []string{"bash", "-c"},
Args: []string{gmsWrapperScript(gpuCount)},
RestartPolicy: ptr.To(corev1.ContainerRestartPolicyAlways),
Env: []corev1.EnvVar{
{Name: "TMPDIR", Value: gmsSharedMountPath},
},
VolumeMounts: []corev1.VolumeMount{
{
Name: gmsSharedVolumeName,
MountPath: gmsSharedMountPath,
},
},
StartupProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
Exec: &corev1.ExecAction{
Command: gmsReadyCheckCommand(gpuCount),
},
},
PeriodSeconds: int32(gmsStartupProbePeriodSec),
FailureThreshold: int32(gmsStartupProbeTimeout/time.Second) / int32(gmsStartupProbePeriodSec),
},
Resources: corev1.ResourceRequirements{
Claims: []corev1.ResourceClaim{
{Name: gmsDRAClaimName},
},
},
}
}
// gmsWrapperScript generates a bash script that launches two GMS subprocesses
// per GPU device (one for weights, one for kv_cache), waits for any to exit,
// then tears down the process group.
func gmsWrapperScript(gpuCount int) string {
devList := make([]string, gpuCount)
for i := range gpuCount {
devList[i] = strconv.Itoa(i)
}
return fmt.Sprintf(
`trap 'kill 0 2>/dev/null || true' EXIT
for dev in %s; do
python3 -m gpu_memory_service --device "$dev" --tag weights &
echo "Started GMS device=$dev tag=weights pid=$!"
python3 -m gpu_memory_service --device "$dev" --tag kv_cache &
echo "Started GMS device=$dev tag=kv_cache pid=$!"
done
wait -n
echo "A GMS subprocess exited, shutting down"`, strings.Join(devList, " "))
}
// gmsReadyCheckCommand returns the exec probe command that verifies the
// expected number of GMS UDS sockets exist on the shared volume.
// With 2-tag GMS (weights + kv_cache), there are 2 sockets per GPU.
func gmsReadyCheckCommand(gpuCount int) []string {
return []string{
"sh", "-c",
fmt.Sprintf("test $(ls %s/gms_*.sock 2>/dev/null | wc -l) -ge %d", gmsSharedMountPath, gpuCount*gmsProcessesPerGPU),
}
}
func gmsSharedVolume() corev1.Volume {
return corev1.Volume{
Name: gmsSharedVolumeName,
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{},
},
}
}
// GMSResourceClaimTemplateName returns the deterministic name for the
// ResourceClaimTemplate associated with a GMS-enabled component.
func GMSResourceClaimTemplateName(parentName, serviceName string) string {
......@@ -254,7 +172,7 @@ func GenerateGMSResourceClaimTemplate(
},
}
if !isGMSEnabled(component) {
if !IsGMSEnabled(component) {
return template, true, nil
}
......
......@@ -12,6 +12,7 @@ import (
"github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
......@@ -63,14 +64,14 @@ func gmsBasePodSpec() corev1.PodSpec {
func TestApplyGPUMemoryService_EmptyContainers(t *testing.T) {
ps := corev1.PodSpec{}
err := applyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.Error(t, err)
assert.Contains(t, err.Error(), "at least one container")
}
func TestApplyGPUMemoryService_MainContainerTransformed(t *testing.T) {
ps := gmsBasePodSpec()
err := applyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err)
main := ps.Containers[0]
......@@ -82,53 +83,56 @@ func TestApplyGPUMemoryService_MainContainerTransformed(t *testing.T) {
// Should have DRA claim
require.Len(t, main.Resources.Claims, 1)
assert.Equal(t, gmsDRAClaimName, main.Resources.Claims[0].Name)
assert.Equal(t, gmsruntime.DRAClaimName, main.Resources.Claims[0].Name)
// Should have shared volume mount
var hasSharedMount bool
for _, vm := range main.VolumeMounts {
if vm.Name == gmsSharedVolumeName && vm.MountPath == gmsSharedMountPath {
if vm.Name == gmsruntime.SharedVolumeName && vm.MountPath == gmsruntime.SharedMountPath {
hasSharedMount = true
}
}
assert.True(t, hasSharedMount, "main container should have gms-shared volume mount")
// Should have TMPDIR
// Should have TMPDIR and GMS_SOCKET_DIR
envMap := envToMap(main.Env)
assert.Equal(t, gmsSharedMountPath, envMap["TMPDIR"])
assert.Equal(t, gmsruntime.SharedMountPath, envMap["TMPDIR"])
assert.Equal(t, gmsruntime.SharedMountPath, envMap["GMS_SOCKET_DIR"])
}
func TestApplyGPUMemoryService_GMSSidecarInjected(t *testing.T) {
ps := gmsBasePodSpec()
err := applyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err)
require.Len(t, ps.InitContainers, 1)
gms := ps.InitContainers[0]
assert.Equal(t, "gms-weights", gms.Name)
assert.Equal(t, gmsruntime.ServerContainerName, gms.Name)
assert.Equal(t, "test-image:latest", gms.Image)
assert.Equal(t, []string{"bash", "-c"}, gms.Command)
assert.Contains(t, gms.Args[0], "gpu_memory_service --device")
assert.Equal(t, []string{"python3", "-m", "gpu_memory_service.cli.server"}, gms.Command)
assert.NotNil(t, gms.RestartPolicy)
assert.Equal(t, corev1.ContainerRestartPolicyAlways, *gms.RestartPolicy)
require.NotNil(t, gms.StartupProbe)
assert.Equal(t, int32(1), gms.StartupProbe.PeriodSeconds)
assert.Equal(t, int32(300), gms.StartupProbe.FailureThreshold)
// GMS sidecar should have DRA claim
// GMS sidecar should have DRA claim copied from main
require.Len(t, gms.Resources.Claims, 1)
assert.Equal(t, gmsDRAClaimName, gms.Resources.Claims[0].Name)
assert.Equal(t, gmsruntime.DRAClaimName, gms.Resources.Claims[0].Name)
// GMS sidecar should have TMPDIR
gmsEnv := envToMap(gms.Env)
assert.Equal(t, gmsSharedMountPath, gmsEnv["TMPDIR"])
assert.Equal(t, gmsruntime.SharedMountPath, gmsEnv["TMPDIR"])
}
func TestApplyGPUMemoryService_SharedVolume(t *testing.T) {
ps := gmsBasePodSpec()
err := applyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err)
var found bool
for _, v := range ps.Volumes {
if v.Name == gmsSharedVolumeName {
if v.Name == gmsruntime.SharedVolumeName {
assert.NotNil(t, v.EmptyDir)
found = true
}
......@@ -138,7 +142,7 @@ func TestApplyGPUMemoryService_SharedVolume(t *testing.T) {
func TestApplyGPUMemoryService_GPUToleration(t *testing.T) {
ps := gmsBasePodSpec()
err := applyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err)
var found bool
......@@ -153,17 +157,17 @@ func TestApplyGPUMemoryService_GPUToleration(t *testing.T) {
func TestApplyGPUMemoryService_DRAResourceClaim(t *testing.T) {
ps := gmsBasePodSpec()
err := applyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err)
require.Len(t, ps.ResourceClaims, 1)
assert.Equal(t, gmsDRAClaimName, ps.ResourceClaims[0].Name)
assert.Equal(t, gmsruntime.DRAClaimName, ps.ResourceClaims[0].Name)
assert.Equal(t, "myapp-worker-gpu", *ps.ResourceClaims[0].ResourceClaimTemplateName)
}
func TestApplyGPUMemoryService_PreservesExistingEnv(t *testing.T) {
ps := gmsBasePodSpec()
err := applyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err)
main := ps.Containers[0]
......@@ -174,38 +178,32 @@ func TestApplyGPUMemoryService_PreservesExistingEnv(t *testing.T) {
func TestApplyGPUMemoryService_SingleContainer(t *testing.T) {
ps := gmsBasePodSpec()
err := applyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err)
// Should still have exactly 1 regular container (no duplication)
assert.Len(t, ps.Containers, 1)
assert.Equal(t, "main", ps.Containers[0].Name)
}
// --- GMS sidecar helpers ---
func TestApplyGPUMemoryService_ResolvesMainByName(t *testing.T) {
ps := gmsBasePodSpec()
// Prepend a sidecar so main is NOT Containers[0]
sidecar := corev1.Container{Name: "sidecar", Image: "sidecar:latest"}
ps.Containers = append([]corev1.Container{sidecar}, ps.Containers...)
require.Equal(t, "sidecar", ps.Containers[0].Name)
func TestGmsWrapperScript_TwoTagsPerDevice(t *testing.T) {
script := gmsWrapperScript(3)
assert.Contains(t, script, "for dev in 0 1 2")
assert.Contains(t, script, "--tag weights")
assert.Contains(t, script, "--tag kv_cache")
assert.Contains(t, script, "trap 'kill 0")
assert.Contains(t, script, "wait -n")
}
err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err)
func TestGmsReadyCheckCommand_TwoSocketsPerGPU(t *testing.T) {
cmd := gmsReadyCheckCommand(2)
assert.Equal(t, "sh", cmd[0])
assert.Equal(t, "-c", cmd[1])
assert.Contains(t, cmd[2], "gms_*.sock")
// 2 GPUs * 2 tags = 4 sockets
assert.Contains(t, cmd[2], "-ge 4")
}
// Sidecar should be untouched
assert.Equal(t, "sidecar", ps.Containers[0].Name)
assert.Empty(t, ps.Containers[0].Resources.Claims)
func TestGmsReadyCheckCommand_SingleGPU(t *testing.T) {
cmd := gmsReadyCheckCommand(1)
// 1 GPU * 2 tags = 2 sockets
assert.Contains(t, cmd[2], "-ge 2")
// Main should have DRA claim
main := ps.Containers[1]
assert.Equal(t, "main", main.Name)
require.Len(t, main.Resources.Claims, 1)
assert.Equal(t, gmsruntime.DRAClaimName, main.Resources.Claims[0].Name)
}
// --- GenerateGMSResourceClaimTemplate ---
......@@ -268,13 +266,13 @@ func TestGMSResourceClaimTemplateName(t *testing.T) {
// --- isGMSEnabled ---
func TestIsGMSEnabled(t *testing.T) {
assert.True(t, isGMSEnabled(&v1alpha1.DynamoComponentDeploymentSharedSpec{
assert.True(t, IsGMSEnabled(&v1alpha1.DynamoComponentDeploymentSharedSpec{
GPUMemoryService: &v1alpha1.GPUMemoryServiceSpec{Enabled: true},
}))
assert.False(t, isGMSEnabled(&v1alpha1.DynamoComponentDeploymentSharedSpec{
assert.False(t, IsGMSEnabled(&v1alpha1.DynamoComponentDeploymentSharedSpec{
GPUMemoryService: &v1alpha1.GPUMemoryServiceSpec{Enabled: false},
}))
assert.False(t, isGMSEnabled(&v1alpha1.DynamoComponentDeploymentSharedSpec{}))
assert.False(t, IsGMSEnabled(&v1alpha1.DynamoComponentDeploymentSharedSpec{}))
}
// --- getGPUCount ---
......
......@@ -1183,9 +1183,9 @@ func GenerateBasePodSpec(
}
// Inject GMS sidecar with DRA shared GPU access when GPU memory service is enabled.
if isGMSEnabled(component) {
if IsGMSEnabled(component) {
claimTemplateName := GMSResourceClaimTemplateName(parentGraphDeploymentName, serviceName)
if err := applyGPUMemoryService(&podSpec, component, claimTemplateName); err != nil {
if err := ApplyGPUMemoryService(&podSpec, component, claimTemplateName); err != nil {
return nil, fmt.Errorf("failed to apply GPU memory service: %w", err)
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment