Unverified Commit f3b181a9 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

feat(gms): operator-managed GMS checkpoint/restore support (#8153)

parent 091cdb51
...@@ -67,6 +67,31 @@ spec: ...@@ -67,6 +67,31 @@ spec:
spec: spec:
description: DynamoCheckpointSpec defines the desired state of DynamoCheckpoint description: DynamoCheckpointSpec defines the desired state of DynamoCheckpoint
properties: properties:
gpuMemoryService:
description: |-
GPUMemoryService enables checkpoint-time GPU Memory Service wiring.
It is intentionally outside spec.identity, so it does not affect the
checkpoint identity hash or deduplication.
properties:
deviceClassName:
default: gpu.nvidia.com
description: DeviceClassName is the DRA DeviceClass to request GPUs from.
type: string
enabled:
description: |-
Enabled activates the GMS sidecar. GPU resources on the main container
are replaced with a DRA ResourceClaim for shared GPU access.
type: boolean
mode:
default: intraPod
description: Mode selects the GMS deployment topology.
enum:
- intraPod
- interPod
type: string
required:
- enabled
type: object
identity: identity:
description: Identity defines the inputs that determine checkpoint equivalence description: Identity defines the inputs that determine checkpoint equivalence
properties: properties:
......
...@@ -124,6 +124,12 @@ type DynamoCheckpointSpec struct { ...@@ -124,6 +124,12 @@ type DynamoCheckpointSpec struct {
// +kubebuilder:validation:Required // +kubebuilder:validation:Required
Identity DynamoCheckpointIdentity `json:"identity"` Identity DynamoCheckpointIdentity `json:"identity"`
// GPUMemoryService enables checkpoint-time GPU Memory Service wiring.
// It is intentionally outside spec.identity, so it does not affect the
// checkpoint identity hash or deduplication.
// +optional
GPUMemoryService *GPUMemoryServiceSpec `json:"gpuMemoryService,omitempty"`
// Job defines the configuration for the checkpoint creation Job // Job defines the configuration for the checkpoint creation Job
// +kubebuilder:validation:Required // +kubebuilder:validation:Required
Job DynamoCheckpointJobConfig `json:"job"` Job DynamoCheckpointJobConfig `json:"job"`
......
...@@ -340,6 +340,11 @@ func (in *DynamoCheckpointList) DeepCopyObject() runtime.Object { ...@@ -340,6 +340,11 @@ func (in *DynamoCheckpointList) DeepCopyObject() runtime.Object {
func (in *DynamoCheckpointSpec) DeepCopyInto(out *DynamoCheckpointSpec) { func (in *DynamoCheckpointSpec) DeepCopyInto(out *DynamoCheckpointSpec) {
*out = *in *out = *in
in.Identity.DeepCopyInto(&out.Identity) in.Identity.DeepCopyInto(&out.Identity)
if in.GPUMemoryService != nil {
in, out := &in.GPUMemoryService, &out.GPUMemoryService
*out = new(GPUMemoryServiceSpec)
**out = **in
}
in.Job.DeepCopyInto(&out.Job) in.Job.DeepCopyInto(&out.Job)
} }
......
...@@ -67,6 +67,31 @@ spec: ...@@ -67,6 +67,31 @@ spec:
spec: spec:
description: DynamoCheckpointSpec defines the desired state of DynamoCheckpoint description: DynamoCheckpointSpec defines the desired state of DynamoCheckpoint
properties: properties:
gpuMemoryService:
description: |-
GPUMemoryService enables checkpoint-time GPU Memory Service wiring.
It is intentionally outside spec.identity, so it does not affect the
checkpoint identity hash or deduplication.
properties:
deviceClassName:
default: gpu.nvidia.com
description: DeviceClassName is the DRA DeviceClass to request GPUs from.
type: string
enabled:
description: |-
Enabled activates the GMS sidecar. GPU resources on the main container
are replaced with a DRA ResourceClaim for shared GPU access.
type: boolean
mode:
default: intraPod
description: Mode selects the GMS deployment topology.
enum:
- intraPod
- interPod
type: string
required:
- enabled
type: object
identity: identity:
description: Identity defines the inputs that determine checkpoint equivalence description: Identity defines the inputs that determine checkpoint equivalence
properties: properties:
......
...@@ -27,6 +27,10 @@ spec: ...@@ -27,6 +27,10 @@ spec:
dtype: "bfloat16" dtype: "bfloat16"
maxModelLen: 2048 maxModelLen: 2048
# Optional: enable GMS-specific checkpoint capture and restore helpers.
gpuMemoryService:
enabled: false
# Job configuration for checkpoint creation # Job configuration for checkpoint creation
job: job:
activeDeadlineSeconds: 3600 activeDeadlineSeconds: 3600
......
...@@ -23,6 +23,7 @@ import ( ...@@ -23,6 +23,7 @@ import (
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1" nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol" snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -159,7 +160,7 @@ func TestCreateOrGetAutoCheckpointDeduplicatesConcurrentSameHashCheckpoint(t *te ...@@ -159,7 +160,7 @@ func TestCreateOrGetAutoCheckpointDeduplicatesConcurrentSameHashCheckpoint(t *te
}, },
} }
ckpt, err := CreateOrGetAutoCheckpoint(ctx, c, testNamespace, identity, corev1.PodTemplateSpec{}) ckpt, err := CreateOrGetAutoCheckpoint(ctx, c, testNamespace, identity, corev1.PodTemplateSpec{}, nil)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, friendly.Name, ckpt.Name) assert.Equal(t, friendly.Name, ckpt.Name)
...@@ -174,7 +175,7 @@ func TestCreateOrGetAutoCheckpointSetsDefaultArtifactVersion(t *testing.T) { ...@@ -174,7 +175,7 @@ func TestCreateOrGetAutoCheckpointSetsDefaultArtifactVersion(t *testing.T) {
s := testScheme() s := testScheme()
c := fake.NewClientBuilder().WithScheme(s).Build() c := fake.NewClientBuilder().WithScheme(s).Build()
ckpt, err := CreateOrGetAutoCheckpoint(ctx, c, testNamespace, testIdentity(), corev1.PodTemplateSpec{}) ckpt, err := CreateOrGetAutoCheckpoint(ctx, c, testNamespace, testIdentity(), corev1.PodTemplateSpec{}, nil)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, ckpt.Annotations) require.NotNil(t, ckpt.Annotations)
assert.Equal(t, snapshotprotocol.DefaultCheckpointArtifactVersion, ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation]) assert.Equal(t, snapshotprotocol.DefaultCheckpointArtifactVersion, ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation])
...@@ -182,6 +183,50 @@ func TestCreateOrGetAutoCheckpointSetsDefaultArtifactVersion(t *testing.T) { ...@@ -182,6 +183,50 @@ func TestCreateOrGetAutoCheckpointSetsDefaultArtifactVersion(t *testing.T) {
// --- InjectCheckpointIntoPodSpec tests --- // --- InjectCheckpointIntoPodSpec tests ---
func TestEnsurePodInfoVolumeMergesExistingDownwardAPIItems(t *testing.T) {
podSpec := &corev1.PodSpec{
Volumes: []corev1.Volume{{
Name: consts.PodInfoVolumeName,
VolumeSource: corev1.VolumeSource{
DownwardAPI: &corev1.DownwardAPIVolumeSource{
Items: []corev1.DownwardAPIVolumeFile{
{
Path: "pod_name",
FieldRef: &corev1.ObjectFieldSelector{FieldPath: "metadata.name"},
},
{
Path: "custom",
FieldRef: &corev1.ObjectFieldSelector{FieldPath: "metadata.labels['custom']"},
},
},
},
},
}},
}
EnsurePodInfoVolume(podSpec)
require.Len(t, podSpec.Volumes, 1)
require.NotNil(t, podSpec.Volumes[0].DownwardAPI)
fields := map[string]string{}
for _, item := range podSpec.Volumes[0].DownwardAPI.Items {
if item.FieldRef != nil {
fields[item.Path] = item.FieldRef.FieldPath
}
}
assert.Equal(t, consts.PodInfoFieldPodName, fields["pod_name"])
assert.Equal(t, consts.PodInfoFieldPodUID, fields["pod_uid"])
assert.Equal(t, consts.PodInfoFieldPodNamespace, fields["pod_namespace"])
assert.Equal(t, "metadata.labels['custom']", fields["custom"])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoNamespace+"']", fields[consts.PodInfoFileDynNamespace])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoWorkerHash+"']", fields[consts.PodInfoFileDynNamespaceWorkerSuffix])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoComponentType+"']", fields[consts.PodInfoFileDynComponent])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoGraphDeploymentName+"']", fields[consts.PodInfoFileDynParentDGDName])
assert.Equal(t, consts.PodInfoFieldPodNamespace, fields[consts.PodInfoFileDynParentDGDNamespace])
}
func TestInjectCheckpointIntoPodSpec(t *testing.T) { func TestInjectCheckpointIntoPodSpec(t *testing.T) {
t.Run("ready checkpoint injects podinfo and overrides command", func(t *testing.T) { t.Run("ready checkpoint injects podinfo and overrides command", func(t *testing.T) {
podSpec := testPodSpec() podSpec := testPodSpec()
...@@ -218,6 +263,50 @@ func TestInjectCheckpointIntoPodSpec(t *testing.T) { ...@@ -218,6 +263,50 @@ func TestInjectCheckpointIntoPodSpec(t *testing.T) {
assert.Equal(t, consts.PodInfoMountPath, mountPaths[consts.PodInfoVolumeName]) assert.Equal(t, consts.PodInfoMountPath, mountPaths[consts.PodInfoVolumeName])
}) })
t.Run("ready checkpoint augments existing podinfo volume", func(t *testing.T) {
podSpec := testPodSpec()
podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{
Name: consts.PodInfoVolumeName,
VolumeSource: corev1.VolumeSource{
DownwardAPI: &corev1.DownwardAPIVolumeSource{
Items: []corev1.DownwardAPIVolumeFile{
{Path: "pod_name", FieldRef: &corev1.ObjectFieldSelector{FieldPath: consts.PodInfoFieldPodName}},
{Path: "pod_uid", FieldRef: &corev1.ObjectFieldSelector{FieldPath: consts.PodInfoFieldPodUID}},
{Path: "pod_namespace", FieldRef: &corev1.ObjectFieldSelector{FieldPath: consts.PodInfoFieldPodNamespace}},
},
},
},
})
info := &CheckpointInfo{Enabled: true, Ready: true, Identity: ptr.To(testIdentity())}
reader := fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build()
require.NoError(t, InjectCheckpointIntoPodSpec(context.Background(), reader, testNamespace, podSpec, info))
var podInfoVolume *corev1.Volume
for i := range podSpec.Volumes {
if podSpec.Volumes[i].Name == consts.PodInfoVolumeName {
podInfoVolume = &podSpec.Volumes[i]
break
}
}
require.NotNil(t, podInfoVolume)
require.NotNil(t, podInfoVolume.DownwardAPI)
fields := map[string]string{}
for _, item := range podInfoVolume.DownwardAPI.Items {
if item.FieldRef != nil {
fields[item.Path] = item.FieldRef.FieldPath
}
}
assert.Equal(t, consts.PodInfoFieldPodName, fields["pod_name"])
assert.Equal(t, consts.PodInfoFieldPodUID, fields["pod_uid"])
assert.Equal(t, consts.PodInfoFieldPodNamespace, fields["pod_namespace"])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoNamespace+"']", fields[consts.PodInfoFileDynNamespace])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoWorkerHash+"']", fields[consts.PodInfoFileDynNamespaceWorkerSuffix])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoComponentType+"']", fields[consts.PodInfoFileDynComponent])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoGraphDeploymentName+"']", fields[consts.PodInfoFileDynParentDGDName])
assert.Equal(t, consts.PodInfoFieldPodNamespace, fields[consts.PodInfoFileDynParentDGDNamespace])
})
t.Run("ready checkpoint targets the container named main", func(t *testing.T) { t.Run("ready checkpoint targets the container named main", func(t *testing.T) {
podSpec := &corev1.PodSpec{ podSpec := &corev1.PodSpec{
Containers: []corev1.Container{ Containers: []corev1.Container{
...@@ -235,6 +324,39 @@ func TestInjectCheckpointIntoPodSpec(t *testing.T) { ...@@ -235,6 +324,39 @@ func TestInjectCheckpointIntoPodSpec(t *testing.T) {
assert.Nil(t, podSpec.Containers[1].Args) assert.Nil(t, podSpec.Containers[1].Args)
}) })
t.Run("ready gms checkpoint injects restore sidecars and loader mount", func(t *testing.T) {
podSpec := testPodSpec()
podSpec.Containers[0].Resources.Claims = []corev1.ResourceClaim{{Name: "gpu"}}
info := &CheckpointInfo{Enabled: true, Ready: true, Hash: testHash, GPUMemoryService: &nvidiacomv1alpha1.GPUMemoryServiceSpec{Enabled: true}}
reader := fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build()
require.NoError(t, InjectCheckpointIntoPodSpec(context.Background(), reader, testNamespace, podSpec, info))
gmsServer := findContainer(podSpec, gmsruntime.ServerContainerName)
require.NotNil(t, gmsServer)
loader := findContainer(podSpec, GMSLoaderContainer)
require.NotNil(t, loader)
// Restore: gms-server should be a regular container, not an init container
assert.Empty(t, podSpec.InitContainers, "restore pods should not have gms-server as init container")
assert.Nil(t, gmsServer.RestartPolicy, "restore gms-server should not have RestartPolicy")
assert.Nil(t, gmsServer.StartupProbe, "restore gms-server should not have StartupProbe")
mounts := map[string]string{}
for _, mount := range loader.VolumeMounts {
mounts[mount.Name] = mount.MountPath
}
assert.Equal(t, "/checkpoints", mounts[snapshotprotocol.CheckpointVolumeName])
assert.Equal(t, gmsruntime.SharedMountPath, mounts[gmsruntime.SharedVolumeName])
env := map[string]string{}
for _, item := range loader.Env {
env[item.Name] = item.Value
}
assert.Equal(t, "/checkpoints/gms/"+testHash+"/versions/1", env["GMS_CHECKPOINT_DIR"])
assert.Equal(t, []string{"python3", "-m", "gpu_memory_service.cli.server"}, gmsServer.Command)
assert.Equal(t, []string{"python3", "-m", "gpu_memory_service.cli.snapshot.loader"}, loader.Command)
})
t.Run("error cases", func(t *testing.T) { t.Run("error cases", func(t *testing.T) {
for _, tc := range []struct { for _, tc := range []struct {
name string name string
...@@ -277,7 +399,10 @@ func TestResolveCheckpointForService(t *testing.T) { ...@@ -277,7 +399,10 @@ func TestResolveCheckpointForService(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
ckpt := &nvidiacomv1alpha1.DynamoCheckpoint{ ckpt := &nvidiacomv1alpha1.DynamoCheckpoint{
ObjectMeta: metav1.ObjectMeta{Name: hash, Namespace: testNamespace}, ObjectMeta: metav1.ObjectMeta{Name: hash, Namespace: testNamespace},
Spec: nvidiacomv1alpha1.DynamoCheckpointSpec{Identity: testIdentity()}, Spec: nvidiacomv1alpha1.DynamoCheckpointSpec{
Identity: testIdentity(),
GPUMemoryService: &nvidiacomv1alpha1.GPUMemoryServiceSpec{Enabled: true},
},
Status: nvidiacomv1alpha1.DynamoCheckpointStatus{ Status: nvidiacomv1alpha1.DynamoCheckpointStatus{
Phase: nvidiacomv1alpha1.DynamoCheckpointPhaseReady, Phase: nvidiacomv1alpha1.DynamoCheckpointPhaseReady,
IdentityHash: hash, IdentityHash: hash,
...@@ -294,6 +419,8 @@ func TestResolveCheckpointForService(t *testing.T) { ...@@ -294,6 +419,8 @@ func TestResolveCheckpointForService(t *testing.T) {
assert.True(t, info.Ready) assert.True(t, info.Ready)
assert.Equal(t, hash, info.Hash) assert.Equal(t, hash, info.Hash)
assert.Equal(t, hash, info.CheckpointName) assert.Equal(t, hash, info.CheckpointName)
require.NotNil(t, info.GPUMemoryService)
assert.True(t, info.GPUMemoryService.Enabled)
}) })
t.Run("checkpointRef resolves not-ready CR", func(t *testing.T) { t.Run("checkpointRef resolves not-ready CR", func(t *testing.T) {
...@@ -412,3 +539,19 @@ func TestResolveCheckpointForService(t *testing.T) { ...@@ -412,3 +539,19 @@ func TestResolveCheckpointForService(t *testing.T) {
assert.ErrorContains(t, err, "no checkpointRef or identity") assert.ErrorContains(t, err, "no checkpointRef or identity")
}) })
} }
// findContainer is a test helper that locates a container by name across both
// regular containers and init containers.
func findContainer(podSpec *corev1.PodSpec, name string) *corev1.Container {
for i := range podSpec.Containers {
if podSpec.Containers[i].Name == name {
return &podSpec.Containers[i]
}
}
for i := range podSpec.InitContainers {
if podSpec.InitContainers[i].Name == name {
return &podSpec.InitContainers[i]
}
}
return nil
}
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package checkpoint
import (
"context"
"fmt"
"path/filepath"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
ctrlclient "sigs.k8s.io/controller-runtime/pkg/client"
)
const (
GMSLoaderContainer = "gms-loader"
GMSSaverContainer = "gms-saver"
gmsCheckpointLoaderModule = "gpu_memory_service.cli.snapshot.loader"
gmsCheckpointSaverModule = "gpu_memory_service.cli.snapshot.saver"
)
func ResolveGMSCheckpointStorage(
ctx context.Context,
reader ctrlclient.Reader,
namespace string,
checkpointID string,
artifactVersion string,
) (snapshotprotocol.Storage, error) {
if reader == nil {
return snapshotprotocol.Storage{}, fmt.Errorf("checkpoint client is required")
}
daemonSets := &appsv1.DaemonSetList{}
if err := reader.List(
ctx,
daemonSets,
ctrlclient.InNamespace(namespace),
ctrlclient.MatchingLabels{snapshotprotocol.SnapshotAgentLabelKey: snapshotprotocol.SnapshotAgentLabelValue},
); err != nil {
return snapshotprotocol.Storage{}, fmt.Errorf("list snapshot-agent daemonsets in %s: %w", namespace, err)
}
storage, err := snapshotprotocol.DiscoverStorageFromDaemonSets(namespace, daemonSets.Items)
if err != nil {
return snapshotprotocol.Storage{}, err
}
return snapshotprotocol.ResolveCheckpointStorage(checkpointID, artifactVersion, storage)
}
// BuildGMSRestoreSidecars prepares GMS infrastructure for a restore pod and
// returns the additional containers the caller must append to podSpec.Containers.
//
// The GMS server runs as a regular container (not init) because the CRIU-restored
// main process already has GPU memory mapped and does not need sockets at
// startup. The gms-loader polls for sockets internally via wait_for_weights_socket.
func BuildGMSRestoreSidecars(
podSpec *corev1.PodSpec,
mainContainer *corev1.Container,
storage snapshotprotocol.Storage,
) []corev1.Container {
if podSpec == nil || mainContainer == nil {
return nil
}
// Remove gms-server from initContainers if the DGD-level
// applyGPUMemoryService already placed it there. For restore pods the
// server runs as a regular container so that all containers start in
// parallel — the restored main process does not need sockets at startup.
for i := range podSpec.InitContainers {
if podSpec.InitContainers[i].Name == gmsruntime.ServerContainerName {
podSpec.InitContainers = append(podSpec.InitContainers[:i], podSpec.InitContainers[i+1:]...)
break
}
}
server := gmsruntime.BuildServerContainer(podSpec, mainContainer)
loader := gmsCheckpointLoaderContainer(mainContainer.Image)
copyGMSDeviceClaims(mainContainer, &loader)
ensureCheckpointVolume(podSpec, storage.PVCName)
loader.VolumeMounts = append(loader.VolumeMounts, corev1.VolumeMount{Name: snapshotprotocol.CheckpointVolumeName, MountPath: storage.BasePath})
loader.Env = append(loader.Env, corev1.EnvVar{Name: "GMS_CHECKPOINT_DIR", Value: resolveGMSArtifactDir(storage)})
return []corev1.Container{server, loader}
}
// BuildGMSCheckpointJobSidecars prepares GMS infrastructure for a checkpoint
// job and returns the additional containers the caller must append to
// podSpec.Containers.
func BuildGMSCheckpointJobSidecars(
podSpec *corev1.PodSpec,
mainContainer *corev1.Container,
storage snapshotprotocol.Storage,
) ([]corev1.Container, error) {
if podSpec == nil || mainContainer == nil {
return nil, nil
}
if len(mainContainer.Resources.Claims) == 0 {
return nil, fmt.Errorf("gms sidecars require main container resource claims")
}
if storage.PVCName == "" || storage.BasePath == "" || storage.Location == "" {
return nil, fmt.Errorf("gms checkpoint jobs require resolved checkpoint storage")
}
gmsruntime.EnsureServerSidecar(podSpec, mainContainer)
ensureGMSCheckpointControl(podSpec)
saver := gmsCheckpointSaverContainer(mainContainer.Image)
copyGMSDeviceClaims(mainContainer, &saver)
ensureCheckpointVolume(podSpec, storage.PVCName)
saver.VolumeMounts = append(saver.VolumeMounts, corev1.VolumeMount{Name: snapshotprotocol.CheckpointVolumeName, MountPath: storage.BasePath})
saver.Env = append(saver.Env, corev1.EnvVar{Name: "GMS_CHECKPOINT_DIR", Value: resolveGMSArtifactDir(storage)})
return []corev1.Container{saver}, nil
}
func resolveGMSArtifactDir(storage snapshotprotocol.Storage) string {
// GMS data lives under /checkpoints/gms/<hash>/versions/<version>
// separate from the CRIU tree (/checkpoints/<hash>/versions/<version>)
// so the non-root saver can create directories at the PVC root.
artifactVersion := filepath.Base(storage.Location)
checkpointID := filepath.Base(filepath.Dir(filepath.Dir(storage.Location)))
return filepath.Join(storage.BasePath, "gms", checkpointID, "versions", artifactVersion)
}
func gmsCheckpointLoaderContainer(image string) corev1.Container {
container := corev1.Container{
Name: GMSLoaderContainer,
Image: image,
Command: []string{"python3", "-m", gmsCheckpointLoaderModule},
Env: []corev1.EnvVar{
{Name: "TMPDIR", Value: gmsruntime.SharedMountPath},
{Name: "GMS_SOCKET_DIR", Value: gmsruntime.SharedMountPath},
},
VolumeMounts: []corev1.VolumeMount{
{Name: gmsruntime.SharedVolumeName, MountPath: gmsruntime.SharedMountPath},
},
}
return container
}
func gmsCheckpointSaverContainer(image string) corev1.Container {
container := corev1.Container{
Name: GMSSaverContainer,
Image: image,
Command: []string{"python3", "-m", gmsCheckpointSaverModule},
Env: []corev1.EnvVar{
{Name: "POD_NAME", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{FieldPath: "metadata.name"}}},
{Name: "POD_NAMESPACE", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{FieldPath: "metadata.namespace"}}},
{Name: "TMPDIR", Value: gmsruntime.SharedMountPath},
{Name: "GMS_SOCKET_DIR", Value: gmsruntime.SharedMountPath},
{Name: "GMS_CONTROL_DIR", Value: gmsruntime.ControlDir},
},
VolumeMounts: []corev1.VolumeMount{
{Name: gmsruntime.SharedVolumeName, MountPath: gmsruntime.SharedMountPath},
{Name: gmsruntime.ControlVolumeName, MountPath: gmsruntime.ControlDir},
},
}
return container
}
// ensureGMSCheckpointControl adds the control volume and injects
// GMS_CONTROL_DIR into the GMS server container for checkpoint coordination.
func ensureGMSCheckpointControl(podSpec *corev1.PodSpec) {
podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{
Name: gmsruntime.ControlVolumeName,
VolumeSource: corev1.VolumeSource{EmptyDir: &corev1.EmptyDirVolumeSource{}},
})
server := gmsruntime.FindServerContainer(podSpec)
if server != nil {
server.VolumeMounts = append(server.VolumeMounts, corev1.VolumeMount{Name: gmsruntime.ControlVolumeName, MountPath: gmsruntime.ControlDir})
server.Env = append(server.Env, corev1.EnvVar{Name: "GMS_CONTROL_DIR", Value: gmsruntime.ControlDir})
}
}
func copyGMSDeviceClaims(mainContainer *corev1.Container, container *corev1.Container) {
if mainContainer == nil || container == nil || len(mainContainer.Resources.Claims) == 0 {
return
}
container.Resources.Claims = append([]corev1.ResourceClaim{}, mainContainer.Resources.Claims...)
}
func ensureCheckpointVolume(podSpec *corev1.PodSpec, pvcName string) {
if pvcName == "" {
return
}
for i := range podSpec.Volumes {
if podSpec.Volumes[i].Name == snapshotprotocol.CheckpointVolumeName {
return
}
}
podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{
Name: snapshotprotocol.CheckpointVolumeName,
VolumeSource: corev1.VolumeSource{
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ClaimName: pvcName},
},
})
}
...@@ -9,71 +9,93 @@ import ( ...@@ -9,71 +9,93 @@ import (
) )
func EnsurePodInfoVolume(podSpec *corev1.PodSpec) { func EnsurePodInfoVolume(podSpec *corev1.PodSpec) {
for _, volume := range podSpec.Volumes { for i := range podSpec.Volumes {
if volume.Name == commonconsts.PodInfoVolumeName { if podSpec.Volumes[i].Name != commonconsts.PodInfoVolumeName {
return continue
}
if podSpec.Volumes[i].DownwardAPI == nil {
podSpec.Volumes[i].VolumeSource.DownwardAPI = &corev1.DownwardAPIVolumeSource{}
}
// Merge required items into existing downwardAPI volume.
source := podSpec.Volumes[i].DownwardAPI
pathToIndex := make(map[string]int, len(source.Items))
for j := range source.Items {
pathToIndex[source.Items[j].Path] = j
} }
for _, item := range podInfoItems() {
if idx, ok := pathToIndex[item.Path]; ok {
source.Items[idx] = item
continue
}
source.Items = append(source.Items, item)
pathToIndex[item.Path] = len(source.Items) - 1
}
return
} }
podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{ podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{
Name: commonconsts.PodInfoVolumeName, Name: commonconsts.PodInfoVolumeName,
VolumeSource: corev1.VolumeSource{ VolumeSource: corev1.VolumeSource{
DownwardAPI: &corev1.DownwardAPIVolumeSource{ DownwardAPI: &corev1.DownwardAPIVolumeSource{
Items: []corev1.DownwardAPIVolumeFile{ Items: podInfoItems(),
{
Path: "pod_name",
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: commonconsts.PodInfoFieldPodName,
},
},
{
Path: "pod_uid",
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: commonconsts.PodInfoFieldPodUID,
},
},
{
Path: "pod_namespace",
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: commonconsts.PodInfoFieldPodNamespace,
},
},
{
Path: commonconsts.PodInfoFileDynNamespace,
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: "metadata.labels['" + commonconsts.KubeLabelDynamoNamespace + "']",
},
},
{
Path: commonconsts.PodInfoFileDynNamespaceWorkerSuffix,
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: "metadata.labels['" + commonconsts.KubeLabelDynamoWorkerHash + "']",
},
},
{
Path: commonconsts.PodInfoFileDynComponent,
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: "metadata.labels['" + commonconsts.KubeLabelDynamoComponentType + "']",
},
},
{
Path: commonconsts.PodInfoFileDynParentDGDName,
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: "metadata.labels['" + commonconsts.KubeLabelDynamoGraphDeploymentName + "']",
},
},
{
Path: commonconsts.PodInfoFileDynParentDGDNamespace,
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: commonconsts.PodInfoFieldPodNamespace,
},
},
},
}, },
}, },
}) })
} }
func podInfoItems() []corev1.DownwardAPIVolumeFile {
return []corev1.DownwardAPIVolumeFile{
{
Path: "pod_name",
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: commonconsts.PodInfoFieldPodName,
},
},
{
Path: "pod_uid",
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: commonconsts.PodInfoFieldPodUID,
},
},
{
Path: "pod_namespace",
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: commonconsts.PodInfoFieldPodNamespace,
},
},
{
Path: commonconsts.PodInfoFileDynNamespace,
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: "metadata.labels['" + commonconsts.KubeLabelDynamoNamespace + "']",
},
},
{
Path: commonconsts.PodInfoFileDynNamespaceWorkerSuffix,
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: "metadata.labels['" + commonconsts.KubeLabelDynamoWorkerHash + "']",
},
},
{
Path: commonconsts.PodInfoFileDynComponent,
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: "metadata.labels['" + commonconsts.KubeLabelDynamoComponentType + "']",
},
},
{
Path: commonconsts.PodInfoFileDynParentDGDName,
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: "metadata.labels['" + commonconsts.KubeLabelDynamoGraphDeploymentName + "']",
},
},
{
Path: commonconsts.PodInfoFileDynParentDGDNamespace,
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: commonconsts.PodInfoFieldPodNamespace,
},
},
}
}
func EnsurePodInfoMount(container *corev1.Container) { func EnsurePodInfoMount(container *corev1.Container) {
for _, mount := range container.VolumeMounts { for _, mount := range container.VolumeMounts {
if mount.Name == commonconsts.PodInfoVolumeName { if mount.Name == commonconsts.PodInfoVolumeName {
......
...@@ -94,5 +94,26 @@ func InjectCheckpointIntoPodSpec( ...@@ -94,5 +94,26 @@ func InjectCheckpointIntoPodSpec(
EnsurePodInfoVolume(podSpec) EnsurePodInfoVolume(podSpec)
EnsurePodInfoMount(mainContainer) EnsurePodInfoMount(mainContainer)
// GMS restore sidecars (server + loader) are only needed when the checkpoint
// is ready and the pod will actually be CRIU-restored.
if info.Ready && info.GPUMemoryService != nil && info.GPUMemoryService.Enabled {
if len(mainContainer.Resources.Claims) == 0 {
return fmt.Errorf("gms sidecars require main container resource claims")
}
storage, err := ResolveGMSCheckpointStorage(
ctx,
reader,
namespace,
info.Hash,
info.ArtifactVersion,
)
if err != nil {
return err
}
gmsSidecars := BuildGMSRestoreSidecars(podSpec, mainContainer, storage)
podSpec.Containers = append(podSpec.Containers, gmsSidecars...)
}
return nil return nil
} }
...@@ -28,13 +28,14 @@ import ( ...@@ -28,13 +28,14 @@ import (
) )
type CheckpointInfo struct { type CheckpointInfo struct {
Enabled bool Enabled bool
Exists bool Exists bool
Identity *nvidiacomv1alpha1.DynamoCheckpointIdentity Identity *nvidiacomv1alpha1.DynamoCheckpointIdentity
Hash string GPUMemoryService *nvidiacomv1alpha1.GPUMemoryServiceSpec
ArtifactVersion string Hash string
CheckpointName string ArtifactVersion string
Ready bool CheckpointName string
Ready bool
} }
func checkpointInfoFromObject(ckpt *nvidiacomv1alpha1.DynamoCheckpoint) (*CheckpointInfo, error) { func checkpointInfoFromObject(ckpt *nvidiacomv1alpha1.DynamoCheckpoint) (*CheckpointInfo, error) {
...@@ -44,13 +45,14 @@ func checkpointInfoFromObject(ckpt *nvidiacomv1alpha1.DynamoCheckpoint) (*Checkp ...@@ -44,13 +45,14 @@ func checkpointInfoFromObject(ckpt *nvidiacomv1alpha1.DynamoCheckpoint) (*Checkp
} }
return &CheckpointInfo{ return &CheckpointInfo{
Enabled: true, Enabled: true,
Exists: true, Exists: true,
Identity: &ckpt.Spec.Identity, Identity: &ckpt.Spec.Identity,
Hash: hash, GPUMemoryService: ckpt.Spec.GPUMemoryService,
ArtifactVersion: checkpointArtifactVersion(ckpt), Hash: hash,
CheckpointName: ckpt.Name, ArtifactVersion: checkpointArtifactVersion(ckpt),
Ready: ckpt.Status.Phase == nvidiacomv1alpha1.DynamoCheckpointPhaseReady, CheckpointName: ckpt.Name,
Ready: ckpt.Status.Phase == nvidiacomv1alpha1.DynamoCheckpointPhaseReady,
}, nil }, nil
} }
......
...@@ -107,6 +107,7 @@ func CreateOrGetAutoCheckpoint( ...@@ -107,6 +107,7 @@ func CreateOrGetAutoCheckpoint(
namespace string, namespace string,
identity nvidiacomv1alpha1.DynamoCheckpointIdentity, identity nvidiacomv1alpha1.DynamoCheckpointIdentity,
podTemplate corev1.PodTemplateSpec, podTemplate corev1.PodTemplateSpec,
gpuMemoryService *nvidiacomv1alpha1.GPUMemoryServiceSpec,
) (*nvidiacomv1alpha1.DynamoCheckpoint, error) { ) (*nvidiacomv1alpha1.DynamoCheckpoint, error) {
hash, err := ComputeIdentityHash(identity) hash, err := ComputeIdentityHash(identity)
if err != nil { if err != nil {
...@@ -125,7 +126,8 @@ func CreateOrGetAutoCheckpoint( ...@@ -125,7 +126,8 @@ func CreateOrGetAutoCheckpoint(
}, },
}, },
Spec: nvidiacomv1alpha1.DynamoCheckpointSpec{ Spec: nvidiacomv1alpha1.DynamoCheckpointSpec{
Identity: identity, Identity: identity,
GPUMemoryService: gpuMemoryService,
Job: nvidiacomv1alpha1.DynamoCheckpointJobConfig{ Job: nvidiacomv1alpha1.DynamoCheckpointJobConfig{
PodTemplateSpec: podTemplate, PodTemplateSpec: podTemplate,
}, },
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
package controller package controller
import ( import (
"context"
"fmt" "fmt"
configv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/config/v1alpha1" configv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/config/v1alpha1"
...@@ -16,6 +17,7 @@ import ( ...@@ -16,6 +17,7 @@ import (
batchv1 "k8s.io/api/batch/v1" batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/api/resource"
ctrlclient "sigs.k8s.io/controller-runtime/pkg/client"
) )
func buildCheckpointWorkerDefaultEnv( func buildCheckpointWorkerDefaultEnv(
...@@ -51,6 +53,8 @@ func buildCheckpointWorkerDefaultEnv( ...@@ -51,6 +53,8 @@ func buildCheckpointWorkerDefaultEnv(
} }
func buildCheckpointJob( func buildCheckpointJob(
ctx context.Context,
reader ctrlclient.Reader,
config *configv1alpha1.OperatorConfiguration, config *configv1alpha1.OperatorConfiguration,
ckpt *nvidiacomv1alpha1.DynamoCheckpoint, ckpt *nvidiacomv1alpha1.DynamoCheckpoint,
jobName string, jobName string,
...@@ -77,31 +81,51 @@ func buildCheckpointJob( ...@@ -77,31 +81,51 @@ func buildCheckpointJob(
checkpoint.EnsurePodInfoVolume(&podTemplate.Spec) checkpoint.EnsurePodInfoVolume(&podTemplate.Spec)
if len(podTemplate.Spec.Containers) > 0 { mainContainer, err := snapshotprotocol.ResolveCheckpointWorkerContainer(&podTemplate.Spec)
mainContainer := &podTemplate.Spec.Containers[0] if err != nil {
mainContainer.Env = dynamo.MergeEnvs( return nil, err
buildCheckpointWorkerDefaultEnv(ckpt, podTemplate), }
mainContainer.Env, mainContainer.Env = dynamo.MergeEnvs(
) buildCheckpointWorkerDefaultEnv(ckpt, podTemplate),
dynamo.AddStandardEnvVars(mainContainer, config) mainContainer.Env,
mainContainer.Env = append(mainContainer.Env, corev1.EnvVar{ )
Name: consts.EnvReadyForCheckpointFile, dynamo.AddStandardEnvVars(mainContainer, config)
Value: config.Checkpoint.ReadyForCheckpointFilePath, mainContainer.Env = append(mainContainer.Env, corev1.EnvVar{
}) Name: consts.EnvReadyForCheckpointFile,
mainContainer.ReadinessProbe = &corev1.Probe{ Value: config.Checkpoint.ReadyForCheckpointFilePath,
ProbeHandler: corev1.ProbeHandler{ })
Exec: &corev1.ExecAction{ mainContainer.ReadinessProbe = &corev1.Probe{
Command: []string{"cat", config.Checkpoint.ReadyForCheckpointFilePath}, ProbeHandler: corev1.ProbeHandler{
}, Exec: &corev1.ExecAction{
Command: []string{"cat", config.Checkpoint.ReadyForCheckpointFilePath},
}, },
InitialDelaySeconds: 15, },
PeriodSeconds: 2, InitialDelaySeconds: 15,
PeriodSeconds: 2,
}
mainContainer.LivenessProbe = nil
mainContainer.StartupProbe = nil
checkpoint.EnsurePodInfoMount(mainContainer)
dynamo.ApplySharedMemoryVolumeAndMount(&podTemplate.Spec, mainContainer, ckpt.Spec.Job.SharedMemory)
var gmsSidecars []corev1.Container
if ckpt.Spec.GPUMemoryService != nil && ckpt.Spec.GPUMemoryService.Enabled {
storage, err := checkpoint.ResolveGMSCheckpointStorage(
ctx,
reader,
ckpt.Namespace,
hash,
ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation],
)
if err != nil {
return nil, err
}
gmsSidecars, err = checkpoint.BuildGMSCheckpointJobSidecars(&podTemplate.Spec, mainContainer, storage)
if err != nil {
return nil, err
} }
mainContainer.LivenessProbe = nil
mainContainer.StartupProbe = nil
checkpoint.EnsurePodInfoMount(mainContainer)
dynamo.ApplySharedMemoryVolumeAndMount(&podTemplate.Spec, mainContainer, ckpt.Spec.Job.SharedMemory)
} }
podTemplate.Spec.Containers = append(podTemplate.Spec.Containers, gmsSidecars...)
activeDeadlineSeconds := ckpt.Spec.Job.ActiveDeadlineSeconds activeDeadlineSeconds := ckpt.Spec.Job.ActiveDeadlineSeconds
if activeDeadlineSeconds == nil { if activeDeadlineSeconds == nil {
...@@ -110,10 +134,8 @@ func buildCheckpointJob( ...@@ -110,10 +134,8 @@ func buildCheckpointJob(
} }
wrapLaunchJob := false wrapLaunchJob := false
if len(podTemplate.Spec.Containers) != 0 { if gpus, ok := mainContainer.Resources.Limits[corev1.ResourceName(consts.KubeResourceGPUNvidia)]; ok {
if gpus, ok := podTemplate.Spec.Containers[0].Resources.Limits[corev1.ResourceName(consts.KubeResourceGPUNvidia)]; ok { wrapLaunchJob = gpus.Cmp(*resource.NewQuantity(1, resource.DecimalSI)) > 0
wrapLaunchJob = gpus.Cmp(*resource.NewQuantity(1, resource.DecimalSI)) > 0
}
} }
ttlSecondsAfterFinish := snapshotprotocol.DefaultCheckpointJobTTLSeconds ttlSecondsAfterFinish := snapshotprotocol.DefaultCheckpointJobTTLSeconds
......
...@@ -197,7 +197,7 @@ func (r *CheckpointReconciler) handlePending(ctx context.Context, ckpt *nvidiaco ...@@ -197,7 +197,7 @@ func (r *CheckpointReconciler) handlePending(ctx context.Context, ckpt *nvidiaco
// Use SyncResource to create/update the checkpoint Job // Use SyncResource to create/update the checkpoint Job
modified, _, err := commonController.SyncResource(ctx, r, ckpt, func(ctx context.Context) (*batchv1.Job, bool, error) { modified, _, err := commonController.SyncResource(ctx, r, ckpt, func(ctx context.Context) (*batchv1.Job, bool, error) {
job, err := buildCheckpointJob(r.Config, ckpt, jobName) job, err := buildCheckpointJob(ctx, r.Client, r.Config, ckpt, jobName)
return job, false, err return job, false, err
}) })
if err != nil { if err != nil {
......
...@@ -26,9 +26,11 @@ import ( ...@@ -26,9 +26,11 @@ import (
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1" nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint" "github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol" snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
appsv1 "k8s.io/api/apps/v1"
batchv1 "k8s.io/api/batch/v1" batchv1 "k8s.io/api/batch/v1"
coordinationv1 "k8s.io/api/coordination/v1" coordinationv1 "k8s.io/api/coordination/v1"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
...@@ -65,6 +67,7 @@ var defaultCheckpointJobName = snapshotprotocol.GetCheckpointJobName(testHash, s ...@@ -65,6 +67,7 @@ var defaultCheckpointJobName = snapshotprotocol.GetCheckpointJobName(testHash, s
func checkpointTestScheme() *runtime.Scheme { func checkpointTestScheme() *runtime.Scheme {
s := runtime.NewScheme() s := runtime.NewScheme()
_ = nvidiacomv1alpha1.AddToScheme(s) _ = nvidiacomv1alpha1.AddToScheme(s)
_ = appsv1.AddToScheme(s)
_ = corev1.AddToScheme(s) _ = corev1.AddToScheme(s)
_ = batchv1.AddToScheme(s) _ = batchv1.AddToScheme(s)
_ = coordinationv1.AddToScheme(s) _ = coordinationv1.AddToScheme(s)
...@@ -130,6 +133,17 @@ func makeCheckpointLease(name string, renewTime time.Time, durationSeconds int32 ...@@ -130,6 +133,17 @@ func makeCheckpointLease(name string, renewTime time.Time, durationSeconds int32
} }
} }
func requireCheckpointContainer(t *testing.T, containers []corev1.Container, name string) *corev1.Container {
t.Helper()
for i := range containers {
if containers[i].Name == name {
return &containers[i]
}
}
t.Fatalf("container %q not found", name)
return nil
}
func TestBuildCheckpointJob(t *testing.T) { func TestBuildCheckpointJob(t *testing.T) {
s := checkpointTestScheme() s := checkpointTestScheme()
ckpt := makeTestCheckpoint(nvidiacomv1alpha1.DynamoCheckpointPhasePending) ckpt := makeTestCheckpoint(nvidiacomv1alpha1.DynamoCheckpointPhasePending)
...@@ -139,7 +153,7 @@ func TestBuildCheckpointJob(t *testing.T) { ...@@ -139,7 +153,7 @@ func TestBuildCheckpointJob(t *testing.T) {
} }
r := makeCheckpointReconciler(s, ckpt) r := makeCheckpointReconciler(s, ckpt)
job, err := buildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName) job, err := buildCheckpointJob(context.Background(), nil, r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err) require.NoError(t, err)
podSpec := job.Spec.Template.Spec podSpec := job.Spec.Template.Spec
main := podSpec.Containers[0] main := podSpec.Containers[0]
...@@ -236,7 +250,7 @@ func TestBuildCheckpointJob(t *testing.T) { ...@@ -236,7 +250,7 @@ func TestBuildCheckpointJob(t *testing.T) {
backoff := int32(5) backoff := int32(5)
ckpt.Spec.Job.ActiveDeadlineSeconds = &deadline ckpt.Spec.Job.ActiveDeadlineSeconds = &deadline
ckpt.Spec.Job.BackoffLimit = &backoff //nolint:staticcheck // Compatibility test: deprecated field must remain ignored by checkpoint Jobs. ckpt.Spec.Job.BackoffLimit = &backoff //nolint:staticcheck // Compatibility test: deprecated field must remain ignored by checkpoint Jobs.
job, err = buildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName) job, err = buildCheckpointJob(context.Background(), nil, r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, int64(7200), *job.Spec.ActiveDeadlineSeconds) assert.Equal(t, int64(7200), *job.Spec.ActiveDeadlineSeconds)
assert.Equal(t, int32(0), *job.Spec.BackoffLimit) assert.Equal(t, int32(0), *job.Spec.BackoffLimit)
...@@ -247,12 +261,142 @@ func TestBuildCheckpointJob(t *testing.T) { ...@@ -247,12 +261,142 @@ func TestBuildCheckpointJob(t *testing.T) {
corev1.ResourceName("nvidia.com/gpu"): resource.MustParse("2"), corev1.ResourceName("nvidia.com/gpu"): resource.MustParse("2"),
}, },
} }
job, err = buildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName) job, err = buildCheckpointJob(context.Background(), nil, r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, []string{"cuda-checkpoint"}, job.Spec.Template.Spec.Containers[0].Command) assert.Equal(t, []string{"cuda-checkpoint"}, job.Spec.Template.Spec.Containers[0].Command)
assert.Equal(t, []string{"--launch-job", "python3", "-m", "dynamo.vllm"}, job.Spec.Template.Spec.Containers[0].Args) assert.Equal(t, []string{"--launch-job", "python3", "-m", "dynamo.vllm"}, job.Spec.Template.Spec.Containers[0].Args)
} }
func TestBuildCheckpointJobTargetsMainContainerWhenSidecarIsFirst(t *testing.T) {
s := checkpointTestScheme()
ckpt := makeTestCheckpoint(nvidiacomv1alpha1.DynamoCheckpointPhasePending)
ckpt.Spec.Job.PodTemplateSpec.Spec.Containers = []corev1.Container{
{
Name: "sidecar",
Image: "sidecar:latest",
Command: []string{"sleep"},
Args: []string{"infinity"},
},
{
Name: consts.MainContainerName,
Image: "test-image:latest",
Command: []string{"python3", "-m", "dynamo.vllm"},
Env: []corev1.EnvVar{{Name: "HF_TOKEN", Value: "secret"}},
Resources: corev1.ResourceRequirements{
Limits: corev1.ResourceList{
corev1.ResourceName(consts.KubeResourceGPUNvidia): resource.MustParse("2"),
},
},
},
}
r := makeCheckpointReconciler(s, ckpt)
job, err := buildCheckpointJob(context.Background(), nil, r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err)
main := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, consts.MainContainerName)
assert.Equal(t, []string{"cuda-checkpoint"}, main.Command)
assert.Equal(t, []string{"--launch-job", "python3", "-m", "dynamo.vllm"}, main.Args)
require.NotNil(t, main.ReadinessProbe)
assert.Equal(t, []string{"cat", "/tmp/ready-for-checkpoint"}, main.ReadinessProbe.Exec.Command)
assert.Nil(t, main.LivenessProbe)
assert.Nil(t, main.StartupProbe)
mainEnv := map[string]string{}
for _, env := range main.Env {
mainEnv[env.Name] = env.Value
}
assert.Equal(t, "/tmp/ready-for-checkpoint", mainEnv[consts.EnvReadyForCheckpointFile])
assert.Equal(t, "secret", mainEnv["HF_TOKEN"])
sidecar := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, "sidecar")
assert.Equal(t, []string{"sleep"}, sidecar.Command)
assert.Equal(t, []string{"infinity"}, sidecar.Args)
assert.Nil(t, sidecar.ReadinessProbe)
assert.Nil(t, sidecar.LivenessProbe)
assert.Nil(t, sidecar.StartupProbe)
for _, env := range sidecar.Env {
assert.NotEqual(t, consts.EnvReadyForCheckpointFile, env.Name)
}
}
func TestBuildCheckpointJobAddsGMSSidecars(t *testing.T) {
s := checkpointTestScheme()
ckpt := makeTestCheckpoint(nvidiacomv1alpha1.DynamoCheckpointPhasePending)
ckpt.Spec.GPUMemoryService = &nvidiacomv1alpha1.GPUMemoryServiceSpec{Enabled: true}
ckpt.Spec.Job.PodTemplateSpec.Spec.Containers[0].Resources.Claims = []corev1.ResourceClaim{{Name: "gpu"}}
snapshotAgentDaemonSet := &appsv1.DaemonSet{
ObjectMeta: metav1.ObjectMeta{
Name: "snapshot-agent",
Namespace: testNamespace,
Labels: map[string]string{
snapshotprotocol.SnapshotAgentLabelKey: snapshotprotocol.SnapshotAgentLabelValue,
},
},
Spec: appsv1.DaemonSetSpec{
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{{
Name: snapshotprotocol.SnapshotAgentContainerName,
VolumeMounts: []corev1.VolumeMount{{
Name: snapshotprotocol.SnapshotAgentVolumeName,
MountPath: "/checkpoints",
}},
}},
Volumes: []corev1.Volume{{
Name: snapshotprotocol.SnapshotAgentVolumeName,
VolumeSource: corev1.VolumeSource{
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
ClaimName: "snapshot-pvc",
},
},
}},
},
},
},
}
reader := fake.NewClientBuilder().WithScheme(s).WithObjects(snapshotAgentDaemonSet).Build()
r := makeCheckpointReconciler(s, ckpt)
job, err := buildCheckpointJob(context.Background(), reader, r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err)
main := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, consts.MainContainerName)
weightsServer := requireCheckpointContainer(t, job.Spec.Template.Spec.InitContainers, gmsruntime.ServerContainerName)
saver := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, checkpoint.GMSSaverContainer)
volNames := map[string]bool{}
for _, v := range job.Spec.Template.Spec.Volumes {
volNames[v.Name] = true
}
assert.True(t, volNames[gmsruntime.SharedVolumeName])
assert.True(t, volNames[gmsruntime.ControlVolumeName])
assert.True(t, volNames[snapshotprotocol.CheckpointVolumeName])
mainMounts := map[string]string{}
for _, m := range main.VolumeMounts {
mainMounts[m.Name] = m.MountPath
}
assert.Equal(t, gmsruntime.SharedMountPath, mainMounts[gmsruntime.SharedVolumeName])
assert.Equal(t, []string{"python3", "-m", "gpu_memory_service.cli.server"}, weightsServer.Command)
assert.Equal(t, corev1.ContainerRestartPolicyAlways, *weightsServer.RestartPolicy)
require.NotNil(t, weightsServer.StartupProbe)
assert.Equal(t, []string{"python3", "-m", "gpu_memory_service.cli.snapshot.saver"}, saver.Command)
saverMounts := map[string]string{}
for _, m := range saver.VolumeMounts {
saverMounts[m.Name] = m.MountPath
}
assert.Equal(t, "/checkpoints", saverMounts[snapshotprotocol.CheckpointVolumeName])
saverEnv := map[string]string{}
for _, env := range saver.Env {
saverEnv[env.Name] = env.Value
}
assert.Equal(t, "/checkpoints/gms/"+testHash+"/versions/1", saverEnv["GMS_CHECKPOINT_DIR"])
}
func TestBuildCheckpointJobInjectsStandardEnvVars(t *testing.T) { func TestBuildCheckpointJobInjectsStandardEnvVars(t *testing.T) {
s := checkpointTestScheme() s := checkpointTestScheme()
ckpt := makeTestCheckpoint(nvidiacomv1alpha1.DynamoCheckpointPhasePending) ckpt := makeTestCheckpoint(nvidiacomv1alpha1.DynamoCheckpointPhasePending)
...@@ -272,7 +416,7 @@ func TestBuildCheckpointJobInjectsStandardEnvVars(t *testing.T) { ...@@ -272,7 +416,7 @@ func TestBuildCheckpointJobInjectsStandardEnvVars(t *testing.T) {
customShmSize := resource.MustParse("16Gi") customShmSize := resource.MustParse("16Gi")
ckpt.Spec.Job.SharedMemory = &nvidiacomv1alpha1.SharedMemorySpec{Size: customShmSize} ckpt.Spec.Job.SharedMemory = &nvidiacomv1alpha1.SharedMemorySpec{Size: customShmSize}
job, err := buildCheckpointJob(r.Config, ckpt, defaultCheckpointJobName) job, err := buildCheckpointJob(context.Background(), nil, r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err) require.NoError(t, err)
foundCustomShmVolume := false foundCustomShmVolume := false
for _, v := range job.Spec.Template.Spec.Volumes { for _, v := range job.Spec.Template.Spec.Volumes {
......
...@@ -29,10 +29,12 @@ import ( ...@@ -29,10 +29,12 @@ import (
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common" "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo" "github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol" snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/onsi/gomega" "github.com/onsi/gomega"
"github.com/onsi/gomega/format" "github.com/onsi/gomega/format"
"github.com/stretchr/testify/require"
istioNetworking "istio.io/api/networking/v1beta1" istioNetworking "istio.io/api/networking/v1beta1"
networkingv1beta1 "istio.io/client-go/pkg/apis/networking/v1beta1" networkingv1beta1 "istio.io/client-go/pkg/apis/networking/v1beta1"
appsv1 "k8s.io/api/apps/v1" appsv1 "k8s.io/api/apps/v1"
...@@ -1248,7 +1250,7 @@ func TestDynamoComponentDeploymentReconciler_createOrUpdateOrDeleteDeployments_R ...@@ -1248,7 +1250,7 @@ func TestDynamoComponentDeploymentReconciler_createOrUpdateOrDeleteDeployments_R
g.Expect(deployment3).NotTo(gomega.BeNil()) g.Expect(deployment3).NotTo(gomega.BeNil())
} }
func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabels(t *testing.T) { func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabels(t *testing.T) { //nolint:gocyclo
s := scheme.Scheme s := scheme.Scheme
if err := v1alpha1.AddToScheme(s); err != nil { if err := v1alpha1.AddToScheme(s); err != nil {
t.Fatalf("Failed to add v1alpha1 to scheme: %v", err) t.Fatalf("Failed to add v1alpha1 to scheme: %v", err)
...@@ -1376,6 +1378,129 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe ...@@ -1376,6 +1378,129 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe
} }
}) })
t.Run("ready gms checkpoint injects gms restore sidecars", func(t *testing.T) {
identity := v1alpha1.DynamoCheckpointIdentity{Model: "test-model", BackendFramework: "vllm"}
checkpointName, err := checkpoint.ComputeIdentityHash(identity)
if err != nil {
t.Fatalf("ComputeIdentityHash failed: %v", err)
}
dcd := makeDCD(checkpointName)
dcd.Spec.ExtraPodSpec.MainContainer.Resources.Claims = []corev1.ResourceClaim{{Name: "gpu"}}
ckpt := &v1alpha1.DynamoCheckpoint{
ObjectMeta: metav1.ObjectMeta{
Name: checkpointName,
Namespace: "default",
},
Spec: v1alpha1.DynamoCheckpointSpec{
Identity: identity,
GPUMemoryService: &v1alpha1.GPUMemoryServiceSpec{Enabled: true},
},
Status: v1alpha1.DynamoCheckpointStatus{
Phase: v1alpha1.DynamoCheckpointPhaseReady,
},
}
r := makeReconciler(dcd, ckpt)
podTemplateSpec, err := r.generatePodTemplateSpec(
context.Background(),
generateResourceOption{dynamoComponentDeployment: dcd},
dynamo.RoleMain,
)
if err != nil {
t.Fatalf("generatePodTemplateSpec failed: %v", err)
}
find := func(name string) *corev1.Container {
for i := range podTemplateSpec.Spec.Containers {
if podTemplateSpec.Spec.Containers[i].Name == name {
return &podTemplateSpec.Spec.Containers[i]
}
}
for i := range podTemplateSpec.Spec.InitContainers {
if podTemplateSpec.Spec.InitContainers[i].Name == name {
return &podTemplateSpec.Spec.InitContainers[i]
}
}
return nil
}
gmsServer := find(gmsruntime.ServerContainerName)
require.NotNil(t, gmsServer)
loader := find(checkpoint.GMSLoaderContainer)
require.NotNil(t, loader)
mounts := map[string]string{}
for _, mount := range loader.VolumeMounts {
mounts[mount.Name] = mount.MountPath
}
if got := mounts[snapshotprotocol.CheckpointVolumeName]; got != "/checkpoints" {
t.Fatalf("expected gms loader checkpoint mount at /checkpoints, got %q", got)
}
if got := gmsServer.Command; len(got) != 3 || got[0] != "python3" || got[1] != "-m" || got[2] != "gpu_memory_service.cli.server" { //nolint:goconst
t.Fatalf("expected weights server to run python module, got %#v", got)
}
// Restore: gms-server should be a regular container, not an init container
if gmsServer.RestartPolicy != nil {
t.Fatalf("expected restore gms-server to have no RestartPolicy (regular container), got %#v", gmsServer.RestartPolicy)
}
if gmsServer.StartupProbe != nil {
t.Fatalf("expected restore gms-server to have no StartupProbe")
}
if got := loader.Command; len(got) != 3 || got[0] != "python3" || got[1] != "-m" || got[2] != "gpu_memory_service.cli.snapshot.loader" {
t.Fatalf("expected loader to run python module, got %#v", got)
}
})
t.Run("ready checkpoint rewrites only main when extra sidecars are present", func(t *testing.T) {
identity := v1alpha1.DynamoCheckpointIdentity{Model: "test-model", BackendFramework: "vllm"}
checkpointName, err := checkpoint.ComputeIdentityHash(identity)
if err != nil {
t.Fatalf("ComputeIdentityHash failed: %v", err)
}
dcd := makeDCD(checkpointName)
dcd.Spec.ExtraPodSpec.PodSpec = &corev1.PodSpec{
Containers: []corev1.Container{{
Name: "gms-loader",
Image: "sidecar:latest",
Command: []string{"python3"},
Args: []string{"-m", "sidecar"},
}},
}
ckpt := &v1alpha1.DynamoCheckpoint{
ObjectMeta: metav1.ObjectMeta{
Name: checkpointName,
Namespace: "default",
},
Spec: v1alpha1.DynamoCheckpointSpec{Identity: identity},
Status: v1alpha1.DynamoCheckpointStatus{
Phase: v1alpha1.DynamoCheckpointPhaseReady,
},
}
r := makeReconciler(dcd, ckpt)
podTemplateSpec, err := r.generatePodTemplateSpec(
context.Background(),
generateResourceOption{dynamoComponentDeployment: dcd},
dynamo.RoleMain,
)
if err != nil {
t.Fatalf("generatePodTemplateSpec failed: %v", err)
}
if got := podTemplateSpec.Spec.Containers[0]; got.Name != "gms-loader" || len(got.Command) != 1 || got.Command[0] != "python3" {
t.Fatalf("expected sidecar container to remain unchanged, got %#v", got)
}
if got := podTemplateSpec.Spec.Containers[1]; got.Name != commonconsts.MainContainerName || len(got.Command) != 2 || got.Command[0] != "sleep" || got.Command[1] != "infinity" {
t.Fatalf("expected main container to be rewritten for restore, got %#v", got)
}
if podTemplateSpec.Spec.Containers[1].Args != nil {
t.Fatalf("expected main container args to be cleared, got %#v", podTemplateSpec.Spec.Containers[1].Args)
}
if got := podTemplateSpec.Labels[snapshotprotocol.RestoreTargetLabel]; got != commonconsts.KubeLabelValueTrue {
t.Fatalf("expected %s label to be true, got %q", snapshotprotocol.RestoreTargetLabel, got)
}
})
t.Run("operator reasserts restore identity labels after metadata merge", func(t *testing.T) { t.Run("operator reasserts restore identity labels after metadata merge", func(t *testing.T) {
identity := v1alpha1.DynamoCheckpointIdentity{Model: "test-model", BackendFramework: "vllm"} identity := v1alpha1.DynamoCheckpointIdentity{Model: "test-model", BackendFramework: "vllm"}
checkpointName, err := checkpoint.ComputeIdentityHash(identity) checkpointName, err := checkpoint.ComputeIdentityHash(identity)
......
...@@ -1380,18 +1380,7 @@ func (r *DynamoGraphDeploymentReconciler) createCheckpointCR( ...@@ -1380,18 +1380,7 @@ func (r *DynamoGraphDeploymentReconciler) createCheckpointCR(
return nil, fmt.Errorf("checkpoint identity is required for Auto mode") return nil, fmt.Errorf("checkpoint identity is required for Auto mode")
} }
identity := component.Checkpoint.Identity checkpointIdentity := *component.Checkpoint.Identity.DeepCopy()
checkpointIdentity := nvidiacomv1alpha1.DynamoCheckpointIdentity{
Model: identity.Model,
BackendFramework: identity.BackendFramework,
DynamoVersion: identity.DynamoVersion,
TensorParallelSize: identity.TensorParallelSize,
PipelineParallelSize: identity.PipelineParallelSize,
Dtype: identity.Dtype,
MaxModelLen: identity.MaxModelLen,
ExtraParameters: identity.ExtraParameters,
}
// Capture config is not part of the checkpoint identity. Once a checkpoint object exists for a // Capture config is not part of the checkpoint identity. Once a checkpoint object exists for a
// hash, later reconcilers must reuse it instead of racing to overwrite the capture pod template. // hash, later reconcilers must reuse it instead of racing to overwrite the capture pod template.
...@@ -1399,7 +1388,7 @@ func (r *DynamoGraphDeploymentReconciler) createCheckpointCR( ...@@ -1399,7 +1388,7 @@ func (r *DynamoGraphDeploymentReconciler) createCheckpointCR(
dynamoDeployment, dynamoDeployment,
component, component,
serviceName, serviceName,
identity.BackendFramework, checkpointIdentity.BackendFramework,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to build checkpoint job pod template: %w", err) return nil, fmt.Errorf("failed to build checkpoint job pod template: %w", err)
...@@ -1411,6 +1400,7 @@ func (r *DynamoGraphDeploymentReconciler) createCheckpointCR( ...@@ -1411,6 +1400,7 @@ func (r *DynamoGraphDeploymentReconciler) createCheckpointCR(
dynamoDeployment.Namespace, dynamoDeployment.Namespace,
checkpointIdentity, checkpointIdentity,
podTemplate, podTemplate,
component.GPUMemoryService,
) )
} }
......
...@@ -456,7 +456,7 @@ func TestDynamoGraphDeploymentReconciler_reconcileCheckpoints_checkpointRefSkips ...@@ -456,7 +456,7 @@ func TestDynamoGraphDeploymentReconciler_reconcileCheckpoints_checkpointRefSkips
referenced := &v1alpha1.DynamoCheckpoint{ referenced := &v1alpha1.DynamoCheckpoint{
ObjectMeta: metav1.ObjectMeta{ ObjectMeta: metav1.ObjectMeta{
Name: "friendly-checkpoint", Name: friendlyCheckpointName,
Namespace: "default", Namespace: "default",
}, },
Spec: v1alpha1.DynamoCheckpointSpec{ Spec: v1alpha1.DynamoCheckpointSpec{
...@@ -526,7 +526,7 @@ func TestDynamoGraphDeploymentReconciler_reconcileCheckpoints_checkpointRefSkips ...@@ -526,7 +526,7 @@ func TestDynamoGraphDeploymentReconciler_reconcileCheckpoints_checkpointRefSkips
if info.Hash != hash { if info.Hash != hash {
t.Fatalf("checkpoint hash = %s, want %s", info.Hash, hash) t.Fatalf("checkpoint hash = %s, want %s", info.Hash, hash)
} }
if checkpointStatuses["worker"].CheckpointName != "friendly-checkpoint" { if checkpointStatuses["worker"].CheckpointName != friendlyCheckpointName {
t.Fatalf("checkpoint status name = %s, want friendly-checkpoint", checkpointStatuses["worker"].CheckpointName) t.Fatalf("checkpoint status name = %s, want friendly-checkpoint", checkpointStatuses["worker"].CheckpointName)
} }
...@@ -537,11 +537,96 @@ func TestDynamoGraphDeploymentReconciler_reconcileCheckpoints_checkpointRefSkips ...@@ -537,11 +537,96 @@ func TestDynamoGraphDeploymentReconciler_reconcileCheckpoints_checkpointRefSkips
if len(checkpoints.Items) != 1 { if len(checkpoints.Items) != 1 {
t.Fatalf("expected only the referenced checkpoint to exist, found %d", len(checkpoints.Items)) t.Fatalf("expected only the referenced checkpoint to exist, found %d", len(checkpoints.Items))
} }
if checkpoints.Items[0].Name != "friendly-checkpoint" { if checkpoints.Items[0].Name != friendlyCheckpointName {
t.Fatalf("unexpected checkpoint %s", checkpoints.Items[0].Name) t.Fatalf("unexpected checkpoint %s", checkpoints.Items[0].Name)
} }
} }
func TestDynamoGraphDeploymentReconciler_reconcileCheckpoints_checkpointRefUsesReadyReferencedCR(t *testing.T) {
if err := v1alpha1.AddToScheme(scheme.Scheme); err != nil {
t.Fatalf("Failed to add v1alpha1 to scheme: %v", err)
}
ctx := context.Background()
identity := v1alpha1.DynamoCheckpointIdentity{
Model: "meta-llama/Llama-2-7b-hf",
BackendFramework: "vllm",
}
hash, err := checkpoint.ComputeIdentityHash(identity)
if err != nil {
t.Fatalf("Failed to compute checkpoint hash: %v", err)
}
referenced := &v1alpha1.DynamoCheckpoint{
ObjectMeta: metav1.ObjectMeta{
Name: friendlyCheckpointName,
Namespace: "default",
},
Spec: v1alpha1.DynamoCheckpointSpec{
Identity: identity,
},
Status: v1alpha1.DynamoCheckpointStatus{
Phase: v1alpha1.DynamoCheckpointPhaseReady,
IdentityHash: hash,
},
}
reconciler := &DynamoGraphDeploymentReconciler{
Client: fake.NewClientBuilder().
WithScheme(scheme.Scheme).
WithObjects(referenced).
WithStatusSubresource(referenced).
Build(),
Config: &configv1alpha1.OperatorConfiguration{},
Recorder: record.NewFakeRecorder(10),
}
ref := friendlyCheckpointName
dgd := &v1alpha1.DynamoGraphDeployment{
ObjectMeta: metav1.ObjectMeta{
Name: "test-dgd",
Namespace: "default",
},
Spec: v1alpha1.DynamoGraphDeploymentSpec{
Services: map[string]*v1alpha1.DynamoComponentDeploymentSharedSpec{
"worker": {
ComponentType: string(commonconsts.ComponentTypeWorker),
Checkpoint: &v1alpha1.ServiceCheckpointConfig{
Enabled: true,
Mode: v1alpha1.CheckpointModeAuto,
CheckpointRef: &ref,
},
},
},
},
}
checkpointStatuses, checkpointInfos, err := reconciler.reconcileCheckpoints(ctx, dgd)
if err != nil {
t.Fatalf("reconcileCheckpoints() error = %v", err)
}
info, ok := checkpointInfos["worker"]
if !ok {
t.Fatalf("expected checkpoint info for worker service")
}
if !info.Ready {
t.Fatalf("expected referenced checkpoint to be ready")
}
if !info.Exists {
t.Fatalf("expected referenced checkpoint to exist")
}
if info.Hash != hash {
t.Fatalf("checkpoint hash = %s, want %s", info.Hash, hash)
}
if checkpointStatuses["worker"].CheckpointName != friendlyCheckpointName {
t.Fatalf("checkpoint status name = %s, want friendly-checkpoint", checkpointStatuses["worker"].CheckpointName)
}
if !checkpointStatuses["worker"].Ready {
t.Fatalf("expected checkpoint status to be ready")
}
}
func TestDynamoGraphDeploymentReconciler_reconcileCheckpoints_autoModeWaitsForExistingCreatingCheckpoint(t *testing.T) { func TestDynamoGraphDeploymentReconciler_reconcileCheckpoints_autoModeWaitsForExistingCreatingCheckpoint(t *testing.T) {
if err := v1alpha1.AddToScheme(scheme.Scheme); err != nil { if err := v1alpha1.AddToScheme(scheme.Scheme); err != nil {
t.Fatalf("Failed to add v1alpha1 to scheme: %v", err) t.Fatalf("Failed to add v1alpha1 to scheme: %v", err)
......
...@@ -10,30 +10,24 @@ import ( ...@@ -10,30 +10,24 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
"time"
"github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1" "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
resourcev1 "k8s.io/api/resource/v1" resourcev1 "k8s.io/api/resource/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors" apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/types"
"k8s.io/utils/ptr"
"sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client"
) )
const ( const (
gmsSharedVolumeName = "gms-shared" defaultDeviceClassName = "gpu.nvidia.com"
gmsSharedMountPath = "/shared"
gmsDRAClaimName = "shared-gpu"
defaultDeviceClassName = "gpu.nvidia.com"
gmsProcessesPerGPU = 2
gmsStartupProbeTimeout = 2 * time.Minute
gmsStartupProbePeriodSec = 2
) )
func isGMSEnabled(component *v1alpha1.DynamoComponentDeploymentSharedSpec) bool { // IsGMSEnabled reports whether GPU Memory Service is requested for the component.
func IsGMSEnabled(component *v1alpha1.DynamoComponentDeploymentSharedSpec) bool {
return component.GPUMemoryService != nil && component.GPUMemoryService.Enabled return component.GPUMemoryService != nil && component.GPUMemoryService.Enabled
} }
...@@ -58,6 +52,9 @@ func getGPUCount(component *v1alpha1.DynamoComponentDeploymentSharedSpec) (int, ...@@ -58,6 +52,9 @@ func getGPUCount(component *v1alpha1.DynamoComponentDeploymentSharedSpec) (int,
if err != nil { if err != nil {
return 0, fmt.Errorf("invalid GPU count %q: %w", gpuStr, err) return 0, fmt.Errorf("invalid GPU count %q: %w", gpuStr, err)
} }
if count <= 0 {
return 0, fmt.Errorf("GPU count must be greater than 0 when GPU memory service is enabled")
}
return count, nil return count, nil
} }
...@@ -70,49 +67,49 @@ func getDeviceClassName(component *v1alpha1.DynamoComponentDeploymentSharedSpec) ...@@ -70,49 +67,49 @@ func getDeviceClassName(component *v1alpha1.DynamoComponentDeploymentSharedSpec)
return defaultDeviceClassName return defaultDeviceClassName
} }
// applyGPUMemoryService transforms a pod spec to include a GMS sidecar with // resolveMainContainer finds the container named "main" in the pod spec.
// DRA shared GPU access. The main container's GPU resources are replaced with // Falls back to Containers[0] when there is no container named "main"
// a DRA ResourceClaim, and a GMS init container is added. // (e.g. failover pods with engine-0/engine-1 naming).
// func resolveMainContainer(podSpec *corev1.PodSpec) (*corev1.Container, error) {
// claimTemplateName is the name of the ResourceClaimTemplate that will provide if len(podSpec.Containers) == 0 {
// shared GPU access; callers should compute it via GMSResourceClaimTemplateName. return nil, fmt.Errorf("pod spec must have at least one container for GPU memory service")
func applyGPUMemoryService( }
for i := range podSpec.Containers {
if podSpec.Containers[i].Name == commonconsts.MainContainerName {
return &podSpec.Containers[i], nil
}
}
return &podSpec.Containers[0], nil
}
// ApplyGPUMemoryService transforms a pod spec to include GMS server sidecars
// with DRA shared GPU access. The main container's GPU resources are replaced
// with a DRA ResourceClaim.
func ApplyGPUMemoryService(
podSpec *corev1.PodSpec, podSpec *corev1.PodSpec,
component *v1alpha1.DynamoComponentDeploymentSharedSpec, component *v1alpha1.DynamoComponentDeploymentSharedSpec,
claimTemplateName string, claimTemplateName string,
) error { ) error {
if len(podSpec.Containers) == 0 {
return fmt.Errorf("pod spec must have at least one container for GPU memory service")
}
gpuCount, err := getGPUCount(component) gpuCount, err := getGPUCount(component)
if err != nil { if err != nil {
return err return err
} }
_ = gpuCount // GPU count is used for DRA claim template; sidecar discovers devices via pynvml
mainContainer := &podSpec.Containers[0] mainContainer, err := resolveMainContainer(podSpec)
if err != nil {
return err
}
// Replace GPU resources with DRA claim on main container // Replace GPU resources with DRA claim on main container
removeGPUResources(mainContainer) removeGPUResources(mainContainer)
mainContainer.Resources.Claims = append(mainContainer.Resources.Claims, corev1.ResourceClaim{ mainContainer.Resources.Claims = append(mainContainer.Resources.Claims, corev1.ResourceClaim{
Name: gmsDRAClaimName, Name: gmsruntime.DRAClaimName,
})
// Add shared volume mount and TMPDIR to main container
mainContainer.VolumeMounts = append(mainContainer.VolumeMounts, corev1.VolumeMount{
Name: gmsSharedVolumeName,
MountPath: gmsSharedMountPath,
})
mainContainer.Env = append(mainContainer.Env, corev1.EnvVar{
Name: "TMPDIR", Value: gmsSharedMountPath,
}) })
// Add GMS sidecar // Add GMS server sidecar, shared volume, and socket env vars.
gmsSidecar := buildGMSSidecar(mainContainer.Image, gpuCount) // The sidecar gets DRA claims copied from main automatically.
podSpec.InitContainers = append(podSpec.InitContainers, gmsSidecar) gmsruntime.EnsureServerSidecar(podSpec, mainContainer)
// Add shared volume
podSpec.Volumes = append(podSpec.Volumes, gmsSharedVolume())
// GPU nodes are typically tainted with nvidia.com/gpu=NoSchedule. With // GPU nodes are typically tainted with nvidia.com/gpu=NoSchedule. With
// traditional scheduling the device-plugin injects the matching toleration, // traditional scheduling the device-plugin injects the matching toleration,
...@@ -126,7 +123,7 @@ func applyGPUMemoryService( ...@@ -126,7 +123,7 @@ func applyGPUMemoryService(
// Add pod-level DRA resource claim referencing the ResourceClaimTemplate // Add pod-level DRA resource claim referencing the ResourceClaimTemplate
podSpec.ResourceClaims = append(podSpec.ResourceClaims, corev1.PodResourceClaim{ podSpec.ResourceClaims = append(podSpec.ResourceClaims, corev1.PodResourceClaim{
Name: gmsDRAClaimName, Name: gmsruntime.DRAClaimName,
ResourceClaimTemplateName: &claimTemplateName, ResourceClaimTemplateName: &claimTemplateName,
}) })
...@@ -145,85 +142,6 @@ func removeGPUResources(container *corev1.Container) { ...@@ -145,85 +142,6 @@ func removeGPUResources(container *corev1.Container) {
} }
} }
// buildGMSSidecar creates the GMS weight server as a sidecar init container
// (restartPolicy: Always). kubelet starts it before regular containers and
// keeps it running for the pod's lifetime.
//
// Each GPU gets two GMS subprocesses (weights + kv_cache) via a bash wrapper
// that forwards signals and exits if any child dies. TMPDIR is set so
// UUID-based sockets land in the shared volume.
func buildGMSSidecar(image string, gpuCount int) corev1.Container {
return corev1.Container{
Name: "gms-weights",
Image: image,
Command: []string{"bash", "-c"},
Args: []string{gmsWrapperScript(gpuCount)},
RestartPolicy: ptr.To(corev1.ContainerRestartPolicyAlways),
Env: []corev1.EnvVar{
{Name: "TMPDIR", Value: gmsSharedMountPath},
},
VolumeMounts: []corev1.VolumeMount{
{
Name: gmsSharedVolumeName,
MountPath: gmsSharedMountPath,
},
},
StartupProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
Exec: &corev1.ExecAction{
Command: gmsReadyCheckCommand(gpuCount),
},
},
PeriodSeconds: int32(gmsStartupProbePeriodSec),
FailureThreshold: int32(gmsStartupProbeTimeout/time.Second) / int32(gmsStartupProbePeriodSec),
},
Resources: corev1.ResourceRequirements{
Claims: []corev1.ResourceClaim{
{Name: gmsDRAClaimName},
},
},
}
}
// gmsWrapperScript generates a bash script that launches two GMS subprocesses
// per GPU device (one for weights, one for kv_cache), waits for any to exit,
// then tears down the process group.
func gmsWrapperScript(gpuCount int) string {
devList := make([]string, gpuCount)
for i := range gpuCount {
devList[i] = strconv.Itoa(i)
}
return fmt.Sprintf(
`trap 'kill 0 2>/dev/null || true' EXIT
for dev in %s; do
python3 -m gpu_memory_service --device "$dev" --tag weights &
echo "Started GMS device=$dev tag=weights pid=$!"
python3 -m gpu_memory_service --device "$dev" --tag kv_cache &
echo "Started GMS device=$dev tag=kv_cache pid=$!"
done
wait -n
echo "A GMS subprocess exited, shutting down"`, strings.Join(devList, " "))
}
// gmsReadyCheckCommand returns the exec probe command that verifies the
// expected number of GMS UDS sockets exist on the shared volume.
// With 2-tag GMS (weights + kv_cache), there are 2 sockets per GPU.
func gmsReadyCheckCommand(gpuCount int) []string {
return []string{
"sh", "-c",
fmt.Sprintf("test $(ls %s/gms_*.sock 2>/dev/null | wc -l) -ge %d", gmsSharedMountPath, gpuCount*gmsProcessesPerGPU),
}
}
func gmsSharedVolume() corev1.Volume {
return corev1.Volume{
Name: gmsSharedVolumeName,
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{},
},
}
}
// GMSResourceClaimTemplateName returns the deterministic name for the // GMSResourceClaimTemplateName returns the deterministic name for the
// ResourceClaimTemplate associated with a GMS-enabled component. // ResourceClaimTemplate associated with a GMS-enabled component.
func GMSResourceClaimTemplateName(parentName, serviceName string) string { func GMSResourceClaimTemplateName(parentName, serviceName string) string {
...@@ -254,7 +172,7 @@ func GenerateGMSResourceClaimTemplate( ...@@ -254,7 +172,7 @@ func GenerateGMSResourceClaimTemplate(
}, },
} }
if !isGMSEnabled(component) { if !IsGMSEnabled(component) {
return template, true, nil return template, true, nil
} }
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1" "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
...@@ -63,14 +64,14 @@ func gmsBasePodSpec() corev1.PodSpec { ...@@ -63,14 +64,14 @@ func gmsBasePodSpec() corev1.PodSpec {
func TestApplyGPUMemoryService_EmptyContainers(t *testing.T) { func TestApplyGPUMemoryService_EmptyContainers(t *testing.T) {
ps := corev1.PodSpec{} ps := corev1.PodSpec{}
err := applyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu") err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.Error(t, err) require.Error(t, err)
assert.Contains(t, err.Error(), "at least one container") assert.Contains(t, err.Error(), "at least one container")
} }
func TestApplyGPUMemoryService_MainContainerTransformed(t *testing.T) { func TestApplyGPUMemoryService_MainContainerTransformed(t *testing.T) {
ps := gmsBasePodSpec() ps := gmsBasePodSpec()
err := applyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu") err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err) require.NoError(t, err)
main := ps.Containers[0] main := ps.Containers[0]
...@@ -82,53 +83,56 @@ func TestApplyGPUMemoryService_MainContainerTransformed(t *testing.T) { ...@@ -82,53 +83,56 @@ func TestApplyGPUMemoryService_MainContainerTransformed(t *testing.T) {
// Should have DRA claim // Should have DRA claim
require.Len(t, main.Resources.Claims, 1) require.Len(t, main.Resources.Claims, 1)
assert.Equal(t, gmsDRAClaimName, main.Resources.Claims[0].Name) assert.Equal(t, gmsruntime.DRAClaimName, main.Resources.Claims[0].Name)
// Should have shared volume mount // Should have shared volume mount
var hasSharedMount bool var hasSharedMount bool
for _, vm := range main.VolumeMounts { for _, vm := range main.VolumeMounts {
if vm.Name == gmsSharedVolumeName && vm.MountPath == gmsSharedMountPath { if vm.Name == gmsruntime.SharedVolumeName && vm.MountPath == gmsruntime.SharedMountPath {
hasSharedMount = true hasSharedMount = true
} }
} }
assert.True(t, hasSharedMount, "main container should have gms-shared volume mount") assert.True(t, hasSharedMount, "main container should have gms-shared volume mount")
// Should have TMPDIR // Should have TMPDIR and GMS_SOCKET_DIR
envMap := envToMap(main.Env) envMap := envToMap(main.Env)
assert.Equal(t, gmsSharedMountPath, envMap["TMPDIR"]) assert.Equal(t, gmsruntime.SharedMountPath, envMap["TMPDIR"])
assert.Equal(t, gmsruntime.SharedMountPath, envMap["GMS_SOCKET_DIR"])
} }
func TestApplyGPUMemoryService_GMSSidecarInjected(t *testing.T) { func TestApplyGPUMemoryService_GMSSidecarInjected(t *testing.T) {
ps := gmsBasePodSpec() ps := gmsBasePodSpec()
err := applyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu") err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err) require.NoError(t, err)
require.Len(t, ps.InitContainers, 1) require.Len(t, ps.InitContainers, 1)
gms := ps.InitContainers[0] gms := ps.InitContainers[0]
assert.Equal(t, "gms-weights", gms.Name) assert.Equal(t, gmsruntime.ServerContainerName, gms.Name)
assert.Equal(t, "test-image:latest", gms.Image) assert.Equal(t, "test-image:latest", gms.Image)
assert.Equal(t, []string{"bash", "-c"}, gms.Command) assert.Equal(t, []string{"python3", "-m", "gpu_memory_service.cli.server"}, gms.Command)
assert.Contains(t, gms.Args[0], "gpu_memory_service --device")
assert.NotNil(t, gms.RestartPolicy) assert.NotNil(t, gms.RestartPolicy)
assert.Equal(t, corev1.ContainerRestartPolicyAlways, *gms.RestartPolicy) assert.Equal(t, corev1.ContainerRestartPolicyAlways, *gms.RestartPolicy)
require.NotNil(t, gms.StartupProbe)
assert.Equal(t, int32(1), gms.StartupProbe.PeriodSeconds)
assert.Equal(t, int32(300), gms.StartupProbe.FailureThreshold)
// GMS sidecar should have DRA claim // GMS sidecar should have DRA claim copied from main
require.Len(t, gms.Resources.Claims, 1) require.Len(t, gms.Resources.Claims, 1)
assert.Equal(t, gmsDRAClaimName, gms.Resources.Claims[0].Name) assert.Equal(t, gmsruntime.DRAClaimName, gms.Resources.Claims[0].Name)
// GMS sidecar should have TMPDIR // GMS sidecar should have TMPDIR
gmsEnv := envToMap(gms.Env) gmsEnv := envToMap(gms.Env)
assert.Equal(t, gmsSharedMountPath, gmsEnv["TMPDIR"]) assert.Equal(t, gmsruntime.SharedMountPath, gmsEnv["TMPDIR"])
} }
func TestApplyGPUMemoryService_SharedVolume(t *testing.T) { func TestApplyGPUMemoryService_SharedVolume(t *testing.T) {
ps := gmsBasePodSpec() ps := gmsBasePodSpec()
err := applyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu") err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err) require.NoError(t, err)
var found bool var found bool
for _, v := range ps.Volumes { for _, v := range ps.Volumes {
if v.Name == gmsSharedVolumeName { if v.Name == gmsruntime.SharedVolumeName {
assert.NotNil(t, v.EmptyDir) assert.NotNil(t, v.EmptyDir)
found = true found = true
} }
...@@ -138,7 +142,7 @@ func TestApplyGPUMemoryService_SharedVolume(t *testing.T) { ...@@ -138,7 +142,7 @@ func TestApplyGPUMemoryService_SharedVolume(t *testing.T) {
func TestApplyGPUMemoryService_GPUToleration(t *testing.T) { func TestApplyGPUMemoryService_GPUToleration(t *testing.T) {
ps := gmsBasePodSpec() ps := gmsBasePodSpec()
err := applyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu") err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err) require.NoError(t, err)
var found bool var found bool
...@@ -153,17 +157,17 @@ func TestApplyGPUMemoryService_GPUToleration(t *testing.T) { ...@@ -153,17 +157,17 @@ func TestApplyGPUMemoryService_GPUToleration(t *testing.T) {
func TestApplyGPUMemoryService_DRAResourceClaim(t *testing.T) { func TestApplyGPUMemoryService_DRAResourceClaim(t *testing.T) {
ps := gmsBasePodSpec() ps := gmsBasePodSpec()
err := applyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu") err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err) require.NoError(t, err)
require.Len(t, ps.ResourceClaims, 1) require.Len(t, ps.ResourceClaims, 1)
assert.Equal(t, gmsDRAClaimName, ps.ResourceClaims[0].Name) assert.Equal(t, gmsruntime.DRAClaimName, ps.ResourceClaims[0].Name)
assert.Equal(t, "myapp-worker-gpu", *ps.ResourceClaims[0].ResourceClaimTemplateName) assert.Equal(t, "myapp-worker-gpu", *ps.ResourceClaims[0].ResourceClaimTemplateName)
} }
func TestApplyGPUMemoryService_PreservesExistingEnv(t *testing.T) { func TestApplyGPUMemoryService_PreservesExistingEnv(t *testing.T) {
ps := gmsBasePodSpec() ps := gmsBasePodSpec()
err := applyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu") err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err) require.NoError(t, err)
main := ps.Containers[0] main := ps.Containers[0]
...@@ -174,38 +178,32 @@ func TestApplyGPUMemoryService_PreservesExistingEnv(t *testing.T) { ...@@ -174,38 +178,32 @@ func TestApplyGPUMemoryService_PreservesExistingEnv(t *testing.T) {
func TestApplyGPUMemoryService_SingleContainer(t *testing.T) { func TestApplyGPUMemoryService_SingleContainer(t *testing.T) {
ps := gmsBasePodSpec() ps := gmsBasePodSpec()
err := applyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu") err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err) require.NoError(t, err)
// Should still have exactly 1 regular container (no duplication)
assert.Len(t, ps.Containers, 1) assert.Len(t, ps.Containers, 1)
assert.Equal(t, "main", ps.Containers[0].Name) assert.Equal(t, "main", ps.Containers[0].Name)
} }
// --- GMS sidecar helpers --- func TestApplyGPUMemoryService_ResolvesMainByName(t *testing.T) {
ps := gmsBasePodSpec()
// Prepend a sidecar so main is NOT Containers[0]
sidecar := corev1.Container{Name: "sidecar", Image: "sidecar:latest"}
ps.Containers = append([]corev1.Container{sidecar}, ps.Containers...)
require.Equal(t, "sidecar", ps.Containers[0].Name)
func TestGmsWrapperScript_TwoTagsPerDevice(t *testing.T) { err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
script := gmsWrapperScript(3) require.NoError(t, err)
assert.Contains(t, script, "for dev in 0 1 2")
assert.Contains(t, script, "--tag weights")
assert.Contains(t, script, "--tag kv_cache")
assert.Contains(t, script, "trap 'kill 0")
assert.Contains(t, script, "wait -n")
}
func TestGmsReadyCheckCommand_TwoSocketsPerGPU(t *testing.T) { // Sidecar should be untouched
cmd := gmsReadyCheckCommand(2) assert.Equal(t, "sidecar", ps.Containers[0].Name)
assert.Equal(t, "sh", cmd[0]) assert.Empty(t, ps.Containers[0].Resources.Claims)
assert.Equal(t, "-c", cmd[1])
assert.Contains(t, cmd[2], "gms_*.sock")
// 2 GPUs * 2 tags = 4 sockets
assert.Contains(t, cmd[2], "-ge 4")
}
func TestGmsReadyCheckCommand_SingleGPU(t *testing.T) { // Main should have DRA claim
cmd := gmsReadyCheckCommand(1) main := ps.Containers[1]
// 1 GPU * 2 tags = 2 sockets assert.Equal(t, "main", main.Name)
assert.Contains(t, cmd[2], "-ge 2") require.Len(t, main.Resources.Claims, 1)
assert.Equal(t, gmsruntime.DRAClaimName, main.Resources.Claims[0].Name)
} }
// --- GenerateGMSResourceClaimTemplate --- // --- GenerateGMSResourceClaimTemplate ---
...@@ -268,13 +266,13 @@ func TestGMSResourceClaimTemplateName(t *testing.T) { ...@@ -268,13 +266,13 @@ func TestGMSResourceClaimTemplateName(t *testing.T) {
// --- isGMSEnabled --- // --- isGMSEnabled ---
func TestIsGMSEnabled(t *testing.T) { func TestIsGMSEnabled(t *testing.T) {
assert.True(t, isGMSEnabled(&v1alpha1.DynamoComponentDeploymentSharedSpec{ assert.True(t, IsGMSEnabled(&v1alpha1.DynamoComponentDeploymentSharedSpec{
GPUMemoryService: &v1alpha1.GPUMemoryServiceSpec{Enabled: true}, GPUMemoryService: &v1alpha1.GPUMemoryServiceSpec{Enabled: true},
})) }))
assert.False(t, isGMSEnabled(&v1alpha1.DynamoComponentDeploymentSharedSpec{ assert.False(t, IsGMSEnabled(&v1alpha1.DynamoComponentDeploymentSharedSpec{
GPUMemoryService: &v1alpha1.GPUMemoryServiceSpec{Enabled: false}, GPUMemoryService: &v1alpha1.GPUMemoryServiceSpec{Enabled: false},
})) }))
assert.False(t, isGMSEnabled(&v1alpha1.DynamoComponentDeploymentSharedSpec{})) assert.False(t, IsGMSEnabled(&v1alpha1.DynamoComponentDeploymentSharedSpec{}))
} }
// --- getGPUCount --- // --- getGPUCount ---
......
...@@ -1183,9 +1183,9 @@ func GenerateBasePodSpec( ...@@ -1183,9 +1183,9 @@ func GenerateBasePodSpec(
} }
// Inject GMS sidecar with DRA shared GPU access when GPU memory service is enabled. // Inject GMS sidecar with DRA shared GPU access when GPU memory service is enabled.
if isGMSEnabled(component) { if IsGMSEnabled(component) {
claimTemplateName := GMSResourceClaimTemplateName(parentGraphDeploymentName, serviceName) claimTemplateName := GMSResourceClaimTemplateName(parentGraphDeploymentName, serviceName)
if err := applyGPUMemoryService(&podSpec, component, claimTemplateName); err != nil { if err := ApplyGPUMemoryService(&podSpec, component, claimTemplateName); err != nil {
return nil, fmt.Errorf("failed to apply GPU memory service: %w", err) return nil, fmt.Errorf("failed to apply GPU memory service: %w", err)
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment