"examples/multimodal/vscode:/vscode.git/clone" did not exist on "93208162753986f9449d3671d6a263dfc4f4381e"
Unverified Commit b2f7f220 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix(gms): rewrite GMS checkpoint/restore operator support (#8194)


Co-authored-by: default avatarDmitry Tokarev <dtokarev@nvidia.com>
parent 2d86b81d
...@@ -23,7 +23,7 @@ import ( ...@@ -23,7 +23,7 @@ import (
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1" nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms" gms "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol" snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -183,50 +183,6 @@ func TestCreateOrGetAutoCheckpointSetsDefaultArtifactVersion(t *testing.T) { ...@@ -183,50 +183,6 @@ func TestCreateOrGetAutoCheckpointSetsDefaultArtifactVersion(t *testing.T) {
// --- InjectCheckpointIntoPodSpec tests --- // --- InjectCheckpointIntoPodSpec tests ---
func TestEnsurePodInfoVolumeMergesExistingDownwardAPIItems(t *testing.T) {
podSpec := &corev1.PodSpec{
Volumes: []corev1.Volume{{
Name: consts.PodInfoVolumeName,
VolumeSource: corev1.VolumeSource{
DownwardAPI: &corev1.DownwardAPIVolumeSource{
Items: []corev1.DownwardAPIVolumeFile{
{
Path: "pod_name",
FieldRef: &corev1.ObjectFieldSelector{FieldPath: "metadata.name"},
},
{
Path: "custom",
FieldRef: &corev1.ObjectFieldSelector{FieldPath: "metadata.labels['custom']"},
},
},
},
},
}},
}
EnsurePodInfoVolume(podSpec)
require.Len(t, podSpec.Volumes, 1)
require.NotNil(t, podSpec.Volumes[0].DownwardAPI)
fields := map[string]string{}
for _, item := range podSpec.Volumes[0].DownwardAPI.Items {
if item.FieldRef != nil {
fields[item.Path] = item.FieldRef.FieldPath
}
}
assert.Equal(t, consts.PodInfoFieldPodName, fields["pod_name"])
assert.Equal(t, consts.PodInfoFieldPodUID, fields["pod_uid"])
assert.Equal(t, consts.PodInfoFieldPodNamespace, fields["pod_namespace"])
assert.Equal(t, "metadata.labels['custom']", fields["custom"])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoNamespace+"']", fields[consts.PodInfoFileDynNamespace])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoWorkerHash+"']", fields[consts.PodInfoFileDynNamespaceWorkerSuffix])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoComponentType+"']", fields[consts.PodInfoFileDynComponent])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoGraphDeploymentName+"']", fields[consts.PodInfoFileDynParentDGDName])
assert.Equal(t, consts.PodInfoFieldPodNamespace, fields[consts.PodInfoFileDynParentDGDNamespace])
}
func TestInjectCheckpointIntoPodSpec(t *testing.T) { func TestInjectCheckpointIntoPodSpec(t *testing.T) {
t.Run("ready checkpoint injects podinfo and overrides command", func(t *testing.T) { t.Run("ready checkpoint injects podinfo and overrides command", func(t *testing.T) {
podSpec := testPodSpec() podSpec := testPodSpec()
...@@ -263,65 +219,21 @@ func TestInjectCheckpointIntoPodSpec(t *testing.T) { ...@@ -263,65 +219,21 @@ func TestInjectCheckpointIntoPodSpec(t *testing.T) {
assert.Equal(t, consts.PodInfoMountPath, mountPaths[consts.PodInfoVolumeName]) assert.Equal(t, consts.PodInfoMountPath, mountPaths[consts.PodInfoVolumeName])
}) })
t.Run("ready checkpoint augments existing podinfo volume", func(t *testing.T) {
podSpec := testPodSpec()
podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{
Name: consts.PodInfoVolumeName,
VolumeSource: corev1.VolumeSource{
DownwardAPI: &corev1.DownwardAPIVolumeSource{
Items: []corev1.DownwardAPIVolumeFile{
{Path: "pod_name", FieldRef: &corev1.ObjectFieldSelector{FieldPath: consts.PodInfoFieldPodName}},
{Path: "pod_uid", FieldRef: &corev1.ObjectFieldSelector{FieldPath: consts.PodInfoFieldPodUID}},
{Path: "pod_namespace", FieldRef: &corev1.ObjectFieldSelector{FieldPath: consts.PodInfoFieldPodNamespace}},
},
},
},
})
info := &CheckpointInfo{Enabled: true, Ready: true, Identity: ptr.To(testIdentity())}
reader := fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build()
require.NoError(t, InjectCheckpointIntoPodSpec(context.Background(), reader, testNamespace, podSpec, info))
var podInfoVolume *corev1.Volume
for i := range podSpec.Volumes {
if podSpec.Volumes[i].Name == consts.PodInfoVolumeName {
podInfoVolume = &podSpec.Volumes[i]
break
}
}
require.NotNil(t, podInfoVolume)
require.NotNil(t, podInfoVolume.DownwardAPI)
fields := map[string]string{}
for _, item := range podInfoVolume.DownwardAPI.Items {
if item.FieldRef != nil {
fields[item.Path] = item.FieldRef.FieldPath
}
}
assert.Equal(t, consts.PodInfoFieldPodName, fields["pod_name"])
assert.Equal(t, consts.PodInfoFieldPodUID, fields["pod_uid"])
assert.Equal(t, consts.PodInfoFieldPodNamespace, fields["pod_namespace"])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoNamespace+"']", fields[consts.PodInfoFileDynNamespace])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoWorkerHash+"']", fields[consts.PodInfoFileDynNamespaceWorkerSuffix])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoComponentType+"']", fields[consts.PodInfoFileDynComponent])
assert.Equal(t, "metadata.labels['"+consts.KubeLabelDynamoGraphDeploymentName+"']", fields[consts.PodInfoFileDynParentDGDName])
assert.Equal(t, consts.PodInfoFieldPodNamespace, fields[consts.PodInfoFileDynParentDGDNamespace])
})
t.Run("ready checkpoint targets the container named main", func(t *testing.T) { t.Run("ready checkpoint targets the container named main", func(t *testing.T) {
podSpec := &corev1.PodSpec{ podSpec := &corev1.PodSpec{
Containers: []corev1.Container{ Containers: []corev1.Container{
{Name: "main", Image: "main:latest", Command: []string{"python3"}, Args: []string{"-m", "dynamo.vllm"}},
{Name: "sidecar", Image: "sidecar:latest", Command: []string{"sidecar"}, Args: []string{"run"}}, {Name: "sidecar", Image: "sidecar:latest", Command: []string{"sidecar"}, Args: []string{"run"}},
{Name: consts.MainContainerName, Image: "main:latest", Command: []string{"python3"}, Args: []string{"-m", "dynamo.vllm"}},
}, },
} }
info := &CheckpointInfo{Enabled: true, Ready: true, Hash: testHash} info := &CheckpointInfo{Enabled: true, Ready: true, Hash: testHash}
reader := fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build() reader := fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build()
require.NoError(t, InjectCheckpointIntoPodSpec(context.Background(), reader, testNamespace, podSpec, info)) require.NoError(t, InjectCheckpointIntoPodSpec(context.Background(), reader, testNamespace, podSpec, info))
assert.Equal(t, []string{"sidecar"}, podSpec.Containers[0].Command) assert.Equal(t, []string{"sleep", "infinity"}, podSpec.Containers[0].Command)
assert.Equal(t, []string{"run"}, podSpec.Containers[0].Args) assert.Nil(t, podSpec.Containers[0].Args)
assert.Equal(t, []string{"sleep", "infinity"}, podSpec.Containers[1].Command) assert.Equal(t, []string{"sidecar"}, podSpec.Containers[1].Command)
assert.Nil(t, podSpec.Containers[1].Args) assert.Equal(t, []string{"run"}, podSpec.Containers[1].Args)
}) })
t.Run("ready gms checkpoint injects restore sidecars and loader mount", func(t *testing.T) { t.Run("ready gms checkpoint injects restore sidecars and loader mount", func(t *testing.T) {
...@@ -331,22 +243,24 @@ func TestInjectCheckpointIntoPodSpec(t *testing.T) { ...@@ -331,22 +243,24 @@ func TestInjectCheckpointIntoPodSpec(t *testing.T) {
reader := fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build() reader := fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build()
require.NoError(t, InjectCheckpointIntoPodSpec(context.Background(), reader, testNamespace, podSpec, info)) require.NoError(t, InjectCheckpointIntoPodSpec(context.Background(), reader, testNamespace, podSpec, info))
gmsServer := findContainer(podSpec, gmsruntime.ServerContainerName) gmsServer := findContainer(podSpec, gms.ServerContainerName)
require.NotNil(t, gmsServer) require.NotNil(t, gmsServer)
loader := findContainer(podSpec, GMSLoaderContainer) loader := findContainer(podSpec, GMSLoaderContainer)
require.NotNil(t, loader) require.NotNil(t, loader)
// Restore: gms-server should be a regular container, not an init container // Restore: server and loader are init sidecars (restartPolicy=Always)
assert.Empty(t, podSpec.InitContainers, "restore pods should not have gms-server as init container") assert.NotNil(t, gmsServer.RestartPolicy, "restore gms-server should have RestartPolicy")
assert.Nil(t, gmsServer.RestartPolicy, "restore gms-server should not have RestartPolicy") assert.Equal(t, corev1.ContainerRestartPolicyAlways, *gmsServer.RestartPolicy)
assert.Nil(t, gmsServer.StartupProbe, "restore gms-server should not have StartupProbe") assert.Nil(t, gmsServer.StartupProbe, "restore gms-server should not have StartupProbe")
assert.NotNil(t, loader.RestartPolicy, "restore gms-loader should have RestartPolicy")
assert.Equal(t, corev1.ContainerRestartPolicyAlways, *loader.RestartPolicy)
mounts := map[string]string{} mounts := map[string]string{}
for _, mount := range loader.VolumeMounts { for _, mount := range loader.VolumeMounts {
mounts[mount.Name] = mount.MountPath mounts[mount.Name] = mount.MountPath
} }
assert.Equal(t, "/checkpoints", mounts[snapshotprotocol.CheckpointVolumeName]) assert.Equal(t, "/checkpoints", mounts[snapshotprotocol.CheckpointVolumeName])
assert.Equal(t, gmsruntime.SharedMountPath, mounts[gmsruntime.SharedVolumeName]) assert.Equal(t, gms.SharedMountPath, mounts[gms.SharedVolumeName])
env := map[string]string{} env := map[string]string{}
for _, item := range loader.Env { for _, item := range loader.Env {
...@@ -366,8 +280,7 @@ func TestInjectCheckpointIntoPodSpec(t *testing.T) { ...@@ -366,8 +280,7 @@ func TestInjectCheckpointIntoPodSpec(t *testing.T) {
errMsg string errMsg string
}{ }{
{"hash empty and identity nil", testPodSpec(), &CheckpointInfo{Enabled: true}, fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build(), "identity is nil"}, {"hash empty and identity nil", testPodSpec(), &CheckpointInfo{Enabled: true}, fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build(), "identity is nil"},
{"no containers", &corev1.PodSpec{}, testInfo(), fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build(), "no container found"}, {"no containers", &corev1.PodSpec{}, testInfo(), fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build(), "no container named"},
{"main container missing", &corev1.PodSpec{Containers: []corev1.Container{{Name: "sidecar", Image: "img", Command: []string{"python3"}}}}, testInfo(), fake.NewClientBuilder().WithScheme(testScheme()).WithObjects(testSnapshotAgentDaemonSet()).Build(), "main container not found"},
{"snapshot daemonset missing", testPodSpec(), testInfo(), fake.NewClientBuilder().WithScheme(testScheme()).Build(), "no snapshot-agent daemonset found"}, {"snapshot daemonset missing", testPodSpec(), testInfo(), fake.NewClientBuilder().WithScheme(testScheme()).Build(), "no snapshot-agent daemonset found"},
} { } {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
......
...@@ -6,15 +6,13 @@ ...@@ -6,15 +6,13 @@
package checkpoint package checkpoint
import ( import (
"context"
"fmt" "fmt"
"path/filepath" "path/filepath"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms" gms "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol" snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
ctrlclient "sigs.k8s.io/controller-runtime/pkg/client" "k8s.io/utils/ptr"
) )
const ( const (
...@@ -23,101 +21,80 @@ const ( ...@@ -23,101 +21,80 @@ const (
gmsCheckpointLoaderModule = "gpu_memory_service.cli.snapshot.loader" gmsCheckpointLoaderModule = "gpu_memory_service.cli.snapshot.loader"
gmsCheckpointSaverModule = "gpu_memory_service.cli.snapshot.saver" gmsCheckpointSaverModule = "gpu_memory_service.cli.snapshot.saver"
)
func ResolveGMSCheckpointStorage(
ctx context.Context,
reader ctrlclient.Reader,
namespace string,
checkpointID string,
artifactVersion string,
) (snapshotprotocol.Storage, error) {
if reader == nil {
return snapshotprotocol.Storage{}, fmt.Errorf("checkpoint client is required")
}
daemonSets := &appsv1.DaemonSetList{}
if err := reader.List(
ctx,
daemonSets,
ctrlclient.InNamespace(namespace),
ctrlclient.MatchingLabels{snapshotprotocol.SnapshotAgentLabelKey: snapshotprotocol.SnapshotAgentLabelValue},
); err != nil {
return snapshotprotocol.Storage{}, fmt.Errorf("list snapshot-agent daemonsets in %s: %w", namespace, err)
}
storage, err := snapshotprotocol.DiscoverStorageFromDaemonSets(namespace, daemonSets.Items) // envCheckpointDir is the environment variable name for the GMS
if err != nil { // checkpoint artifact directory on the snapshot PVC.
return snapshotprotocol.Storage{}, err envCheckpointDir = "GMS_CHECKPOINT_DIR"
} )
return snapshotprotocol.ResolveCheckpointStorage(checkpointID, artifactVersion, storage)
}
// BuildGMSRestoreSidecars prepares GMS infrastructure for a restore pod and // EnsureGMSRestoreSidecars adds GMS server + loader containers to the pod spec
// returns the additional containers the caller must append to podSpec.Containers. // for a checkpoint restore. The server runs as a regular container (not init)
// // because the CRIU-restored main process already has GPU memory mapped and
// The GMS server runs as a regular container (not init) because the CRIU-restored // all containers must start in parallel.
// main process already has GPU memory mapped and does not need sockets at func EnsureGMSRestoreSidecars(
// startup. The gms-loader polls for sockets internally via wait_for_weights_socket.
func BuildGMSRestoreSidecars(
podSpec *corev1.PodSpec, podSpec *corev1.PodSpec,
mainContainer *corev1.Container, mainContainer *corev1.Container,
storage snapshotprotocol.Storage, storage snapshotprotocol.Storage,
) []corev1.Container { ) {
if podSpec == nil || mainContainer == nil { if podSpec == nil || mainContainer == nil {
return nil return
} }
// Remove gms-server from initContainers if the DGD-level // The DGD path adds the GMS server as an init sidecar (blocks until
// applyGPUMemoryService already placed it there. For restore pods the // sockets are ready). For restore, move it to a regular container so
// server runs as a regular container so that all containers start in // all containers start in parallel.
// parallel — the restored main process does not need sockets at startup.
for i := range podSpec.InitContainers { for i := range podSpec.InitContainers {
if podSpec.InitContainers[i].Name == gmsruntime.ServerContainerName { if podSpec.InitContainers[i].Name == gms.ServerContainerName {
podSpec.InitContainers = append(podSpec.InitContainers[:i], podSpec.InitContainers[i+1:]...) podSpec.InitContainers = append(podSpec.InitContainers[:i], podSpec.InitContainers[i+1:]...)
break break
} }
} }
gms.EnsureSharedVolume(podSpec, mainContainer)
snapshotprotocol.InjectCheckpointVolume(podSpec, storage.PVCName)
server := gmsruntime.BuildServerContainer(podSpec, mainContainer) server := gms.Container(gms.ServerContainerName, gms.ServerModule, mainContainer.Image)
server.RestartPolicy = ptr.To(corev1.ContainerRestartPolicyAlways)
loader := gmsCheckpointLoaderContainer(mainContainer.Image) loader := gms.Container(GMSLoaderContainer, gmsCheckpointLoaderModule, mainContainer.Image)
copyGMSDeviceClaims(mainContainer, &loader)
ensureCheckpointVolume(podSpec, storage.PVCName)
loader.VolumeMounts = append(loader.VolumeMounts, corev1.VolumeMount{Name: snapshotprotocol.CheckpointVolumeName, MountPath: storage.BasePath}) loader.VolumeMounts = append(loader.VolumeMounts, corev1.VolumeMount{Name: snapshotprotocol.CheckpointVolumeName, MountPath: storage.BasePath})
loader.Env = append(loader.Env, corev1.EnvVar{Name: "GMS_CHECKPOINT_DIR", Value: resolveGMSArtifactDir(storage)}) loader.Env = append(loader.Env, corev1.EnvVar{Name: envCheckpointDir, Value: resolveGMSArtifactDir(storage)})
loader.RestartPolicy = ptr.To(corev1.ContainerRestartPolicyAlways)
return []corev1.Container{server, loader} podSpec.InitContainers = append(podSpec.InitContainers, server, loader)
} }
// BuildGMSCheckpointJobSidecars prepares GMS infrastructure for a checkpoint // EnsureGMSCheckpointJobSidecars adds GMS server (init) + saver containers
// job and returns the additional containers the caller must append to // to the pod spec for a checkpoint job.
// podSpec.Containers. func EnsureGMSCheckpointJobSidecars(
func BuildGMSCheckpointJobSidecars(
podSpec *corev1.PodSpec, podSpec *corev1.PodSpec,
mainContainer *corev1.Container, mainContainer *corev1.Container,
storage snapshotprotocol.Storage, storage snapshotprotocol.Storage,
) ([]corev1.Container, error) { ) error {
if podSpec == nil || mainContainer == nil { if podSpec == nil || mainContainer == nil {
return nil, nil return nil
} }
if len(mainContainer.Resources.Claims) == 0 { if len(mainContainer.Resources.Claims) == 0 {
return nil, fmt.Errorf("gms sidecars require main container resource claims") return fmt.Errorf("gms sidecars require main container resource claims (DRA must be enabled)")
} }
if storage.PVCName == "" || storage.BasePath == "" || storage.Location == "" { if storage.PVCName == "" || storage.BasePath == "" || storage.Location == "" {
return nil, fmt.Errorf("gms checkpoint jobs require resolved checkpoint storage") return fmt.Errorf("gms checkpoint jobs require resolved checkpoint storage")
} }
gmsruntime.EnsureServerSidecar(podSpec, mainContainer) gmsArtifactDir := resolveGMSArtifactDir(storage)
ensureGMSCheckpointControl(podSpec)
saver := gmsCheckpointSaverContainer(mainContainer.Image) gms.EnsureServerSidecar(podSpec, mainContainer)
copyGMSDeviceClaims(mainContainer, &saver) snapshotprotocol.InjectCheckpointVolume(podSpec, storage.PVCName)
ensureCheckpointVolume(podSpec, storage.PVCName)
saver.VolumeMounts = append(saver.VolumeMounts, corev1.VolumeMount{Name: snapshotprotocol.CheckpointVolumeName, MountPath: storage.BasePath})
saver.Env = append(saver.Env, corev1.EnvVar{Name: "GMS_CHECKPOINT_DIR", Value: resolveGMSArtifactDir(storage)})
return []corev1.Container{saver}, nil saver := gms.Container(GMSSaverContainer, gmsCheckpointSaverModule, mainContainer.Image)
saver.VolumeMounts = append(saver.VolumeMounts, corev1.VolumeMount{Name: snapshotprotocol.CheckpointVolumeName, MountPath: storage.BasePath})
saver.Env = append(saver.Env, corev1.EnvVar{Name: envCheckpointDir, Value: gmsArtifactDir})
// The saver is an init sidecar (restartPolicy=Always) so it doesn't
// affect pod Ready (only the worker's probe matters) and doesn't block
// Job completion. It saves, then sleeps until the pod terminates.
saver.RestartPolicy = ptr.To(corev1.ContainerRestartPolicyAlways)
podSpec.InitContainers = append(podSpec.InitContainers, saver)
return nil
} }
func resolveGMSArtifactDir(storage snapshotprotocol.Storage) string { func resolveGMSArtifactDir(storage snapshotprotocol.Storage) string {
...@@ -128,77 +105,3 @@ func resolveGMSArtifactDir(storage snapshotprotocol.Storage) string { ...@@ -128,77 +105,3 @@ func resolveGMSArtifactDir(storage snapshotprotocol.Storage) string {
checkpointID := filepath.Base(filepath.Dir(filepath.Dir(storage.Location))) checkpointID := filepath.Base(filepath.Dir(filepath.Dir(storage.Location)))
return filepath.Join(storage.BasePath, "gms", checkpointID, "versions", artifactVersion) return filepath.Join(storage.BasePath, "gms", checkpointID, "versions", artifactVersion)
} }
func gmsCheckpointLoaderContainer(image string) corev1.Container {
container := corev1.Container{
Name: GMSLoaderContainer,
Image: image,
Command: []string{"python3", "-m", gmsCheckpointLoaderModule},
Env: []corev1.EnvVar{
{Name: "TMPDIR", Value: gmsruntime.SharedMountPath},
{Name: "GMS_SOCKET_DIR", Value: gmsruntime.SharedMountPath},
},
VolumeMounts: []corev1.VolumeMount{
{Name: gmsruntime.SharedVolumeName, MountPath: gmsruntime.SharedMountPath},
},
}
return container
}
func gmsCheckpointSaverContainer(image string) corev1.Container {
container := corev1.Container{
Name: GMSSaverContainer,
Image: image,
Command: []string{"python3", "-m", gmsCheckpointSaverModule},
Env: []corev1.EnvVar{
{Name: "POD_NAME", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{FieldPath: "metadata.name"}}},
{Name: "POD_NAMESPACE", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{FieldPath: "metadata.namespace"}}},
{Name: "TMPDIR", Value: gmsruntime.SharedMountPath},
{Name: "GMS_SOCKET_DIR", Value: gmsruntime.SharedMountPath},
{Name: "GMS_CONTROL_DIR", Value: gmsruntime.ControlDir},
},
VolumeMounts: []corev1.VolumeMount{
{Name: gmsruntime.SharedVolumeName, MountPath: gmsruntime.SharedMountPath},
{Name: gmsruntime.ControlVolumeName, MountPath: gmsruntime.ControlDir},
},
}
return container
}
// ensureGMSCheckpointControl adds the control volume and injects
// GMS_CONTROL_DIR into the GMS server container for checkpoint coordination.
func ensureGMSCheckpointControl(podSpec *corev1.PodSpec) {
podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{
Name: gmsruntime.ControlVolumeName,
VolumeSource: corev1.VolumeSource{EmptyDir: &corev1.EmptyDirVolumeSource{}},
})
server := gmsruntime.FindServerContainer(podSpec)
if server != nil {
server.VolumeMounts = append(server.VolumeMounts, corev1.VolumeMount{Name: gmsruntime.ControlVolumeName, MountPath: gmsruntime.ControlDir})
server.Env = append(server.Env, corev1.EnvVar{Name: "GMS_CONTROL_DIR", Value: gmsruntime.ControlDir})
}
}
func copyGMSDeviceClaims(mainContainer *corev1.Container, container *corev1.Container) {
if mainContainer == nil || container == nil || len(mainContainer.Resources.Claims) == 0 {
return
}
container.Resources.Claims = append([]corev1.ResourceClaim{}, mainContainer.Resources.Claims...)
}
func ensureCheckpointVolume(podSpec *corev1.PodSpec, pvcName string) {
if pvcName == "" {
return
}
for i := range podSpec.Volumes {
if podSpec.Volumes[i].Name == snapshotprotocol.CheckpointVolumeName {
return
}
}
podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{
Name: snapshotprotocol.CheckpointVolumeName,
VolumeSource: corev1.VolumeSource{
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ClaimName: pvcName},
},
})
}
...@@ -9,42 +9,17 @@ import ( ...@@ -9,42 +9,17 @@ import (
) )
func EnsurePodInfoVolume(podSpec *corev1.PodSpec) { func EnsurePodInfoVolume(podSpec *corev1.PodSpec) {
for i := range podSpec.Volumes { for _, volume := range podSpec.Volumes {
if podSpec.Volumes[i].Name != commonconsts.PodInfoVolumeName { if volume.Name == commonconsts.PodInfoVolumeName {
continue
}
if podSpec.Volumes[i].DownwardAPI == nil {
podSpec.Volumes[i].VolumeSource.DownwardAPI = &corev1.DownwardAPIVolumeSource{}
}
// Merge required items into existing downwardAPI volume.
source := podSpec.Volumes[i].DownwardAPI
pathToIndex := make(map[string]int, len(source.Items))
for j := range source.Items {
pathToIndex[source.Items[j].Path] = j
}
for _, item := range podInfoItems() {
if idx, ok := pathToIndex[item.Path]; ok {
source.Items[idx] = item
continue
}
source.Items = append(source.Items, item)
pathToIndex[item.Path] = len(source.Items) - 1
}
return return
} }
}
podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{ podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{
Name: commonconsts.PodInfoVolumeName, Name: commonconsts.PodInfoVolumeName,
VolumeSource: corev1.VolumeSource{ VolumeSource: corev1.VolumeSource{
DownwardAPI: &corev1.DownwardAPIVolumeSource{ DownwardAPI: &corev1.DownwardAPIVolumeSource{
Items: podInfoItems(), Items: []corev1.DownwardAPIVolumeFile{
},
},
})
}
func podInfoItems() []corev1.DownwardAPIVolumeFile {
return []corev1.DownwardAPIVolumeFile{
{ {
Path: "pod_name", Path: "pod_name",
FieldRef: &corev1.ObjectFieldSelector{ FieldRef: &corev1.ObjectFieldSelector{
...@@ -93,7 +68,10 @@ func podInfoItems() []corev1.DownwardAPIVolumeFile { ...@@ -93,7 +68,10 @@ func podInfoItems() []corev1.DownwardAPIVolumeFile {
FieldPath: commonconsts.PodInfoFieldPodNamespace, FieldPath: commonconsts.PodInfoFieldPodNamespace,
}, },
}, },
} },
},
},
})
} }
func EnsurePodInfoMount(container *corev1.Container) { func EnsurePodInfoMount(container *corev1.Container) {
......
...@@ -38,6 +38,19 @@ func ApplyRestorePodMetadata(labels map[string]string, annotations map[string]st ...@@ -38,6 +38,19 @@ func ApplyRestorePodMetadata(labels map[string]string, annotations map[string]st
snapshotprotocol.ApplyRestoreTargetMetadata(labels, annotations, enabled, hash, artifactVersion) snapshotprotocol.ApplyRestoreTargetMetadata(labels, annotations, enabled, hash, artifactVersion)
} }
// resolveMainContainer finds the container named "main" in the pod spec.
// ExtraPodSpec.PodSpec.Containers can inject user containers before the main
// container (mergo merge happens before main is appended), so index 0 is
// not guaranteed to be the main container here.
func resolveMainContainer(podSpec *corev1.PodSpec) *corev1.Container {
for i := range podSpec.Containers {
if podSpec.Containers[i].Name == commonconsts.MainContainerName {
return &podSpec.Containers[i]
}
}
return nil
}
func InjectCheckpointIntoPodSpec( func InjectCheckpointIntoPodSpec(
ctx context.Context, ctx context.Context,
reader ctrlclient.Reader, reader ctrlclient.Reader,
...@@ -62,18 +75,9 @@ func InjectCheckpointIntoPodSpec( ...@@ -62,18 +75,9 @@ func InjectCheckpointIntoPodSpec(
info.Hash = hash info.Hash = hash
} }
if len(podSpec.Containers) == 0 { mainContainer := resolveMainContainer(podSpec)
return fmt.Errorf("no container found to inject checkpoint config")
}
var mainContainer *corev1.Container
for i := range podSpec.Containers {
if podSpec.Containers[i].Name == commonconsts.MainContainerName {
mainContainer = &podSpec.Containers[i]
break
}
}
if mainContainer == nil { if mainContainer == nil {
return fmt.Errorf("main container not found in pod spec") return fmt.Errorf("no container named %q found in pod spec", commonconsts.MainContainerName)
} }
if reader == nil { if reader == nil {
return fmt.Errorf("checkpoint client is required") return fmt.Errorf("checkpoint client is required")
...@@ -94,14 +98,8 @@ func InjectCheckpointIntoPodSpec( ...@@ -94,14 +98,8 @@ func InjectCheckpointIntoPodSpec(
EnsurePodInfoVolume(podSpec) EnsurePodInfoVolume(podSpec)
EnsurePodInfoMount(mainContainer) EnsurePodInfoMount(mainContainer)
// GMS restore sidecars (server + loader) are only needed when the checkpoint
// is ready and the pod will actually be CRIU-restored.
if info.Ready && info.GPUMemoryService != nil && info.GPUMemoryService.Enabled { if info.Ready && info.GPUMemoryService != nil && info.GPUMemoryService.Enabled {
if len(mainContainer.Resources.Claims) == 0 { storage, err := snapshotprotocol.DiscoverAndResolveStorage(
return fmt.Errorf("gms sidecars require main container resource claims")
}
storage, err := ResolveGMSCheckpointStorage(
ctx, ctx,
reader, reader,
namespace, namespace,
...@@ -111,8 +109,7 @@ func InjectCheckpointIntoPodSpec( ...@@ -111,8 +109,7 @@ func InjectCheckpointIntoPodSpec(
if err != nil { if err != nil {
return err return err
} }
gmsSidecars := BuildGMSRestoreSidecars(podSpec, mainContainer, storage) EnsureGMSRestoreSidecars(podSpec, mainContainer, storage)
podSpec.Containers = append(podSpec.Containers, gmsSidecars...)
} }
return nil return nil
......
...@@ -6,17 +6,18 @@ package controller ...@@ -6,17 +6,18 @@ package controller
import ( import (
"context" "context"
"fmt" "fmt"
"strings"
configv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/config/v1alpha1" configv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/config/v1alpha1"
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1" nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint" "github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/discovery" "github.com/ai-dynamo/dynamo/deploy/operator/internal/discovery"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dra"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo" "github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol" snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
batchv1 "k8s.io/api/batch/v1" batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
ctrlclient "sigs.k8s.io/controller-runtime/pkg/client" ctrlclient "sigs.k8s.io/controller-runtime/pkg/client"
) )
...@@ -81,10 +82,10 @@ func buildCheckpointJob( ...@@ -81,10 +82,10 @@ func buildCheckpointJob(
checkpoint.EnsurePodInfoVolume(&podTemplate.Spec) checkpoint.EnsurePodInfoVolume(&podTemplate.Spec)
mainContainer, err := snapshotprotocol.ResolveCheckpointWorkerContainer(&podTemplate.Spec) if len(podTemplate.Spec.Containers) == 0 {
if err != nil { return nil, fmt.Errorf("checkpoint job requires at least one container")
return nil, err
} }
mainContainer := &podTemplate.Spec.Containers[0]
mainContainer.Env = dynamo.MergeEnvs( mainContainer.Env = dynamo.MergeEnvs(
buildCheckpointWorkerDefaultEnv(ckpt, podTemplate), buildCheckpointWorkerDefaultEnv(ckpt, podTemplate),
mainContainer.Env, mainContainer.Env,
...@@ -105,12 +106,17 @@ func buildCheckpointJob( ...@@ -105,12 +106,17 @@ func buildCheckpointJob(
} }
mainContainer.LivenessProbe = nil mainContainer.LivenessProbe = nil
mainContainer.StartupProbe = nil mainContainer.StartupProbe = nil
// The snapshot agent sends SIGUSR1 to PID 1 of the main container after
checkpoint.EnsurePodInfoMount(mainContainer) checkpoint.EnsurePodInfoMount(mainContainer)
dynamo.ApplySharedMemoryVolumeAndMount(&podTemplate.Spec, mainContainer, ckpt.Spec.Job.SharedMemory) dynamo.ApplySharedMemoryVolumeAndMount(&podTemplate.Spec, mainContainer, ckpt.Spec.Job.SharedMemory)
var gmsSidecars []corev1.Container
if ckpt.Spec.GPUMemoryService != nil && ckpt.Spec.GPUMemoryService.Enabled { if ckpt.Spec.GPUMemoryService != nil && ckpt.Spec.GPUMemoryService.Enabled {
storage, err := checkpoint.ResolveGMSCheckpointStorage( claimTemplateName := dra.ResourceClaimTemplateName("checkpoint-"+hash, "worker")
if err := dra.ApplyClaim(&podTemplate.Spec, claimTemplateName); err != nil {
return nil, fmt.Errorf("failed to apply DRA claim for GMS checkpoint: %w", err)
}
storage, err := snapshotprotocol.DiscoverAndResolveStorage(
ctx, ctx,
reader, reader,
ckpt.Namespace, ckpt.Namespace,
...@@ -120,12 +126,13 @@ func buildCheckpointJob( ...@@ -120,12 +126,13 @@ func buildCheckpointJob(
if err != nil { if err != nil {
return nil, err return nil, err
} }
gmsSidecars, err = checkpoint.BuildGMSCheckpointJobSidecars(&podTemplate.Spec, mainContainer, storage) if err := checkpoint.EnsureGMSCheckpointJobSidecars(&podTemplate.Spec, mainContainer, storage); err != nil {
if err != nil {
return nil, err return nil, err
} }
// Re-acquire pointer: append in EnsureGMSCheckpointJobSidecars may
// have reallocated the Containers slice.
mainContainer = &podTemplate.Spec.Containers[0]
} }
podTemplate.Spec.Containers = append(podTemplate.Spec.Containers, gmsSidecars...)
activeDeadlineSeconds := ckpt.Spec.Job.ActiveDeadlineSeconds activeDeadlineSeconds := ckpt.Spec.Job.ActiveDeadlineSeconds
if activeDeadlineSeconds == nil { if activeDeadlineSeconds == nil {
...@@ -133,10 +140,29 @@ func buildCheckpointJob( ...@@ -133,10 +140,29 @@ func buildCheckpointJob(
activeDeadlineSeconds = &defaultDeadline activeDeadlineSeconds = &defaultDeadline
} }
wrapLaunchJob := false // Wrap with cuda-checkpoint --launch-job for multi-GPU jobs (TP*PP > 1).
if gpus, ok := mainContainer.Resources.Limits[corev1.ResourceName(consts.KubeResourceGPUNvidia)]; ok { // Use checkpoint identity (not container limits) because DRA may have
wrapLaunchJob = gpus.Cmp(*resource.NewQuantity(1, resource.DecimalSI)) > 0 // already removed nvidia.com/gpu from the template.
tp := ckpt.Spec.Identity.TensorParallelSize
pp := ckpt.Spec.Identity.PipelineParallelSize
if tp == 0 {
tp = 1
} }
if pp == 0 {
pp = 1
}
wrapLaunchJob := tp*pp > 1
// For single-GPU jobs (no cuda-checkpoint wrapper), unwrap /bin/sh -c so
// the actual process is PID 1 and receives SIGUSR1 from the snapshot agent.
if !wrapLaunchJob && len(mainContainer.Command) >= 2 &&
mainContainer.Command[len(mainContainer.Command)-1] == "-c" &&
len(mainContainer.Args) == 1 {
parts := strings.Fields(mainContainer.Args[0])
mainContainer.Command = parts[:1]
mainContainer.Args = parts[1:]
}
ttlSecondsAfterFinish := snapshotprotocol.DefaultCheckpointJobTTLSeconds ttlSecondsAfterFinish := snapshotprotocol.DefaultCheckpointJobTTLSeconds
return snapshotprotocol.NewCheckpointJob(podTemplate, snapshotprotocol.CheckpointJobOptions{ return snapshotprotocol.NewCheckpointJob(podTemplate, snapshotprotocol.CheckpointJobOptions{
......
...@@ -26,6 +26,7 @@ import ( ...@@ -26,6 +26,7 @@ import (
coordinationv1 "k8s.io/api/coordination/v1" coordinationv1 "k8s.io/api/coordination/v1"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
rbacv1 "k8s.io/api/rbac/v1" rbacv1 "k8s.io/api/rbac/v1"
resourcev1 "k8s.io/api/resource/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors" apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/api/meta"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
...@@ -40,8 +41,10 @@ import ( ...@@ -40,8 +41,10 @@ import (
configv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/config/v1alpha1" configv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/config/v1alpha1"
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1" nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint" "github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
commonController "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common" commonController "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/discovery" "github.com/ai-dynamo/dynamo/deploy/operator/internal/discovery"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dra"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol" snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
) )
...@@ -62,6 +65,8 @@ func (r *CheckpointReconciler) GetRecorder() record.EventRecorder { ...@@ -62,6 +65,8 @@ func (r *CheckpointReconciler) GetRecorder() record.EventRecorder {
// +kubebuilder:rbac:groups=nvidia.com,resources=dynamocheckpoints/status,verbs=get;update;patch // +kubebuilder:rbac:groups=nvidia.com,resources=dynamocheckpoints/status,verbs=get;update;patch
// +kubebuilder:rbac:groups=nvidia.com,resources=dynamocheckpoints/finalizers,verbs=update // +kubebuilder:rbac:groups=nvidia.com,resources=dynamocheckpoints/finalizers,verbs=update
// +kubebuilder:rbac:groups=batch,resources=jobs,verbs=get;list;watch;create;update;patch;delete // +kubebuilder:rbac:groups=batch,resources=jobs,verbs=get;list;watch;create;update;patch;delete
// +kubebuilder:rbac:groups=resource.k8s.io,resources=resourceclaimtemplates,verbs=get;list;watch;create;update;patch;delete
// +kubebuilder:rbac:groups=resource.k8s.io,resources=deviceclasses,verbs=get
// +kubebuilder:rbac:groups=coordination.k8s.io,resources=leases,verbs=get;list;watch // +kubebuilder:rbac:groups=coordination.k8s.io,resources=leases,verbs=get;list;watch
func (r *CheckpointReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { func (r *CheckpointReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
...@@ -190,6 +195,32 @@ func (r *CheckpointReconciler) handlePending(ctx context.Context, ckpt *nvidiaco ...@@ -190,6 +195,32 @@ func (r *CheckpointReconciler) handlePending(ctx context.Context, ckpt *nvidiaco
return ctrl.Result{}, fmt.Errorf("failed to compute checkpoint identity hash: %w", err) return ctrl.Result{}, fmt.Errorf("failed to compute checkpoint identity hash: %w", err)
} }
} }
// Sync DRA ResourceClaimTemplate for GMS-enabled checkpoints.
if ckpt.Spec.GPUMemoryService != nil && ckpt.Spec.GPUMemoryService.Enabled {
if !r.RuntimeConfig.DRAEnabled {
return ctrl.Result{}, fmt.Errorf(
"GMS requires DRA (Dynamic Resource Allocation), but the resource.k8s.io API group is not available")
}
if len(ckpt.Spec.Job.PodTemplateSpec.Spec.Containers) == 0 {
return ctrl.Result{}, fmt.Errorf("checkpoint job requires at least one container for GMS")
}
gpuQty := ckpt.Spec.Job.PodTemplateSpec.Spec.Containers[0].Resources.Limits[corev1.ResourceName(consts.KubeResourceGPUNvidia)]
gpuCount := int(gpuQty.Value())
deviceClassName := ""
if ckpt.Spec.GPUMemoryService != nil {
deviceClassName = ckpt.Spec.GPUMemoryService.DeviceClassName
}
claimTemplateName := dra.ResourceClaimTemplateName("checkpoint-"+hash, "worker")
_, _, err := commonController.SyncResource(ctx, r, ckpt, func(ctx context.Context) (*resourcev1.ResourceClaimTemplate, bool, error) {
return dra.GenerateResourceClaimTemplate(ctx, r.Client, claimTemplateName, ckpt.Namespace, gpuCount, deviceClassName)
})
if err != nil {
logger.Error(err, "Failed to sync GMS ResourceClaimTemplate for checkpoint")
return ctrl.Result{}, fmt.Errorf("failed to sync GMS ResourceClaimTemplate for checkpoint: %w", err)
}
}
jobName := snapshotprotocol.GetCheckpointJobName( jobName := snapshotprotocol.GetCheckpointJobName(
hash, hash,
ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation], ckpt.Annotations[snapshotprotocol.CheckpointArtifactVersionAnnotation],
......
...@@ -26,7 +26,7 @@ import ( ...@@ -26,7 +26,7 @@ import (
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1" nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint" "github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms" gms "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol" snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -256,6 +256,8 @@ func TestBuildCheckpointJob(t *testing.T) { ...@@ -256,6 +256,8 @@ func TestBuildCheckpointJob(t *testing.T) {
assert.Equal(t, int32(0), *job.Spec.BackoffLimit) assert.Equal(t, int32(0), *job.Spec.BackoffLimit)
assert.Equal(t, int32(300), *job.Spec.TTLSecondsAfterFinished) assert.Equal(t, int32(300), *job.Spec.TTLSecondsAfterFinished)
// Multi-GPU: wrapping decision uses identity.TensorParallelSize, not container GPU limits.
ckpt.Spec.Identity.TensorParallelSize = 2
ckpt.Spec.Job.PodTemplateSpec.Spec.Containers[0].Resources = corev1.ResourceRequirements{ ckpt.Spec.Job.PodTemplateSpec.Spec.Containers[0].Resources = corev1.ResourceRequirements{
Limits: corev1.ResourceList{ Limits: corev1.ResourceList{
corev1.ResourceName("nvidia.com/gpu"): resource.MustParse("2"), corev1.ResourceName("nvidia.com/gpu"): resource.MustParse("2"),
...@@ -267,16 +269,11 @@ func TestBuildCheckpointJob(t *testing.T) { ...@@ -267,16 +269,11 @@ func TestBuildCheckpointJob(t *testing.T) {
assert.Equal(t, []string{"--launch-job", "python3", "-m", "dynamo.vllm"}, job.Spec.Template.Spec.Containers[0].Args) assert.Equal(t, []string{"--launch-job", "python3", "-m", "dynamo.vllm"}, job.Spec.Template.Spec.Containers[0].Args)
} }
func TestBuildCheckpointJobTargetsMainContainerWhenSidecarIsFirst(t *testing.T) { func TestBuildCheckpointJobWrapsWithCudaCheckpointForMultiGPU(t *testing.T) {
s := checkpointTestScheme() s := checkpointTestScheme()
ckpt := makeTestCheckpoint(nvidiacomv1alpha1.DynamoCheckpointPhasePending) ckpt := makeTestCheckpoint(nvidiacomv1alpha1.DynamoCheckpointPhasePending)
ckpt.Spec.Identity.TensorParallelSize = 2
ckpt.Spec.Job.PodTemplateSpec.Spec.Containers = []corev1.Container{ ckpt.Spec.Job.PodTemplateSpec.Spec.Containers = []corev1.Container{
{
Name: "sidecar",
Image: "sidecar:latest",
Command: []string{"sleep"},
Args: []string{"infinity"},
},
{ {
Name: consts.MainContainerName, Name: consts.MainContainerName,
Image: "test-image:latest", Image: "test-image:latest",
...@@ -288,13 +285,19 @@ func TestBuildCheckpointJobTargetsMainContainerWhenSidecarIsFirst(t *testing.T) ...@@ -288,13 +285,19 @@ func TestBuildCheckpointJobTargetsMainContainerWhenSidecarIsFirst(t *testing.T)
}, },
}, },
}, },
{
Name: "sidecar",
Image: "sidecar:latest",
Command: []string{"sleep"},
Args: []string{"infinity"},
},
} }
r := makeCheckpointReconciler(s, ckpt) r := makeCheckpointReconciler(s, ckpt)
job, err := buildCheckpointJob(context.Background(), nil, r.Config, ckpt, defaultCheckpointJobName) job, err := buildCheckpointJob(context.Background(), nil, r.Config, ckpt, defaultCheckpointJobName)
require.NoError(t, err) require.NoError(t, err)
main := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, consts.MainContainerName) main := &job.Spec.Template.Spec.Containers[0]
assert.Equal(t, []string{"cuda-checkpoint"}, main.Command) assert.Equal(t, []string{"cuda-checkpoint"}, main.Command)
assert.Equal(t, []string{"--launch-job", "python3", "-m", "dynamo.vllm"}, main.Args) assert.Equal(t, []string{"--launch-job", "python3", "-m", "dynamo.vllm"}, main.Args)
require.NotNil(t, main.ReadinessProbe) require.NotNil(t, main.ReadinessProbe)
...@@ -362,22 +365,22 @@ func TestBuildCheckpointJobAddsGMSSidecars(t *testing.T) { ...@@ -362,22 +365,22 @@ func TestBuildCheckpointJobAddsGMSSidecars(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
main := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, consts.MainContainerName) main := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, consts.MainContainerName)
weightsServer := requireCheckpointContainer(t, job.Spec.Template.Spec.InitContainers, gmsruntime.ServerContainerName) weightsServer := requireCheckpointContainer(t, job.Spec.Template.Spec.InitContainers, gms.ServerContainerName)
saver := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, checkpoint.GMSSaverContainer) saver := requireCheckpointContainer(t, job.Spec.Template.Spec.InitContainers, checkpoint.GMSSaverContainer)
volNames := map[string]bool{} volNames := map[string]bool{}
for _, v := range job.Spec.Template.Spec.Volumes { for _, v := range job.Spec.Template.Spec.Volumes {
volNames[v.Name] = true volNames[v.Name] = true
} }
assert.True(t, volNames[gmsruntime.SharedVolumeName]) assert.True(t, volNames[gms.SharedVolumeName])
assert.True(t, volNames[gmsruntime.ControlVolumeName]) assert.True(t, volNames[snapshotprotocol.CheckpointVolumeName])
assert.True(t, volNames[snapshotprotocol.CheckpointVolumeName]) assert.True(t, volNames[snapshotprotocol.CheckpointVolumeName])
mainMounts := map[string]string{} mainMounts := map[string]string{}
for _, m := range main.VolumeMounts { for _, m := range main.VolumeMounts {
mainMounts[m.Name] = m.MountPath mainMounts[m.Name] = m.MountPath
} }
assert.Equal(t, gmsruntime.SharedMountPath, mainMounts[gmsruntime.SharedVolumeName]) assert.Equal(t, gms.SharedMountPath, mainMounts[gms.SharedVolumeName])
assert.Equal(t, []string{"python3", "-m", "gpu_memory_service.cli.server"}, weightsServer.Command) assert.Equal(t, []string{"python3", "-m", "gpu_memory_service.cli.server"}, weightsServer.Command)
assert.Equal(t, corev1.ContainerRestartPolicyAlways, *weightsServer.RestartPolicy) assert.Equal(t, corev1.ContainerRestartPolicyAlways, *weightsServer.RestartPolicy)
......
...@@ -40,6 +40,7 @@ import ( ...@@ -40,6 +40,7 @@ import (
"github.com/ai-dynamo/dynamo/deploy/operator/internal/common" "github.com/ai-dynamo/dynamo/deploy/operator/internal/common"
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
commonController "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common" commonController "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dra"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo" "github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/observability" "github.com/ai-dynamo/dynamo/deploy/operator/internal/observability"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol" snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
...@@ -190,9 +191,11 @@ func (r *DynamoComponentDeploymentReconciler) Reconcile(ctx context.Context, req ...@@ -190,9 +191,11 @@ func (r *DynamoComponentDeploymentReconciler) Reconcile(ctx context.Context, req
if serviceName == "" { if serviceName == "" {
serviceName = dynamoComponentDeployment.Name serviceName = dynamoComponentDeployment.Name
} }
claimTemplateName := dynamo.GMSResourceClaimTemplateName(dynamoComponentDeployment.GetParentGraphDeploymentName(), serviceName) spec := &dynamoComponentDeployment.Spec.DynamoComponentDeploymentSharedSpec
gpuCount, deviceClassName := dra.ExtractGPUParams(spec.GPUMemoryService, spec.Resources)
claimTemplateName := dra.ResourceClaimTemplateName(dynamoComponentDeployment.GetParentGraphDeploymentName(), serviceName)
_, _, err = commonController.SyncResource(ctx, r, dynamoComponentDeployment, func(ctx context.Context) (*resourcev1.ResourceClaimTemplate, bool, error) { _, _, err = commonController.SyncResource(ctx, r, dynamoComponentDeployment, func(ctx context.Context) (*resourcev1.ResourceClaimTemplate, bool, error) {
return dynamo.GenerateGMSResourceClaimTemplate(ctx, r.Client, claimTemplateName, dynamoComponentDeployment.Namespace, &dynamoComponentDeployment.Spec.DynamoComponentDeploymentSharedSpec) return dra.GenerateResourceClaimTemplate(ctx, r.Client, claimTemplateName, dynamoComponentDeployment.Namespace, gpuCount, deviceClassName)
}) })
if err != nil { if err != nil {
return ctrl.Result{}, fmt.Errorf("failed to sync GMS ResourceClaimTemplate: %w", err) return ctrl.Result{}, fmt.Errorf("failed to sync GMS ResourceClaimTemplate: %w", err)
......
...@@ -29,7 +29,7 @@ import ( ...@@ -29,7 +29,7 @@ import (
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common" "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo" "github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms" gms "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol" snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"github.com/onsi/gomega" "github.com/onsi/gomega"
...@@ -1424,7 +1424,7 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe ...@@ -1424,7 +1424,7 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe
return nil return nil
} }
gmsServer := find(gmsruntime.ServerContainerName) gmsServer := find(gms.ServerContainerName)
require.NotNil(t, gmsServer) require.NotNil(t, gmsServer)
loader := find(checkpoint.GMSLoaderContainer) loader := find(checkpoint.GMSLoaderContainer)
require.NotNil(t, loader) require.NotNil(t, loader)
...@@ -1439,9 +1439,9 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe ...@@ -1439,9 +1439,9 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe
if got := gmsServer.Command; len(got) != 3 || got[0] != "python3" || got[1] != "-m" || got[2] != "gpu_memory_service.cli.server" { //nolint:goconst if got := gmsServer.Command; len(got) != 3 || got[0] != "python3" || got[1] != "-m" || got[2] != "gpu_memory_service.cli.server" { //nolint:goconst
t.Fatalf("expected weights server to run python module, got %#v", got) t.Fatalf("expected weights server to run python module, got %#v", got)
} }
// Restore: gms-server should be a regular container, not an init container // Restore: gms-server and loader are init sidecars (restartPolicy=Always)
if gmsServer.RestartPolicy != nil { if gmsServer.RestartPolicy == nil || *gmsServer.RestartPolicy != corev1.ContainerRestartPolicyAlways {
t.Fatalf("expected restore gms-server to have no RestartPolicy (regular container), got %#v", gmsServer.RestartPolicy) t.Fatalf("expected restore gms-server to have RestartPolicy=Always, got %#v", gmsServer.RestartPolicy)
} }
if gmsServer.StartupProbe != nil { if gmsServer.StartupProbe != nil {
t.Fatalf("expected restore gms-server to have no StartupProbe") t.Fatalf("expected restore gms-server to have no StartupProbe")
...@@ -1487,8 +1487,11 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe ...@@ -1487,8 +1487,11 @@ func TestDynamoComponentDeploymentReconciler_generatePodTemplateSpec_RestoreLabe
t.Fatalf("generatePodTemplateSpec failed: %v", err) t.Fatalf("generatePodTemplateSpec failed: %v", err)
} }
// User's extra sidecar should remain in Containers, unchanged.
// GMS loader is now an init sidecar, so the user's container stays
// at Containers[0] and main at Containers[1].
if got := podTemplateSpec.Spec.Containers[0]; got.Name != "gms-loader" || len(got.Command) != 1 || got.Command[0] != "python3" { if got := podTemplateSpec.Spec.Containers[0]; got.Name != "gms-loader" || len(got.Command) != 1 || got.Command[0] != "python3" {
t.Fatalf("expected sidecar container to remain unchanged, got %#v", got) t.Fatalf("expected user sidecar container to remain unchanged, got %#v", got)
} }
if got := podTemplateSpec.Spec.Containers[1]; got.Name != commonconsts.MainContainerName || len(got.Command) != 2 || got.Command[0] != "sleep" || got.Command[1] != "infinity" { if got := podTemplateSpec.Spec.Containers[1]; got.Name != commonconsts.MainContainerName || len(got.Command) != 2 || got.Command[0] != "sleep" || got.Command[1] != "infinity" {
t.Fatalf("expected main container to be rewritten for restore, got %#v", got) t.Fatalf("expected main container to be rewritten for restore, got %#v", got)
......
...@@ -54,6 +54,7 @@ import ( ...@@ -54,6 +54,7 @@ import (
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1" nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
commoncontroller "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common" commoncontroller "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dra"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo" "github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo/epp" "github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo/epp"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/observability" "github.com/ai-dynamo/dynamo/deploy/operator/internal/observability"
...@@ -664,10 +665,10 @@ func (r *DynamoGraphDeploymentReconciler) reconcileGroveResources(ctx context.Co ...@@ -664,10 +665,10 @@ func (r *DynamoGraphDeploymentReconciler) reconcileGroveResources(ctx context.Co
// Sync ResourceClaimTemplates for GMS-enabled components before creating pods. // Sync ResourceClaimTemplates for GMS-enabled components before creating pods.
if r.RuntimeConfig.DRAEnabled { if r.RuntimeConfig.DRAEnabled {
for serviceName, component := range dynamoDeployment.Spec.Services { for serviceName, component := range dynamoDeployment.Spec.Services {
svcComponent := component gpuCount, deviceClassName := dra.ExtractGPUParams(component.GPUMemoryService, component.Resources)
claimTemplateName := dynamo.GMSResourceClaimTemplateName(dynamoDeployment.Name, serviceName) claimTemplateName := dra.ResourceClaimTemplateName(dynamoDeployment.Name, serviceName)
_, _, err := commoncontroller.SyncResource(ctx, r, dynamoDeployment, func(ctx context.Context) (*resourcev1.ResourceClaimTemplate, bool, error) { _, _, err := commoncontroller.SyncResource(ctx, r, dynamoDeployment, func(ctx context.Context) (*resourcev1.ResourceClaimTemplate, bool, error) {
return dynamo.GenerateGMSResourceClaimTemplate(ctx, r.Client, claimTemplateName, dynamoDeployment.Namespace, svcComponent) return dra.GenerateResourceClaimTemplate(ctx, r.Client, claimTemplateName, dynamoDeployment.Namespace, gpuCount, deviceClassName)
}) })
if err != nil { if err != nil {
logger.Error(err, "failed to sync GMS ResourceClaimTemplate", "service", serviceName) logger.Error(err, "failed to sync GMS ResourceClaimTemplate", "service", serviceName)
...@@ -1419,10 +1420,14 @@ func (r *DynamoGraphDeploymentReconciler) buildCheckpointJobPodTemplate( ...@@ -1419,10 +1420,14 @@ func (r *DynamoGraphDeploymentReconciler) buildCheckpointJobPodTemplate(
return corev1.PodTemplateSpec{}, err return corev1.PodTemplateSpec{}, err
} }
// Create a copy of the component spec without checkpoint config // Create a copy of the component spec stripped of features that buildCheckpointJob
// The checkpoint job is CREATING the checkpoint, not restoring from one // or the checkpoint controller handle independently. GenerateBasePodSpec would
// otherwise apply DGD-specific transforms (DRA claims, GMS server sidecar,
// frontend sidecar) that conflict with the checkpoint path's own setup.
componentForJob := component.DeepCopy() componentForJob := component.DeepCopy()
componentForJob.Checkpoint = nil componentForJob.Checkpoint = nil
componentForJob.GPUMemoryService = nil
componentForJob.FrontendSidecar = nil
// Ensure DYN_NAMESPACE is set for checkpoint job using the same logic as regular pods // Ensure DYN_NAMESPACE is set for checkpoint job using the same logic as regular pods
// This is required for service discovery and distributed coordination // This is required for service discovery and distributed coordination
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
package dynamo package dra
import ( import (
"context" "context"
...@@ -13,7 +13,6 @@ import ( ...@@ -13,7 +13,6 @@ import (
"github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1" "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
resourcev1 "k8s.io/api/resource/v1" resourcev1 "k8s.io/api/resource/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors" apierrors "k8s.io/apimachinery/pkg/api/errors"
...@@ -23,147 +22,91 @@ import ( ...@@ -23,147 +22,91 @@ import (
) )
const ( const (
// ClaimName is the pod-level DRA ResourceClaim name for shared GPU access.
ClaimName = "intrapod-shared-gpu"
defaultDeviceClassName = "gpu.nvidia.com" defaultDeviceClassName = "gpu.nvidia.com"
) )
// IsGMSEnabled reports whether GPU Memory Service is requested for the component. // ApplyClaim replaces the first container's nvidia.com/gpu resources with a
func IsGMSEnabled(component *v1alpha1.DynamoComponentDeploymentSharedSpec) bool { // shared DRA ResourceClaim. Every container that references this claim name
return component.GPUMemoryService != nil && component.GPUMemoryService.Enabled // will share the same physical GPUs. The function is idempotent — calling it
} // on a pod that already has the claim is a no-op.
func ApplyClaim(podSpec *corev1.PodSpec, claimTemplateName string) error {
// getGPUCount extracts the GPU count from the component resource spec.
func getGPUCount(component *v1alpha1.DynamoComponentDeploymentSharedSpec) (int, error) {
if component.Resources == nil {
return 0, fmt.Errorf("resources must be specified when GPU memory service is enabled")
}
gpuStr := ""
if component.Resources.Limits != nil && component.Resources.Limits.GPU != "" {
gpuStr = component.Resources.Limits.GPU
} else if component.Resources.Requests != nil && component.Resources.Requests.GPU != "" {
gpuStr = component.Resources.Requests.GPU
}
if gpuStr == "" {
return 0, fmt.Errorf("GPU count must be specified when GPU memory service is enabled")
}
count, err := strconv.Atoi(gpuStr)
if err != nil {
return 0, fmt.Errorf("invalid GPU count %q: %w", gpuStr, err)
}
if count <= 0 {
return 0, fmt.Errorf("GPU count must be greater than 0 when GPU memory service is enabled")
}
return count, nil
}
// getDeviceClassName returns the DRA DeviceClass name for the component.
// It reads from GPUMemoryServiceSpec.DeviceClassName, falling back to the default.
func getDeviceClassName(component *v1alpha1.DynamoComponentDeploymentSharedSpec) string {
if component.GPUMemoryService != nil && component.GPUMemoryService.DeviceClassName != "" {
return component.GPUMemoryService.DeviceClassName
}
return defaultDeviceClassName
}
// resolveMainContainer finds the container named "main" in the pod spec.
// Falls back to Containers[0] when there is no container named "main"
// (e.g. failover pods with engine-0/engine-1 naming).
func resolveMainContainer(podSpec *corev1.PodSpec) (*corev1.Container, error) {
if len(podSpec.Containers) == 0 { if len(podSpec.Containers) == 0 {
return nil, fmt.Errorf("pod spec must have at least one container for GPU memory service") return fmt.Errorf("pod spec must have at least one container for DRA claim")
}
for i := range podSpec.Containers {
if podSpec.Containers[i].Name == commonconsts.MainContainerName {
return &podSpec.Containers[i], nil
}
} }
return &podSpec.Containers[0], nil
}
// ApplyGPUMemoryService transforms a pod spec to include GMS server sidecars // Skip if the pod-level claim already exists (idempotent).
// with DRA shared GPU access. The main container's GPU resources are replaced for i := range podSpec.ResourceClaims {
// with a DRA ResourceClaim. if podSpec.ResourceClaims[i].Name == ClaimName {
func ApplyGPUMemoryService( return nil
podSpec *corev1.PodSpec,
component *v1alpha1.DynamoComponentDeploymentSharedSpec,
claimTemplateName string,
) error {
gpuCount, err := getGPUCount(component)
if err != nil {
return err
} }
_ = gpuCount // GPU count is used for DRA claim template; sidecar discovers devices via pynvml
mainContainer, err := resolveMainContainer(podSpec)
if err != nil {
return err
} }
// Replace GPU resources with DRA claim on main container // Replace nvidia.com/gpu with the shared DRA claim.
removeGPUResources(mainContainer) gpuResource := corev1.ResourceName(commonconsts.KubeResourceGPUNvidia)
mainContainer.Resources.Claims = append(mainContainer.Resources.Claims, corev1.ResourceClaim{ delete(podSpec.Containers[0].Resources.Limits, gpuResource)
Name: gmsruntime.DRAClaimName, delete(podSpec.Containers[0].Resources.Requests, gpuResource)
podSpec.Containers[0].Resources.Claims = append(podSpec.Containers[0].Resources.Claims, corev1.ResourceClaim{
Name: ClaimName,
}) })
// Add GMS server sidecar, shared volume, and socket env vars. // GPU nodes are typically tainted with nvidia.com/gpu=NoSchedule. DRA
// The sidecar gets DRA claims copied from main automatically. // bypasses the device-plugin toleration injection, so add it explicitly.
gmsruntime.EnsureServerSidecar(podSpec, mainContainer)
// GPU nodes are typically tainted with nvidia.com/gpu=NoSchedule. With
// traditional scheduling the device-plugin injects the matching toleration,
// but DRA bypasses that path. Re-add the toleration explicitly so the pod
// can schedule on GPU nodes.
podSpec.Tolerations = append(podSpec.Tolerations, corev1.Toleration{ podSpec.Tolerations = append(podSpec.Tolerations, corev1.Toleration{
Key: commonconsts.KubeResourceGPUNvidia, Key: commonconsts.KubeResourceGPUNvidia,
Operator: corev1.TolerationOpExists, Operator: corev1.TolerationOpExists,
Effect: corev1.TaintEffectNoSchedule, Effect: corev1.TaintEffectNoSchedule,
}) })
// Add pod-level DRA resource claim referencing the ResourceClaimTemplate
podSpec.ResourceClaims = append(podSpec.ResourceClaims, corev1.PodResourceClaim{ podSpec.ResourceClaims = append(podSpec.ResourceClaims, corev1.PodResourceClaim{
Name: gmsruntime.DRAClaimName, Name: ClaimName,
ResourceClaimTemplateName: &claimTemplateName, ResourceClaimTemplateName: &claimTemplateName,
}) })
return nil return nil
} }
// removeGPUResources strips nvidia.com/gpu from container resource limits and requests. // ResourceClaimTemplateName returns the deterministic name for the
// GPU allocation is handled by DRA when GMS is enabled. // ResourceClaimTemplate associated with a component.
func removeGPUResources(container *corev1.Container) { func ResourceClaimTemplateName(parentName, serviceName string) string {
gpuResource := corev1.ResourceName(commonconsts.KubeResourceGPUNvidia) return fmt.Sprintf("%s-%s-gpu", parentName, strings.ToLower(serviceName))
if container.Resources.Limits != nil {
delete(container.Resources.Limits, gpuResource)
}
if container.Resources.Requests != nil {
delete(container.Resources.Requests, gpuResource)
}
} }
// GMSResourceClaimTemplateName returns the deterministic name for the // ExtractGPUParams extracts the GPU count and device class name from API types
// ResourceClaimTemplate associated with a GMS-enabled component. // shared by DGD components and DynamoCheckpoint specs. Returns gpuCount=0 when
func GMSResourceClaimTemplateName(parentName, serviceName string) string { // GMS is not enabled, which tells GenerateResourceClaimTemplate to delete.
return fmt.Sprintf("%s-%s-gpu", parentName, strings.ToLower(serviceName)) func ExtractGPUParams(gmsSpec *v1alpha1.GPUMemoryServiceSpec, resources *v1alpha1.Resources) (gpuCount int, deviceClassName string) {
if gmsSpec == nil || !gmsSpec.Enabled {
return 0, ""
}
deviceClassName = gmsSpec.DeviceClassName
if resources != nil {
gpuStr := ""
if resources.Limits != nil {
gpuStr = resources.Limits.GPU
}
if gpuStr == "" && resources.Requests != nil {
gpuStr = resources.Requests.GPU
}
gpuCount, _ = strconv.Atoi(gpuStr)
}
return gpuCount, deviceClassName
} }
// GenerateGMSResourceClaimTemplate builds the ResourceClaimTemplate that // GenerateResourceClaimTemplate builds the ResourceClaimTemplate that provides
// provides shared GPU access to all containers in a GMS-enabled pod via DRA. // shared GPU access to all containers in a pod via DRA.
//
// claimTemplateName is the deterministic name for the template; callers should
// compute it via GMSResourceClaimTemplateName.
// //
// When GMS is not enabled for the component, it returns the template skeleton // When gpuCount <= 0 it returns the template skeleton with toDelete=true so
// with toDelete=true so that SyncResource cleans up any previously created template. // that SyncResource cleans up any previously created template. Pass cl=nil to
// // skip the DeviceClass existence check.
// The cl parameter is used to verify the DeviceClass exists before creating func GenerateResourceClaimTemplate(
// the template. Pass nil to skip the DeviceClass check.
func GenerateGMSResourceClaimTemplate(
ctx context.Context, ctx context.Context,
cl client.Client, cl client.Client,
claimTemplateName, namespace string, claimTemplateName, namespace string,
component *v1alpha1.DynamoComponentDeploymentSharedSpec, gpuCount int,
deviceClassName string,
) (*resourcev1.ResourceClaimTemplate, bool, error) { ) (*resourcev1.ResourceClaimTemplate, bool, error) {
template := &resourcev1.ResourceClaimTemplate{ template := &resourcev1.ResourceClaimTemplate{
ObjectMeta: metav1.ObjectMeta{ ObjectMeta: metav1.ObjectMeta{
...@@ -172,18 +115,14 @@ func GenerateGMSResourceClaimTemplate( ...@@ -172,18 +115,14 @@ func GenerateGMSResourceClaimTemplate(
}, },
} }
if !IsGMSEnabled(component) { if gpuCount <= 0 {
return template, true, nil return template, true, nil
} }
gpuCount, err := getGPUCount(component) if deviceClassName == "" {
if err != nil { deviceClassName = defaultDeviceClassName
return nil, false, fmt.Errorf("failed to get GPU count for ResourceClaimTemplate: %w", err)
} }
deviceClassName := getDeviceClassName(component)
// Verify the DeviceClass exists before creating the template
if cl != nil { if cl != nil {
dc := &resourcev1.DeviceClass{} dc := &resourcev1.DeviceClass{}
if err := cl.Get(ctx, types.NamespacedName{Name: deviceClassName}, dc); err != nil { if err := cl.Get(ctx, types.NamespacedName{Name: deviceClassName}, dc); err != nil {
......
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package dra
import (
"context"
"strconv"
"testing"
"github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
"k8s.io/apimachinery/pkg/util/intstr"
)
func basePodSpec() corev1.PodSpec {
httpPort := intstr.FromString("system")
return corev1.PodSpec{
Containers: []corev1.Container{{
Name: "main",
Image: "test-image:latest",
Command: []string{"python3", "-m", "dynamo.vllm"},
Env: []corev1.EnvVar{
{Name: "DYN_SYSTEM_PORT", Value: "9090"},
},
Ports: []corev1.ContainerPort{
{Name: "system", ContainerPort: 9090, Protocol: corev1.ProtocolTCP},
},
StartupProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{Path: "/health", Port: httpPort},
},
},
Resources: corev1.ResourceRequirements{
Limits: corev1.ResourceList{
corev1.ResourceName(commonconsts.KubeResourceGPUNvidia): resource.MustParse("2"),
},
},
}},
}
}
func TestApplyClaim_EmptyContainers(t *testing.T) {
ps := corev1.PodSpec{}
err := ApplyClaim(&ps, "myapp-worker-gpu")
require.Error(t, err)
assert.Contains(t, err.Error(), "at least one container")
}
func TestApplyClaim_ReplacesGPUWithDRAClaim(t *testing.T) {
ps := basePodSpec()
err := ApplyClaim(&ps, "myapp-worker-gpu")
require.NoError(t, err)
main := ps.Containers[0]
gpuResource := corev1.ResourceName(commonconsts.KubeResourceGPUNvidia)
_, hasGPU := main.Resources.Limits[gpuResource]
assert.False(t, hasGPU)
require.Len(t, main.Resources.Claims, 1)
assert.Equal(t, ClaimName, main.Resources.Claims[0].Name)
require.Len(t, ps.ResourceClaims, 1)
assert.Equal(t, ClaimName, ps.ResourceClaims[0].Name)
assert.Equal(t, "myapp-worker-gpu", *ps.ResourceClaims[0].ResourceClaimTemplateName)
var hasToleration bool
for _, tol := range ps.Tolerations {
if tol.Key == commonconsts.KubeResourceGPUNvidia && tol.Effect == corev1.TaintEffectNoSchedule {
hasToleration = true
}
}
assert.True(t, hasToleration)
assert.Empty(t, ps.InitContainers)
}
func TestApplyClaim_AlwaysTargetsFirstContainer(t *testing.T) {
ps := basePodSpec()
ps.Containers = append(ps.Containers, corev1.Container{Name: "sidecar", Image: "sidecar:latest"})
err := ApplyClaim(&ps, "myapp-worker-gpu")
require.NoError(t, err)
require.Len(t, ps.Containers[0].Resources.Claims, 1)
assert.Equal(t, ClaimName, ps.Containers[0].Resources.Claims[0].Name)
assert.Empty(t, ps.Containers[1].Resources.Claims)
}
func TestGenerateResourceClaimTemplate_Enabled(t *testing.T) {
tmpl, toDelete, err := GenerateResourceClaimTemplate(context.Background(), nil, "myapp-worker-gpu", "default", 4, "")
require.NoError(t, err)
assert.False(t, toDelete)
assert.Equal(t, "myapp-worker-gpu", tmpl.Name)
require.Len(t, tmpl.Spec.Spec.Devices.Requests, 1)
req := tmpl.Spec.Spec.Devices.Requests[0]
assert.Equal(t, defaultDeviceClassName, req.Exactly.DeviceClassName)
assert.Equal(t, int64(4), req.Exactly.Count)
}
func TestGenerateResourceClaimTemplate_CustomDeviceClass(t *testing.T) {
tmpl, _, err := GenerateResourceClaimTemplate(context.Background(), nil, "myapp-worker-gpu", "default", 2, "gpu.intel.com/xe")
require.NoError(t, err)
assert.Equal(t, "gpu.intel.com/xe", tmpl.Spec.Spec.Devices.Requests[0].Exactly.DeviceClassName)
}
func TestGenerateResourceClaimTemplate_DisabledReturnsDelete(t *testing.T) {
tmpl, toDelete, err := GenerateResourceClaimTemplate(context.Background(), nil, "myapp-worker-gpu", "default", 0, "")
require.NoError(t, err)
assert.True(t, toDelete)
assert.Equal(t, "myapp-worker-gpu", tmpl.Name)
}
func TestResourceClaimTemplateName(t *testing.T) {
assert.Equal(t, "myapp-worker-gpu", ResourceClaimTemplateName("myapp", "Worker"))
assert.Equal(t, "app-vllmdecodeworker-gpu", ResourceClaimTemplateName("app", "VllmDecodeWorker"))
}
func TestExtractGPUParams(t *testing.T) {
count, dc := ExtractGPUParams(nil, nil)
assert.Equal(t, 0, count)
assert.Equal(t, "", dc)
count, dc = ExtractGPUParams(
&v1alpha1.GPUMemoryServiceSpec{Enabled: true},
&v1alpha1.Resources{Limits: &v1alpha1.ResourceItem{GPU: strconv.Itoa(4)}},
)
assert.Equal(t, 4, count)
assert.Equal(t, "", dc)
count, dc = ExtractGPUParams(
&v1alpha1.GPUMemoryServiceSpec{Enabled: true, DeviceClassName: "gpu.intel.com/xe"},
&v1alpha1.Resources{Requests: &v1alpha1.ResourceItem{GPU: "2"}},
)
assert.Equal(t, 2, count)
assert.Equal(t, "gpu.intel.com/xe", dc)
}
...@@ -12,7 +12,8 @@ import ( ...@@ -12,7 +12,8 @@ import (
"github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1" "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms" "github.com/ai-dynamo/dynamo/deploy/operator/internal/dra"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
...@@ -36,7 +37,7 @@ func failoverPodSpec() corev1.PodSpec { ...@@ -36,7 +37,7 @@ func failoverPodSpec() corev1.PodSpec {
{Name: "DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS", Value: "true"}, {Name: "DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS", Value: "true"},
{Name: "DYN_HEALTH_CHECK_ENABLED", Value: "true"}, {Name: "DYN_HEALTH_CHECK_ENABLED", Value: "true"},
{Name: commonconsts.DynamoDiscoveryBackendEnvVar, Value: "kubernetes"}, {Name: commonconsts.DynamoDiscoveryBackendEnvVar, Value: "kubernetes"},
{Name: "TMPDIR", Value: gmsruntime.SharedMountPath}, {Name: "TMPDIR", Value: gms.SharedMountPath},
}, },
Ports: []corev1.ContainerPort{ Ports: []corev1.ContainerPort{
{Name: "system", ContainerPort: 9090, Protocol: corev1.ProtocolTCP}, {Name: "system", ContainerPort: 9090, Protocol: corev1.ProtocolTCP},
...@@ -57,10 +58,10 @@ func failoverPodSpec() corev1.PodSpec { ...@@ -57,10 +58,10 @@ func failoverPodSpec() corev1.PodSpec {
}, },
}, },
Resources: corev1.ResourceRequirements{ Resources: corev1.ResourceRequirements{
Claims: []corev1.ResourceClaim{{Name: gmsruntime.DRAClaimName}}, Claims: []corev1.ResourceClaim{{Name: dra.ClaimName}},
}, },
VolumeMounts: []corev1.VolumeMount{ VolumeMounts: []corev1.VolumeMount{
{Name: gmsruntime.SharedVolumeName, MountPath: gmsruntime.SharedMountPath}, {Name: gms.SharedVolumeName, MountPath: gms.SharedMountPath},
}, },
}, },
{ {
...@@ -165,7 +166,7 @@ func TestBuildFailoverPod_PreservesDRAClaim(t *testing.T) { ...@@ -165,7 +166,7 @@ func TestBuildFailoverPod_PreservesDRAClaim(t *testing.T) {
for i := range 2 { for i := range 2 {
engine := ps.Containers[i] engine := ps.Containers[i]
require.Len(t, engine.Resources.Claims, 1, "engine-%d should retain DRA claim", i) require.Len(t, engine.Resources.Claims, 1, "engine-%d should retain DRA claim", i)
assert.Equal(t, gmsruntime.DRAClaimName, engine.Resources.Claims[0].Name) assert.Equal(t, dra.ClaimName, engine.Resources.Claims[0].Name)
} }
} }
...@@ -214,3 +215,11 @@ func TestIsFailoverEnabled(t *testing.T) { ...@@ -214,3 +215,11 @@ func TestIsFailoverEnabled(t *testing.T) {
})) }))
assert.False(t, isFailoverEnabled(&v1alpha1.DynamoComponentDeploymentSharedSpec{})) assert.False(t, isFailoverEnabled(&v1alpha1.DynamoComponentDeploymentSharedSpec{}))
} }
func envToMap(envs []corev1.EnvVar) map[string]string {
m := make(map[string]string, len(envs))
for _, e := range envs {
m[e.Name] = e.Value
}
return m
}
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package dynamo
import (
"context"
"strconv"
"testing"
"github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
"k8s.io/apimachinery/pkg/util/intstr"
)
func gmsComponent(gpuCount int) *v1alpha1.DynamoComponentDeploymentSharedSpec {
return &v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: commonconsts.ComponentTypeWorker,
GPUMemoryService: &v1alpha1.GPUMemoryServiceSpec{Enabled: true},
Resources: &v1alpha1.Resources{
Limits: &v1alpha1.ResourceItem{GPU: strconv.Itoa(gpuCount)},
},
}
}
func gmsBasePodSpec() corev1.PodSpec {
httpPort := intstr.FromString("system")
return corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "main",
Image: "test-image:latest",
Command: []string{"python3", "-m", "dynamo.vllm"},
Env: []corev1.EnvVar{
{Name: "DYN_SYSTEM_PORT", Value: "9090"},
{Name: commonconsts.DynamoDiscoveryBackendEnvVar, Value: "kubernetes"},
},
Ports: []corev1.ContainerPort{
{Name: "system", ContainerPort: 9090, Protocol: corev1.ProtocolTCP},
},
StartupProbe: &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{Path: "/health", Port: httpPort},
},
},
Resources: corev1.ResourceRequirements{
Limits: corev1.ResourceList{
corev1.ResourceName(commonconsts.KubeResourceGPUNvidia): resource.MustParse("2"),
},
},
},
},
}
}
// --- applyGPUMemoryService ---
func TestApplyGPUMemoryService_EmptyContainers(t *testing.T) {
ps := corev1.PodSpec{}
err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.Error(t, err)
assert.Contains(t, err.Error(), "at least one container")
}
func TestApplyGPUMemoryService_MainContainerTransformed(t *testing.T) {
ps := gmsBasePodSpec()
err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err)
main := ps.Containers[0]
// GPU resources should be removed
gpuResource := corev1.ResourceName(commonconsts.KubeResourceGPUNvidia)
_, hasGPU := main.Resources.Limits[gpuResource]
assert.False(t, hasGPU, "main container should not have GPU limits")
// Should have DRA claim
require.Len(t, main.Resources.Claims, 1)
assert.Equal(t, gmsruntime.DRAClaimName, main.Resources.Claims[0].Name)
// Should have shared volume mount
var hasSharedMount bool
for _, vm := range main.VolumeMounts {
if vm.Name == gmsruntime.SharedVolumeName && vm.MountPath == gmsruntime.SharedMountPath {
hasSharedMount = true
}
}
assert.True(t, hasSharedMount, "main container should have gms-shared volume mount")
// Should have TMPDIR and GMS_SOCKET_DIR
envMap := envToMap(main.Env)
assert.Equal(t, gmsruntime.SharedMountPath, envMap["TMPDIR"])
assert.Equal(t, gmsruntime.SharedMountPath, envMap["GMS_SOCKET_DIR"])
}
func TestApplyGPUMemoryService_GMSSidecarInjected(t *testing.T) {
ps := gmsBasePodSpec()
err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err)
require.Len(t, ps.InitContainers, 1)
gms := ps.InitContainers[0]
assert.Equal(t, gmsruntime.ServerContainerName, gms.Name)
assert.Equal(t, "test-image:latest", gms.Image)
assert.Equal(t, []string{"python3", "-m", "gpu_memory_service.cli.server"}, gms.Command)
assert.NotNil(t, gms.RestartPolicy)
assert.Equal(t, corev1.ContainerRestartPolicyAlways, *gms.RestartPolicy)
require.NotNil(t, gms.StartupProbe)
assert.Equal(t, int32(1), gms.StartupProbe.PeriodSeconds)
assert.Equal(t, int32(300), gms.StartupProbe.FailureThreshold)
// GMS sidecar should have DRA claim copied from main
require.Len(t, gms.Resources.Claims, 1)
assert.Equal(t, gmsruntime.DRAClaimName, gms.Resources.Claims[0].Name)
// GMS sidecar should have TMPDIR
gmsEnv := envToMap(gms.Env)
assert.Equal(t, gmsruntime.SharedMountPath, gmsEnv["TMPDIR"])
}
func TestApplyGPUMemoryService_SharedVolume(t *testing.T) {
ps := gmsBasePodSpec()
err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err)
var found bool
for _, v := range ps.Volumes {
if v.Name == gmsruntime.SharedVolumeName {
assert.NotNil(t, v.EmptyDir)
found = true
}
}
assert.True(t, found, "should have gms-shared volume")
}
func TestApplyGPUMemoryService_GPUToleration(t *testing.T) {
ps := gmsBasePodSpec()
err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err)
var found bool
for _, tol := range ps.Tolerations {
if tol.Key == commonconsts.KubeResourceGPUNvidia && tol.Effect == corev1.TaintEffectNoSchedule {
assert.Equal(t, corev1.TolerationOpExists, tol.Operator)
found = true
}
}
assert.True(t, found, "should have nvidia.com/gpu NoSchedule toleration")
}
func TestApplyGPUMemoryService_DRAResourceClaim(t *testing.T) {
ps := gmsBasePodSpec()
err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err)
require.Len(t, ps.ResourceClaims, 1)
assert.Equal(t, gmsruntime.DRAClaimName, ps.ResourceClaims[0].Name)
assert.Equal(t, "myapp-worker-gpu", *ps.ResourceClaims[0].ResourceClaimTemplateName)
}
func TestApplyGPUMemoryService_PreservesExistingEnv(t *testing.T) {
ps := gmsBasePodSpec()
err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err)
main := ps.Containers[0]
envMap := envToMap(main.Env)
assert.Equal(t, "kubernetes", envMap[commonconsts.DynamoDiscoveryBackendEnvVar])
assert.Equal(t, "9090", envMap["DYN_SYSTEM_PORT"])
}
func TestApplyGPUMemoryService_SingleContainer(t *testing.T) {
ps := gmsBasePodSpec()
err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err)
assert.Len(t, ps.Containers, 1)
assert.Equal(t, "main", ps.Containers[0].Name)
}
func TestApplyGPUMemoryService_ResolvesMainByName(t *testing.T) {
ps := gmsBasePodSpec()
// Prepend a sidecar so main is NOT Containers[0]
sidecar := corev1.Container{Name: "sidecar", Image: "sidecar:latest"}
ps.Containers = append([]corev1.Container{sidecar}, ps.Containers...)
require.Equal(t, "sidecar", ps.Containers[0].Name)
err := ApplyGPUMemoryService(&ps, gmsComponent(2), "myapp-worker-gpu")
require.NoError(t, err)
// Sidecar should be untouched
assert.Equal(t, "sidecar", ps.Containers[0].Name)
assert.Empty(t, ps.Containers[0].Resources.Claims)
// Main should have DRA claim
main := ps.Containers[1]
assert.Equal(t, "main", main.Name)
require.Len(t, main.Resources.Claims, 1)
assert.Equal(t, gmsruntime.DRAClaimName, main.Resources.Claims[0].Name)
}
// --- GenerateGMSResourceClaimTemplate ---
func TestGenerateGMSResourceClaimTemplate_Enabled(t *testing.T) {
component := gmsComponent(4)
tmpl, toDelete, err := GenerateGMSResourceClaimTemplate(context.Background(), nil, "myapp-worker-gpu", "default", component)
require.NoError(t, err)
assert.False(t, toDelete)
assert.Equal(t, "myapp-worker-gpu", tmpl.Name)
assert.Equal(t, "default", tmpl.Namespace)
require.Len(t, tmpl.Spec.Spec.Devices.Requests, 1)
req := tmpl.Spec.Spec.Devices.Requests[0]
assert.Equal(t, "gpus", req.Name)
require.NotNil(t, req.Exactly)
assert.Equal(t, defaultDeviceClassName, req.Exactly.DeviceClassName)
assert.Equal(t, int64(4), req.Exactly.Count)
}
func TestGenerateGMSResourceClaimTemplate_CustomDeviceClass(t *testing.T) {
component := gmsComponent(2)
component.GPUMemoryService.DeviceClassName = "gpu.intel.com/xe"
tmpl, toDelete, err := GenerateGMSResourceClaimTemplate(context.Background(), nil, "myapp-worker-gpu", "default", component)
require.NoError(t, err)
assert.False(t, toDelete)
assert.Equal(t, "gpu.intel.com/xe", tmpl.Spec.Spec.Devices.Requests[0].Exactly.DeviceClassName)
}
func TestGenerateGMSResourceClaimTemplate_DisabledReturnsDelete(t *testing.T) {
component := &v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: commonconsts.ComponentTypeWorker,
}
tmpl, toDelete, err := GenerateGMSResourceClaimTemplate(context.Background(), nil, "myapp-worker-gpu", "default", component)
require.NoError(t, err)
assert.True(t, toDelete)
assert.Equal(t, "myapp-worker-gpu", tmpl.Name)
}
func TestGenerateGMSResourceClaimTemplate_NoGPUCountError(t *testing.T) {
component := &v1alpha1.DynamoComponentDeploymentSharedSpec{
ComponentType: commonconsts.ComponentTypeWorker,
GPUMemoryService: &v1alpha1.GPUMemoryServiceSpec{Enabled: true},
}
_, _, err := GenerateGMSResourceClaimTemplate(context.Background(), nil, "myapp-worker-gpu", "default", component)
require.Error(t, err)
assert.Contains(t, err.Error(), "resources must be specified")
}
// --- GMSResourceClaimTemplateName ---
func TestGMSResourceClaimTemplateName(t *testing.T) {
assert.Equal(t, "myapp-worker-gpu", GMSResourceClaimTemplateName("myapp", "Worker"))
assert.Equal(t, "app-vllmdecodeworker-gpu", GMSResourceClaimTemplateName("app", "VllmDecodeWorker"))
}
// --- isGMSEnabled ---
func TestIsGMSEnabled(t *testing.T) {
assert.True(t, IsGMSEnabled(&v1alpha1.DynamoComponentDeploymentSharedSpec{
GPUMemoryService: &v1alpha1.GPUMemoryServiceSpec{Enabled: true},
}))
assert.False(t, IsGMSEnabled(&v1alpha1.DynamoComponentDeploymentSharedSpec{
GPUMemoryService: &v1alpha1.GPUMemoryServiceSpec{Enabled: false},
}))
assert.False(t, IsGMSEnabled(&v1alpha1.DynamoComponentDeploymentSharedSpec{}))
}
// --- getGPUCount ---
func TestGetGPUCount(t *testing.T) {
tests := []struct {
name string
component *v1alpha1.DynamoComponentDeploymentSharedSpec
want int
wantErr bool
}{
{
name: "from limits",
component: &v1alpha1.DynamoComponentDeploymentSharedSpec{Resources: &v1alpha1.Resources{Limits: &v1alpha1.ResourceItem{GPU: "4"}}},
want: 4,
},
{
name: "from requests",
component: &v1alpha1.DynamoComponentDeploymentSharedSpec{Resources: &v1alpha1.Resources{Requests: &v1alpha1.ResourceItem{GPU: "2"}}},
want: 2,
},
{
name: "no resources",
component: &v1alpha1.DynamoComponentDeploymentSharedSpec{},
wantErr: true,
},
{
name: "invalid GPU string",
component: &v1alpha1.DynamoComponentDeploymentSharedSpec{Resources: &v1alpha1.Resources{Limits: &v1alpha1.ResourceItem{GPU: "abc"}}},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getGPUCount(tt.component)
if tt.wantErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tt.want, got)
}
})
}
}
// --- getDeviceClassName ---
func TestGetDeviceClassName(t *testing.T) {
assert.Equal(t, defaultDeviceClassName, getDeviceClassName(&v1alpha1.DynamoComponentDeploymentSharedSpec{}))
assert.Equal(t, defaultDeviceClassName, getDeviceClassName(&v1alpha1.DynamoComponentDeploymentSharedSpec{
GPUMemoryService: &v1alpha1.GPUMemoryServiceSpec{Enabled: true},
}))
assert.Equal(t, "gpu.intel.com/xe", getDeviceClassName(&v1alpha1.DynamoComponentDeploymentSharedSpec{
GPUMemoryService: &v1alpha1.GPUMemoryServiceSpec{Enabled: true, DeviceClassName: "gpu.intel.com/xe"},
}))
}
// helpers
func envToMap(envs []corev1.EnvVar) map[string]string {
m := make(map[string]string, len(envs))
for _, e := range envs {
m[e.Name] = e.Value
}
return m
}
...@@ -38,6 +38,8 @@ import ( ...@@ -38,6 +38,8 @@ import (
commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" commonconsts "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common" "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/discovery" "github.com/ai-dynamo/dynamo/deploy/operator/internal/discovery"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dra"
gms "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
grovev1alpha1 "github.com/ai-dynamo/grove/operator/api/core/v1alpha1" grovev1alpha1 "github.com/ai-dynamo/grove/operator/api/core/v1alpha1"
"github.com/imdario/mergo" "github.com/imdario/mergo"
networkingv1beta1 "istio.io/client-go/pkg/apis/networking/v1beta1" networkingv1beta1 "istio.io/client-go/pkg/apis/networking/v1beta1"
...@@ -1182,12 +1184,13 @@ func GenerateBasePodSpec( ...@@ -1182,12 +1184,13 @@ func GenerateBasePodSpec(
} }
} }
// Inject GMS sidecar with DRA shared GPU access when GPU memory service is enabled. // GMS: replace nvidia.com/gpu with a shared DRA claim and add the server sidecar.
if IsGMSEnabled(component) { if component.GPUMemoryService != nil && component.GPUMemoryService.Enabled {
claimTemplateName := GMSResourceClaimTemplateName(parentGraphDeploymentName, serviceName) claimTemplateName := dra.ResourceClaimTemplateName(parentGraphDeploymentName, serviceName)
if err := ApplyGPUMemoryService(&podSpec, component, claimTemplateName); err != nil { if err := dra.ApplyClaim(&podSpec, claimTemplateName); err != nil {
return nil, fmt.Errorf("failed to apply GPU memory service: %w", err) return nil, fmt.Errorf("failed to apply DRA claim for GMS: %w", err)
} }
gms.EnsureServerSidecar(&podSpec, &podSpec.Containers[0])
} }
// Clone main container into two engine containers (active + standby) for failover. // Clone main container into two engine containers (active + standby) for failover.
......
...@@ -3,11 +3,14 @@ ...@@ -3,11 +3,14 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
// Package gms provides GMS (GPU Memory Service) server container building
// for both steady-state DGD pods and checkpoint/restore flows.
package gms package gms
import ( import (
"path/filepath" "path/filepath"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dra"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
"k8s.io/utils/ptr" "k8s.io/utils/ptr"
) )
...@@ -23,32 +26,25 @@ const ( ...@@ -23,32 +26,25 @@ const (
// SharedMountPath is the mount path for the shared GMS socket directory. // SharedMountPath is the mount path for the shared GMS socket directory.
SharedMountPath = "/shared" SharedMountPath = "/shared"
// DRAClaimName is the pod-level DRA ResourceClaim name used by both the // EnvSocketDir is the environment variable name for the GMS UDS socket directory.
// main container and GMS sidecars. EnvSocketDir = "GMS_SOCKET_DIR"
DRAClaimName = "shared-gpu"
// ControlVolumeName is the checkpoint-specific control volume name. // ServerModule is the Python module for the GMS server entry point.
ControlVolumeName = "gms-control" ServerModule = "gpu_memory_service.cli.server"
// ControlDir is the mount path for the checkpoint control volume.
ControlDir = "/tmp/gms-control"
readyFile = "gms-ready" readyFile = "gms-ready"
serverSidecarModule = "gpu_memory_service.cli.server"
) )
// EnsureServerSidecar adds the GMS server as a restartable init sidecar with a // EnsureServerSidecar adds the GMS server as a restartable init sidecar with a
// startup probe. Used for checkpoint jobs and steady-state pods where the main // startup probe. Idempotent — safe to call from both the DGD and checkpoint paths.
// container needs GMS sockets before starting.
func EnsureServerSidecar(podSpec *corev1.PodSpec, mainContainer *corev1.Container) { func EnsureServerSidecar(podSpec *corev1.PodSpec, mainContainer *corev1.Container) {
if podSpec == nil || mainContainer == nil { if podSpec == nil || mainContainer == nil {
return return
} }
ensureSharedVolume(podSpec, mainContainer) EnsureSharedVolume(podSpec, mainContainer)
sidecar := serverContainer(mainContainer.Image) sidecar := Container(ServerContainerName, ServerModule, mainContainer.Image)
sidecar.RestartPolicy = ptr.To(corev1.ContainerRestartPolicyAlways) sidecar.RestartPolicy = ptr.To(corev1.ContainerRestartPolicyAlways)
sidecar.StartupProbe = &corev1.Probe{ sidecar.StartupProbe = &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{ ProbeHandler: corev1.ProbeHandler{
...@@ -59,9 +55,6 @@ func EnsureServerSidecar(podSpec *corev1.PodSpec, mainContainer *corev1.Containe ...@@ -59,9 +55,6 @@ func EnsureServerSidecar(podSpec *corev1.PodSpec, mainContainer *corev1.Containe
PeriodSeconds: 1, PeriodSeconds: 1,
FailureThreshold: 300, // 1s * 300 = 5 min FailureThreshold: 300, // 1s * 300 = 5 min
} }
copyDeviceClaims(mainContainer, &sidecar)
// Idempotent — EnsureServerSidecar may be called by both the
// steady-state operator path and the checkpoint overlay.
for i := range podSpec.InitContainers { for i := range podSpec.InitContainers {
if podSpec.InitContainers[i].Name == sidecar.Name { if podSpec.InitContainers[i].Name == sidecar.Name {
return return
...@@ -70,41 +63,9 @@ func EnsureServerSidecar(podSpec *corev1.PodSpec, mainContainer *corev1.Containe ...@@ -70,41 +63,9 @@ func EnsureServerSidecar(podSpec *corev1.PodSpec, mainContainer *corev1.Containe
podSpec.InitContainers = append(podSpec.InitContainers, sidecar) podSpec.InitContainers = append(podSpec.InitContainers, sidecar)
} }
// BuildServerContainer prepares the shared GMS volume/env and returns a GMS // EnsureSharedVolume adds the GMS UDS socket volume, mount, and GMS_SOCKET_DIR
// server container suitable for use as a regular sidecar. The caller must // env var to the main container. Idempotent.
// append the returned container to podSpec.Containers. func EnsureSharedVolume(podSpec *corev1.PodSpec, mainContainer *corev1.Container) {
//
// Used for restore pods where the main container is CRIU-restored and does not
// need GMS sockets at startup. The gms-loader polls for sockets internally.
func BuildServerContainer(podSpec *corev1.PodSpec, mainContainer *corev1.Container) corev1.Container {
ensureSharedVolume(podSpec, mainContainer)
sidecar := serverContainer(mainContainer.Image)
copyDeviceClaims(mainContainer, &sidecar)
return sidecar
}
// FindServerContainer returns a pointer to the GMS server container, checking
// both init containers and regular containers. Returns nil if not present.
func FindServerContainer(podSpec *corev1.PodSpec) *corev1.Container {
if podSpec == nil {
return nil
}
for i := range podSpec.InitContainers {
if podSpec.InitContainers[i].Name == ServerContainerName {
return &podSpec.InitContainers[i]
}
}
for i := range podSpec.Containers {
if podSpec.Containers[i].Name == ServerContainerName {
return &podSpec.Containers[i]
}
}
return nil
}
// ensureSharedVolume adds the shared GMS socket volume, mounts, and env vars.
// Idempotent — may be called by both steady-state and checkpoint paths.
func ensureSharedVolume(podSpec *corev1.PodSpec, mainContainer *corev1.Container) {
hasVolume := false hasVolume := false
for _, v := range podSpec.Volumes { for _, v := range podSpec.Volumes {
if v.Name == SharedVolumeName { if v.Name == SharedVolumeName {
...@@ -119,8 +80,6 @@ func ensureSharedVolume(podSpec *corev1.PodSpec, mainContainer *corev1.Container ...@@ -119,8 +80,6 @@ func ensureSharedVolume(podSpec *corev1.PodSpec, mainContainer *corev1.Container
}) })
} }
// Mount and env injection checked independently of volume existence —
// another code path may have added the volume without configuring main.
hasMount := false hasMount := false
for _, m := range mainContainer.VolumeMounts { for _, m := range mainContainer.VolumeMounts {
if m.Name == SharedVolumeName { if m.Name == SharedVolumeName {
...@@ -134,41 +93,31 @@ func ensureSharedVolume(podSpec *corev1.PodSpec, mainContainer *corev1.Container ...@@ -134,41 +93,31 @@ func ensureSharedVolume(podSpec *corev1.PodSpec, mainContainer *corev1.Container
hasEnv := false hasEnv := false
for _, e := range mainContainer.Env { for _, e := range mainContainer.Env {
if e.Name == "GMS_SOCKET_DIR" { if e.Name == EnvSocketDir {
hasEnv = true hasEnv = true
break break
} }
} }
if !hasEnv { if !hasEnv {
mainContainer.Env = append(mainContainer.Env, mainContainer.Env = append(mainContainer.Env, corev1.EnvVar{Name: EnvSocketDir, Value: SharedMountPath})
corev1.EnvVar{Name: "TMPDIR", Value: SharedMountPath},
corev1.EnvVar{Name: "GMS_SOCKET_DIR", Value: SharedMountPath},
)
} }
} }
// serverContainer builds the base GMS server container without init-specific // Container builds a GMS container with the shared socket volume, env, and
// fields (RestartPolicy, StartupProbe). Callers add those as needed. // DRA claim. Used for the server, loader, and saver.
func serverContainer(image string) corev1.Container { func Container(name, module, image string) corev1.Container {
return corev1.Container{ return corev1.Container{
Name: ServerContainerName, Name: name,
Image: image, Image: image,
Command: []string{"python3", "-m", serverSidecarModule}, Command: []string{"python3", "-m", module},
Env: []corev1.EnvVar{ Env: []corev1.EnvVar{
{Name: "TMPDIR", Value: SharedMountPath}, {Name: EnvSocketDir, Value: SharedMountPath},
{Name: "GMS_SOCKET_DIR", Value: SharedMountPath},
}, },
VolumeMounts: []corev1.VolumeMount{ VolumeMounts: []corev1.VolumeMount{
{Name: SharedVolumeName, MountPath: SharedMountPath}, {Name: SharedVolumeName, MountPath: SharedMountPath},
}, },
Resources: corev1.ResourceRequirements{
Claims: []corev1.ResourceClaim{{Name: dra.ClaimName}},
},
} }
} }
func copyDeviceClaims(src *corev1.Container, dst *corev1.Container) {
if src == nil || dst == nil || len(src.Resources.Claims) == 0 {
return
}
claims := make([]corev1.ResourceClaim, len(src.Resources.Claims))
copy(claims, src.Resources.Claims)
dst.Resources.Claims = append(dst.Resources.Claims, claims...)
}
/* /*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
*/ */
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dra"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
...@@ -20,71 +21,29 @@ func TestEnsureServerSidecar(t *testing.T) { ...@@ -20,71 +21,29 @@ func TestEnsureServerSidecar(t *testing.T) {
Name: "main", Name: "main",
Image: "test-image:latest", Image: "test-image:latest",
Resources: corev1.ResourceRequirements{ Resources: corev1.ResourceRequirements{
Claims: []corev1.ResourceClaim{{Name: DRAClaimName}}, Claims: []corev1.ResourceClaim{{Name: dra.ClaimName}},
}, },
}}, }},
} }
EnsureServerSidecar(podSpec, &podSpec.Containers[0]) EnsureServerSidecar(podSpec, &podSpec.Containers[0])
require.Len(t, podSpec.Containers, 1)
require.Len(t, podSpec.InitContainers, 1) require.Len(t, podSpec.InitContainers, 1)
main := &podSpec.Containers[0]
server := &podSpec.InitContainers[0] server := &podSpec.InitContainers[0]
assert.Equal(t, ServerContainerName, server.Name) assert.Equal(t, ServerContainerName, server.Name)
assert.Equal(t, []string{"python3", "-m", serverSidecarModule}, server.Command) assert.Equal(t, []string{"python3", "-m", ServerModule}, server.Command)
assert.Equal(t, SharedMountPath, envValue(t, main, "TMPDIR"))
assert.Equal(t, SharedMountPath, envValue(t, main, "GMS_SOCKET_DIR"))
assert.Equal(t, SharedMountPath, envValue(t, server, "TMPDIR"))
assert.Equal(t, SharedMountPath, envValue(t, server, "GMS_SOCKET_DIR"))
assert.Equal(t, corev1.ContainerRestartPolicyAlways, *server.RestartPolicy) assert.Equal(t, corev1.ContainerRestartPolicyAlways, *server.RestartPolicy)
require.NotNil(t, server.StartupProbe) require.NotNil(t, server.StartupProbe)
assert.Equal(t, []string{"test", "-f", filepath.Join(SharedMountPath, readyFile)}, assert.Equal(t, []string{"test", "-f", filepath.Join(SharedMountPath, readyFile)},
server.StartupProbe.Exec.Command) server.StartupProbe.Exec.Command)
assert.Equal(t, int32(1), server.StartupProbe.PeriodSeconds)
assert.Equal(t, int32(300), server.StartupProbe.FailureThreshold)
// DRA claim copied from main // DRA claim on server
assert.Len(t, server.Resources.Claims, 1) assert.Len(t, server.Resources.Claims, 1)
assert.Equal(t, DRAClaimName, server.Resources.Claims[0].Name) assert.Equal(t, dra.ClaimName, server.Resources.Claims[0].Name)
}
func TestBuildServerContainer(t *testing.T) { // Shared volume and env on main
podSpec := &corev1.PodSpec{ assert.Equal(t, SharedMountPath, envValue(t, &podSpec.Containers[0], "GMS_SOCKET_DIR"))
Containers: []corev1.Container{{
Name: "main",
Image: "test-image:latest",
Resources: corev1.ResourceRequirements{
Claims: []corev1.ResourceClaim{{Name: DRAClaimName}},
},
}},
}
server := BuildServerContainer(podSpec, &podSpec.Containers[0])
// Should not be added to init containers
assert.Empty(t, podSpec.InitContainers)
assert.Equal(t, ServerContainerName, server.Name)
assert.Equal(t, []string{"python3", "-m", serverSidecarModule}, server.Command)
// No init-specific fields
assert.Nil(t, server.RestartPolicy)
assert.Nil(t, server.StartupProbe)
// DRA claim copied from main
assert.Len(t, server.Resources.Claims, 1)
assert.Equal(t, DRAClaimName, server.Resources.Claims[0].Name)
// Shared volume and env should be set on main
main := &podSpec.Containers[0]
assert.Equal(t, SharedMountPath, envValue(t, main, "TMPDIR"))
assert.Equal(t, SharedMountPath, envValue(t, main, "GMS_SOCKET_DIR"))
// Shared volume should exist
var hasVolume bool var hasVolume bool
for _, v := range podSpec.Volumes { for _, v := range podSpec.Volumes {
if v.Name == SharedVolumeName { if v.Name == SharedVolumeName {
...@@ -94,54 +53,27 @@ func TestBuildServerContainer(t *testing.T) { ...@@ -94,54 +53,27 @@ func TestBuildServerContainer(t *testing.T) {
assert.True(t, hasVolume) assert.True(t, hasVolume)
} }
func TestEnsureServerSidecarDoesNotAddCheckpointControl(t *testing.T) { func TestEnsureServerSidecarIdempotent(t *testing.T) {
podSpec := &corev1.PodSpec{ podSpec := &corev1.PodSpec{
Containers: []corev1.Container{{Name: "main", Image: "test:latest"}}, Containers: []corev1.Container{{Name: "main", Image: "test:latest"}},
} }
EnsureServerSidecar(podSpec, &podSpec.Containers[0])
EnsureServerSidecar(podSpec, &podSpec.Containers[0]) EnsureServerSidecar(podSpec, &podSpec.Containers[0])
for _, volume := range podSpec.Volumes { assert.Len(t, podSpec.InitContainers, 1)
if volume.Name == ControlVolumeName {
t.Fatal("runtime shaping should not add checkpoint control volume")
}
}
server := FindServerContainer(podSpec)
require.NotNil(t, server)
for _, env := range server.Env {
if env.Name == "GMS_CONTROL_DIR" {
t.Fatal("server should not have checkpoint control env")
}
}
} }
func TestEnsureServerSidecarIdempotent(t *testing.T) { func TestEnsureServerSidecarDoesNotAddCheckpointControl(t *testing.T) {
podSpec := &corev1.PodSpec{ podSpec := &corev1.PodSpec{
Containers: []corev1.Container{{Name: "main", Image: "test:latest"}}, Containers: []corev1.Container{{Name: "main", Image: "test:latest"}},
} }
EnsureServerSidecar(podSpec, &podSpec.Containers[0])
EnsureServerSidecar(podSpec, &podSpec.Containers[0]) EnsureServerSidecar(podSpec, &podSpec.Containers[0])
assert.Len(t, podSpec.InitContainers, 1)
volumeCount := 0
for _, v := range podSpec.Volumes { for _, v := range podSpec.Volumes {
if v.Name == SharedVolumeName { if v.Name == "gms-control" {
volumeCount++ t.Fatal("should not add checkpoint control volume")
} }
} }
assert.Equal(t, 1, volumeCount)
}
func TestFindServerContainer(t *testing.T) {
podSpec := &corev1.PodSpec{
Containers: []corev1.Container{{Name: "main", Image: "test:latest"}},
}
assert.Nil(t, FindServerContainer(podSpec))
EnsureServerSidecar(podSpec, &podSpec.Containers[0])
assert.NotNil(t, FindServerContainer(podSpec))
assert.Equal(t, ServerContainerName, FindServerContainer(podSpec).Name)
} }
func envValue(t *testing.T, container *corev1.Container, name string) string { func envValue(t *testing.T, container *corev1.Container, name string) string {
......
...@@ -48,17 +48,13 @@ func podFromInformerObj(obj interface{}) (*corev1.Pod, bool) { ...@@ -48,17 +48,13 @@ func podFromInformerObj(obj interface{}) (*corev1.Pod, bool) {
return pod, ok return pod, ok
} }
// resolveMainContainerName returns the name of the workload container, which
// is always Containers[0]. GMS sidecars are appended after the workload.
func resolveMainContainerName(pod *corev1.Pod) string { func resolveMainContainerName(pod *corev1.Pod) string {
containerName := "" if len(pod.Spec.Containers) == 0 {
for _, c := range pod.Spec.Containers { return ""
if c.Name == "main" {
return c.Name
} }
if containerName == "" { return pod.Spec.Containers[0].Name
containerName = c.Name
}
}
return containerName
} }
func isPodReady(pod *corev1.Pod) bool { func isPodReady(pod *corev1.Pod) bool {
......
...@@ -5,6 +5,7 @@ package protocol ...@@ -5,6 +5,7 @@ package protocol
import ( import (
"fmt" "fmt"
"strings"
batchv1 "k8s.io/api/batch/v1" batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
...@@ -56,10 +57,10 @@ func NewCheckpointJob(podTemplate *corev1.PodTemplateSpec, opts CheckpointJobOpt ...@@ -56,10 +57,10 @@ func NewCheckpointJob(podTemplate *corev1.PodTemplateSpec, opts CheckpointJobOpt
EnsureLocalhostSeccompProfile(&podTemplate.Spec, opts.SeccompProfile) EnsureLocalhostSeccompProfile(&podTemplate.Spec, opts.SeccompProfile)
} }
if opts.WrapLaunchJob { if opts.WrapLaunchJob {
container, err := ResolveCheckpointWorkerContainer(&podTemplate.Spec) if len(podTemplate.Spec.Containers) == 0 {
if err != nil { return nil, fmt.Errorf("checkpoint job requires at least one container")
return nil, err
} }
container := &podTemplate.Spec.Containers[0]
if len(container.Command) == 0 { if len(container.Command) == 0 {
return nil, fmt.Errorf("checkpoint job requires container.command when cuda-checkpoint launch-job wrapping is enabled") return nil, fmt.Errorf("checkpoint job requires container.command when cuda-checkpoint launch-job wrapping is enabled")
} }
...@@ -157,6 +158,17 @@ func EnsureLocalhostSeccompProfile(podSpec *corev1.PodSpec, profile string) { ...@@ -157,6 +158,17 @@ func EnsureLocalhostSeccompProfile(podSpec *corev1.PodSpec, profile string) {
} }
func wrapWithCudaCheckpointLaunchJob(command []string, args []string) ([]string, []string) { func wrapWithCudaCheckpointLaunchJob(command []string, args []string) ([]string, []string) {
// Unwrap "/bin/sh -c <single-string>" so cuda-checkpoint launches the
// actual process directly. Otherwise sh sits between cuda-checkpoint and
// the real process and swallows SIGUSR1.
if len(command) >= 2 && command[len(command)-1] == "-c" && len(args) == 1 {
shell := command[:len(command)-1] // e.g. ["/bin/sh"] — discarded
_ = shell
parts := strings.Fields(args[0])
command = parts[:1] // e.g. ["python3"]
args = parts[1:] // e.g. ["-m", "dynamo.vllm", "--model", ...]
}
wrappedArgs := make([]string, 0, len(command)+len(args)+1) wrappedArgs := make([]string, 0, len(command)+len(args)+1)
wrappedArgs = append(wrappedArgs, "--launch-job") wrappedArgs = append(wrappedArgs, "--launch-job")
wrappedArgs = append(wrappedArgs, command...) wrappedArgs = append(wrappedArgs, command...)
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package protocol
import (
"fmt"
corev1 "k8s.io/api/core/v1"
)
const checkpointWorkerContainerName = "main"
func ResolveCheckpointWorkerContainer(podSpec *corev1.PodSpec) (*corev1.Container, error) {
if podSpec == nil || len(podSpec.Containers) == 0 {
return nil, fmt.Errorf("checkpoint job requires at least one container")
}
if len(podSpec.Containers) == 1 {
return &podSpec.Containers[0], nil
}
for i := range podSpec.Containers {
if podSpec.Containers[i].Name == checkpointWorkerContainerName {
return &podSpec.Containers[i], nil
}
}
return nil, fmt.Errorf("checkpoint job requires a container named %q when multiple containers are present", checkpointWorkerContainerName)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment