Unverified Commit f3b181a9 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

feat(gms): operator-managed GMS checkpoint/restore support (#8153)

parent 091cdb51
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package gms
import (
"path/filepath"
corev1 "k8s.io/api/core/v1"
"k8s.io/utils/ptr"
)
const (
// ServerContainerName is the name of the GMS server init sidecar.
ServerContainerName = "gms-server"
// SharedVolumeName is the emptyDir volume shared between the GMS server
// sidecar and the main workload container for UDS sockets.
SharedVolumeName = "gms-shared"
// SharedMountPath is the mount path for the shared GMS socket directory.
SharedMountPath = "/shared"
// DRAClaimName is the pod-level DRA ResourceClaim name used by both the
// main container and GMS sidecars.
DRAClaimName = "shared-gpu"
// ControlVolumeName is the checkpoint-specific control volume name.
ControlVolumeName = "gms-control"
// ControlDir is the mount path for the checkpoint control volume.
ControlDir = "/tmp/gms-control"
readyFile = "gms-ready"
serverSidecarModule = "gpu_memory_service.cli.server"
)
// EnsureServerSidecar adds the GMS server as a restartable init sidecar with a
// startup probe. Used for checkpoint jobs and steady-state pods where the main
// container needs GMS sockets before starting.
func EnsureServerSidecar(podSpec *corev1.PodSpec, mainContainer *corev1.Container) {
if podSpec == nil || mainContainer == nil {
return
}
ensureSharedVolume(podSpec, mainContainer)
sidecar := serverContainer(mainContainer.Image)
sidecar.RestartPolicy = ptr.To(corev1.ContainerRestartPolicyAlways)
sidecar.StartupProbe = &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
Exec: &corev1.ExecAction{
Command: []string{"test", "-f", filepath.Join(SharedMountPath, readyFile)},
},
},
PeriodSeconds: 1,
FailureThreshold: 300, // 1s * 300 = 5 min
}
copyDeviceClaims(mainContainer, &sidecar)
// Idempotent — EnsureServerSidecar may be called by both the
// steady-state operator path and the checkpoint overlay.
for i := range podSpec.InitContainers {
if podSpec.InitContainers[i].Name == sidecar.Name {
return
}
}
podSpec.InitContainers = append(podSpec.InitContainers, sidecar)
}
// BuildServerContainer prepares the shared GMS volume/env and returns a GMS
// server container suitable for use as a regular sidecar. The caller must
// append the returned container to podSpec.Containers.
//
// Used for restore pods where the main container is CRIU-restored and does not
// need GMS sockets at startup. The gms-loader polls for sockets internally.
func BuildServerContainer(podSpec *corev1.PodSpec, mainContainer *corev1.Container) corev1.Container {
ensureSharedVolume(podSpec, mainContainer)
sidecar := serverContainer(mainContainer.Image)
copyDeviceClaims(mainContainer, &sidecar)
return sidecar
}
// FindServerContainer returns a pointer to the GMS server container, checking
// both init containers and regular containers. Returns nil if not present.
func FindServerContainer(podSpec *corev1.PodSpec) *corev1.Container {
if podSpec == nil {
return nil
}
for i := range podSpec.InitContainers {
if podSpec.InitContainers[i].Name == ServerContainerName {
return &podSpec.InitContainers[i]
}
}
for i := range podSpec.Containers {
if podSpec.Containers[i].Name == ServerContainerName {
return &podSpec.Containers[i]
}
}
return nil
}
// ensureSharedVolume adds the shared GMS socket volume, mounts, and env vars.
// Idempotent — may be called by both steady-state and checkpoint paths.
func ensureSharedVolume(podSpec *corev1.PodSpec, mainContainer *corev1.Container) {
hasVolume := false
for _, v := range podSpec.Volumes {
if v.Name == SharedVolumeName {
hasVolume = true
break
}
}
if !hasVolume {
podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{
Name: SharedVolumeName,
VolumeSource: corev1.VolumeSource{EmptyDir: &corev1.EmptyDirVolumeSource{}},
})
}
// Mount and env injection checked independently of volume existence —
// another code path may have added the volume without configuring main.
hasMount := false
for _, m := range mainContainer.VolumeMounts {
if m.Name == SharedVolumeName {
hasMount = true
break
}
}
if !hasMount {
mainContainer.VolumeMounts = append(mainContainer.VolumeMounts, corev1.VolumeMount{Name: SharedVolumeName, MountPath: SharedMountPath})
}
hasEnv := false
for _, e := range mainContainer.Env {
if e.Name == "GMS_SOCKET_DIR" {
hasEnv = true
break
}
}
if !hasEnv {
mainContainer.Env = append(mainContainer.Env,
corev1.EnvVar{Name: "TMPDIR", Value: SharedMountPath},
corev1.EnvVar{Name: "GMS_SOCKET_DIR", Value: SharedMountPath},
)
}
}
// serverContainer builds the base GMS server container without init-specific
// fields (RestartPolicy, StartupProbe). Callers add those as needed.
func serverContainer(image string) corev1.Container {
return corev1.Container{
Name: ServerContainerName,
Image: image,
Command: []string{"python3", "-m", serverSidecarModule},
Env: []corev1.EnvVar{
{Name: "TMPDIR", Value: SharedMountPath},
{Name: "GMS_SOCKET_DIR", Value: SharedMountPath},
},
VolumeMounts: []corev1.VolumeMount{
{Name: SharedVolumeName, MountPath: SharedMountPath},
},
}
}
func copyDeviceClaims(src *corev1.Container, dst *corev1.Container) {
if src == nil || dst == nil || len(src.Resources.Claims) == 0 {
return
}
claims := make([]corev1.ResourceClaim, len(src.Resources.Claims))
copy(claims, src.Resources.Claims)
dst.Resources.Claims = append(dst.Resources.Claims, claims...)
}
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package gms
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
)
func TestEnsureServerSidecar(t *testing.T) {
podSpec := &corev1.PodSpec{
Containers: []corev1.Container{{
Name: "main",
Image: "test-image:latest",
Resources: corev1.ResourceRequirements{
Claims: []corev1.ResourceClaim{{Name: DRAClaimName}},
},
}},
}
EnsureServerSidecar(podSpec, &podSpec.Containers[0])
require.Len(t, podSpec.Containers, 1)
require.Len(t, podSpec.InitContainers, 1)
main := &podSpec.Containers[0]
server := &podSpec.InitContainers[0]
assert.Equal(t, ServerContainerName, server.Name)
assert.Equal(t, []string{"python3", "-m", serverSidecarModule}, server.Command)
assert.Equal(t, SharedMountPath, envValue(t, main, "TMPDIR"))
assert.Equal(t, SharedMountPath, envValue(t, main, "GMS_SOCKET_DIR"))
assert.Equal(t, SharedMountPath, envValue(t, server, "TMPDIR"))
assert.Equal(t, SharedMountPath, envValue(t, server, "GMS_SOCKET_DIR"))
assert.Equal(t, corev1.ContainerRestartPolicyAlways, *server.RestartPolicy)
require.NotNil(t, server.StartupProbe)
assert.Equal(t, []string{"test", "-f", filepath.Join(SharedMountPath, readyFile)},
server.StartupProbe.Exec.Command)
assert.Equal(t, int32(1), server.StartupProbe.PeriodSeconds)
assert.Equal(t, int32(300), server.StartupProbe.FailureThreshold)
// DRA claim copied from main
assert.Len(t, server.Resources.Claims, 1)
assert.Equal(t, DRAClaimName, server.Resources.Claims[0].Name)
}
func TestBuildServerContainer(t *testing.T) {
podSpec := &corev1.PodSpec{
Containers: []corev1.Container{{
Name: "main",
Image: "test-image:latest",
Resources: corev1.ResourceRequirements{
Claims: []corev1.ResourceClaim{{Name: DRAClaimName}},
},
}},
}
server := BuildServerContainer(podSpec, &podSpec.Containers[0])
// Should not be added to init containers
assert.Empty(t, podSpec.InitContainers)
assert.Equal(t, ServerContainerName, server.Name)
assert.Equal(t, []string{"python3", "-m", serverSidecarModule}, server.Command)
// No init-specific fields
assert.Nil(t, server.RestartPolicy)
assert.Nil(t, server.StartupProbe)
// DRA claim copied from main
assert.Len(t, server.Resources.Claims, 1)
assert.Equal(t, DRAClaimName, server.Resources.Claims[0].Name)
// Shared volume and env should be set on main
main := &podSpec.Containers[0]
assert.Equal(t, SharedMountPath, envValue(t, main, "TMPDIR"))
assert.Equal(t, SharedMountPath, envValue(t, main, "GMS_SOCKET_DIR"))
// Shared volume should exist
var hasVolume bool
for _, v := range podSpec.Volumes {
if v.Name == SharedVolumeName {
hasVolume = true
}
}
assert.True(t, hasVolume)
}
func TestEnsureServerSidecarDoesNotAddCheckpointControl(t *testing.T) {
podSpec := &corev1.PodSpec{
Containers: []corev1.Container{{Name: "main", Image: "test:latest"}},
}
EnsureServerSidecar(podSpec, &podSpec.Containers[0])
for _, volume := range podSpec.Volumes {
if volume.Name == ControlVolumeName {
t.Fatal("runtime shaping should not add checkpoint control volume")
}
}
server := FindServerContainer(podSpec)
require.NotNil(t, server)
for _, env := range server.Env {
if env.Name == "GMS_CONTROL_DIR" {
t.Fatal("server should not have checkpoint control env")
}
}
}
func TestEnsureServerSidecarIdempotent(t *testing.T) {
podSpec := &corev1.PodSpec{
Containers: []corev1.Container{{Name: "main", Image: "test:latest"}},
}
EnsureServerSidecar(podSpec, &podSpec.Containers[0])
EnsureServerSidecar(podSpec, &podSpec.Containers[0])
assert.Len(t, podSpec.InitContainers, 1)
volumeCount := 0
for _, v := range podSpec.Volumes {
if v.Name == SharedVolumeName {
volumeCount++
}
}
assert.Equal(t, 1, volumeCount)
}
func TestFindServerContainer(t *testing.T) {
podSpec := &corev1.PodSpec{
Containers: []corev1.Container{{Name: "main", Image: "test:latest"}},
}
assert.Nil(t, FindServerContainer(podSpec))
EnsureServerSidecar(podSpec, &podSpec.Containers[0])
assert.NotNil(t, FindServerContainer(podSpec))
assert.Equal(t, ServerContainerName, FindServerContainer(podSpec).Name)
}
func envValue(t *testing.T, container *corev1.Container, name string) string {
t.Helper()
for _, env := range container.Env {
if env.Name == name {
return env.Value
}
}
t.Fatalf("env %s not found", name)
return ""
}
...@@ -95,14 +95,30 @@ func loadPod(manifestPath string) (*corev1.Pod, error) { ...@@ -95,14 +95,30 @@ func loadPod(manifestPath string) (*corev1.Pod, error) {
if kind := strings.TrimSpace(pod.Kind); kind != "" && kind != "Pod" { if kind := strings.TrimSpace(pod.Kind); kind != "" && kind != "Pod" {
return nil, fmt.Errorf("manifest %s is kind %q, expected Pod", manifestPath, kind) return nil, fmt.Errorf("manifest %s is kind %q, expected Pod", manifestPath, kind)
} }
if len(pod.Spec.Containers) != 1 { if len(pod.Spec.Containers) == 0 {
return nil, fmt.Errorf( return nil, fmt.Errorf(
"manifest %s has %d containers; snapshotctl requires exactly one worker container", "manifest %s has no worker containers; snapshotctl requires at least one worker container",
manifestPath, manifestPath,
len(pod.Spec.Containers),
) )
} }
if strings.TrimSpace(pod.Spec.Containers[0].Image) == "" { workerContainer := &pod.Spec.Containers[0]
if len(pod.Spec.Containers) > 1 {
workerContainer = nil
for index := range pod.Spec.Containers {
if pod.Spec.Containers[index].Name == "main" {
workerContainer = &pod.Spec.Containers[index]
break
}
}
if workerContainer == nil {
return nil, fmt.Errorf(
"manifest %s has %d containers; snapshotctl requires a worker container named main",
manifestPath,
len(pod.Spec.Containers),
)
}
}
if strings.TrimSpace(workerContainer.Image) == "" {
return nil, fmt.Errorf("manifest %s: worker container image is required", manifestPath) return nil, fmt.Errorf("manifest %s: worker container image is required", manifestPath)
} }
if strings.TrimSpace(pod.Name) == "" { if strings.TrimSpace(pod.Name) == "" {
......
...@@ -56,15 +56,16 @@ func NewCheckpointJob(podTemplate *corev1.PodTemplateSpec, opts CheckpointJobOpt ...@@ -56,15 +56,16 @@ func NewCheckpointJob(podTemplate *corev1.PodTemplateSpec, opts CheckpointJobOpt
EnsureLocalhostSeccompProfile(&podTemplate.Spec, opts.SeccompProfile) EnsureLocalhostSeccompProfile(&podTemplate.Spec, opts.SeccompProfile)
} }
if opts.WrapLaunchJob { if opts.WrapLaunchJob {
if len(podTemplate.Spec.Containers) == 0 { container, err := ResolveCheckpointWorkerContainer(&podTemplate.Spec)
return nil, fmt.Errorf("checkpoint job requires one worker container") if err != nil {
return nil, err
} }
if len(podTemplate.Spec.Containers[0].Command) == 0 { if len(container.Command) == 0 {
return nil, fmt.Errorf("checkpoint job requires container.command when cuda-checkpoint launch-job wrapping is enabled") return nil, fmt.Errorf("checkpoint job requires container.command when cuda-checkpoint launch-job wrapping is enabled")
} }
podTemplate.Spec.Containers[0].Command, podTemplate.Spec.Containers[0].Args = wrapWithCudaCheckpointLaunchJob( container.Command, container.Args = wrapWithCudaCheckpointLaunchJob(
podTemplate.Spec.Containers[0].Command, container.Command,
podTemplate.Spec.Containers[0].Args, container.Args,
) )
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package protocol
import (
"fmt"
corev1 "k8s.io/api/core/v1"
)
const checkpointWorkerContainerName = "main"
func ResolveCheckpointWorkerContainer(podSpec *corev1.PodSpec) (*corev1.Container, error) {
if podSpec == nil || len(podSpec.Containers) == 0 {
return nil, fmt.Errorf("checkpoint job requires at least one container")
}
if len(podSpec.Containers) == 1 {
return &podSpec.Containers[0], nil
}
for i := range podSpec.Containers {
if podSpec.Containers[i].Name == checkpointWorkerContainerName {
return &podSpec.Containers[i], nil
}
}
return nil, fmt.Errorf("checkpoint job requires a container named %q when multiple containers are present", checkpointWorkerContainerName)
}
...@@ -11,6 +11,17 @@ import ( ...@@ -11,6 +11,17 @@ import (
"k8s.io/utils/ptr" "k8s.io/utils/ptr"
) )
func requireCheckpointContainer(t *testing.T, containers []corev1.Container, name string) *corev1.Container {
t.Helper()
for i := range containers {
if containers[i].Name == name {
return &containers[i]
}
}
t.Fatalf("container %q not found", name)
return nil
}
func TestNewCheckpointJob(t *testing.T) { func TestNewCheckpointJob(t *testing.T) {
job, err := NewCheckpointJob(&corev1.PodTemplateSpec{ job, err := NewCheckpointJob(&corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{ ObjectMeta: metav1.ObjectMeta{
...@@ -87,6 +98,97 @@ func TestNewCheckpointJob(t *testing.T) { ...@@ -87,6 +98,97 @@ func TestNewCheckpointJob(t *testing.T) {
} }
} }
func TestNewCheckpointJobPrefersContainerNamedMain(t *testing.T) {
job, err := NewCheckpointJob(&corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{Name: "sidecar", Command: []string{"sleep"}, Args: []string{"infinity"}},
{Name: "main", Command: []string{"python3", "-m", "dynamo.vllm"}, Args: []string{"--model", "Qwen"}},
},
},
}, CheckpointJobOptions{
Namespace: "test-ns",
CheckpointID: "hash",
ArtifactVersion: "2",
Name: "test-job",
TTLSecondsAfterFinish: ptr.To(int32(300)),
WrapLaunchJob: true,
})
if err != nil {
t.Fatalf("expected checkpoint job, got error: %v", err)
}
main := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, "main")
if len(main.Command) != 1 || main.Command[0] != "cuda-checkpoint" {
t.Fatalf("expected main container to be wrapped, got %#v", main.Command)
}
expectedArgs := []string{"--launch-job", "python3", "-m", "dynamo.vllm", "--model", "Qwen"}
if len(main.Args) != len(expectedArgs) {
t.Fatalf("expected launch-job args %#v, got %#v", expectedArgs, main.Args)
}
for i := range expectedArgs {
if main.Args[i] != expectedArgs[i] {
t.Fatalf("expected launch-job args %#v, got %#v", expectedArgs, main.Args)
}
}
sidecar := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, "sidecar")
if len(sidecar.Command) != 1 || sidecar.Command[0] != "sleep" {
t.Fatalf("expected sidecar command to remain unchanged, got %#v", sidecar.Command)
}
if len(sidecar.Args) != 1 || sidecar.Args[0] != "infinity" {
t.Fatalf("expected sidecar args to remain unchanged, got %#v", sidecar.Args)
}
}
func TestNewCheckpointJobAllowsSingleNonMainContainer(t *testing.T) {
job, err := NewCheckpointJob(&corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{{
Name: "worker",
Command: []string{"python3", "-m", "dynamo.vllm"},
Args: []string{"--model", "Qwen"},
}},
},
}, CheckpointJobOptions{
Namespace: "test-ns",
CheckpointID: "hash",
ArtifactVersion: "2",
Name: "test-job",
TTLSecondsAfterFinish: ptr.To(int32(300)),
WrapLaunchJob: true,
})
if err != nil {
t.Fatalf("expected checkpoint job, got error: %v", err)
}
container := &job.Spec.Template.Spec.Containers[0]
if len(container.Command) != 1 || container.Command[0] != "cuda-checkpoint" {
t.Fatalf("expected single container to be wrapped, got %#v", container.Command)
}
}
func TestNewCheckpointJobRejectsMultiContainerPodWithoutMain(t *testing.T) {
_, err := NewCheckpointJob(&corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{Name: "sidecar", Command: []string{"sleep"}, Args: []string{"infinity"}},
{Name: "worker", Command: []string{"python3", "-m", "dynamo.vllm"}},
},
},
}, CheckpointJobOptions{
Namespace: "test-ns",
CheckpointID: "hash",
ArtifactVersion: "2",
Name: "test-job",
TTLSecondsAfterFinish: ptr.To(int32(300)),
WrapLaunchJob: true,
})
if err == nil || err.Error() != "checkpoint job requires a container named \"main\" when multiple containers are present" {
t.Fatalf("expected missing main container error, got %v", err)
}
}
func TestGetCheckpointJobName(t *testing.T) { func TestGetCheckpointJobName(t *testing.T) {
name := GetCheckpointJobName("abc123def4567890", "2") name := GetCheckpointJobName("abc123def4567890", "2")
if name != "checkpoint-job-abc123def4567890-2" { if name != "checkpoint-job-abc123def4567890-2" {
......
...@@ -39,12 +39,31 @@ func NewRestorePod(pod *corev1.Pod, opts PodOptions) *corev1.Pod { ...@@ -39,12 +39,31 @@ func NewRestorePod(pod *corev1.Pod, opts PodOptions) *corev1.Pod {
pod.Annotations = map[string]string{} pod.Annotations = map[string]string{}
} }
ApplyRestoreTargetMetadata(pod.Labels, pod.Annotations, true, opts.CheckpointID, opts.ArtifactVersion) ApplyRestoreTargetMetadata(pod.Labels, pod.Annotations, true, opts.CheckpointID, opts.ArtifactVersion)
PrepareRestorePodSpec(&pod.Spec, &pod.Spec.Containers[0], opts.Storage, opts.SeccompProfile, true) container := resolveWorkerContainer(&pod.Spec)
if container == nil {
return nil
}
PrepareRestorePodSpec(&pod.Spec, container, opts.Storage, opts.SeccompProfile, true)
pod.Namespace = opts.Namespace pod.Namespace = opts.Namespace
pod.Spec.RestartPolicy = corev1.RestartPolicyNever pod.Spec.RestartPolicy = corev1.RestartPolicyNever
return pod return pod
} }
func resolveWorkerContainer(podSpec *corev1.PodSpec) *corev1.Container {
if podSpec == nil {
return nil
}
if len(podSpec.Containers) == 1 {
return &podSpec.Containers[0]
}
for index := range podSpec.Containers {
if podSpec.Containers[index].Name == "main" {
return &podSpec.Containers[index]
}
}
return nil
}
func PrepareRestorePodSpec( func PrepareRestorePodSpec(
podSpec *corev1.PodSpec, podSpec *corev1.PodSpec,
container *corev1.Container, container *corev1.Container,
...@@ -92,10 +111,10 @@ func ValidateRestorePodSpec( ...@@ -92,10 +111,10 @@ func ValidateRestorePodSpec(
if podSpec == nil { if podSpec == nil {
return fmt.Errorf("pod spec is nil") return fmt.Errorf("pod spec is nil")
} }
if len(podSpec.Containers) != 1 { container := resolveWorkerContainer(podSpec)
return fmt.Errorf("restore target must have exactly one container, got %d", len(podSpec.Containers)) if container == nil {
return fmt.Errorf("restore target must include a worker container named main")
} }
container := &podSpec.Containers[0]
if storage.PVCName != "" { if storage.PVCName != "" {
hasVolume := false hasVolume := false
for _, volume := range podSpec.Volumes { for _, volume := range podSpec.Volumes {
......
...@@ -187,6 +187,38 @@ func TestPrepareRestorePodSpecSynthesizesStartupProbeFromLiveness(t *testing.T) ...@@ -187,6 +187,38 @@ func TestPrepareRestorePodSpecSynthesizesStartupProbeFromLiveness(t *testing.T)
} }
} }
func TestNewRestorePodTargetsMainContainerWhenSidecarsPresent(t *testing.T) {
restorePod := NewRestorePod(&corev1.Pod{
ObjectMeta: metav1.ObjectMeta{Name: "worker"},
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{Name: "sidecar", Image: "sidecar:latest", Command: []string{"sidecar"}, Args: []string{"run"}},
{Name: "main", Image: "test:latest", Command: []string{"python3"}, Args: []string{"-m", "dynamo.vllm"}},
},
},
}, PodOptions{
Namespace: "test-ns",
CheckpointID: "hash",
ArtifactVersion: "2",
Storage: Storage{
Type: StorageTypePVC,
PVCName: "snapshot-pvc",
BasePath: "/checkpoints",
},
SeccompProfile: DefaultSeccompLocalhostProfile,
})
if got := restorePod.Spec.Containers[0].Command; len(got) != 1 || got[0] != "sidecar" {
t.Fatalf("expected sidecar command to remain unchanged, got %#v", got)
}
if got := restorePod.Spec.Containers[1].Command; len(got) != 2 || got[0] != "sleep" || got[1] != "infinity" {
t.Fatalf("expected main container placeholder command, got %#v", got)
}
if restorePod.Spec.Containers[1].Args != nil {
t.Fatalf("expected main container args to be cleared: %#v", restorePod.Spec.Containers[1].Args)
}
}
func TestPrepareRestorePodSpecSynthesizesStartupProbeFromReadiness(t *testing.T) { func TestPrepareRestorePodSpecSynthesizesStartupProbeFromReadiness(t *testing.T) {
podSpec := corev1.PodSpec{} podSpec := corev1.PodSpec{}
readinessProbe := &corev1.Probe{ readinessProbe := &corev1.Probe{
...@@ -279,7 +311,7 @@ func TestValidateRestorePodSpec(t *testing.T) { ...@@ -279,7 +311,7 @@ func TestValidateRestorePodSpec(t *testing.T) {
} }
} }
func TestValidateRestorePodSpecRequiresExactlyOneContainer(t *testing.T) { func TestValidateRestorePodSpecRequiresMainContainerWhenMultiContainer(t *testing.T) {
profile := DefaultSeccompLocalhostProfile profile := DefaultSeccompLocalhostProfile
podSpec := &corev1.PodSpec{ podSpec := &corev1.PodSpec{
SecurityContext: &corev1.PodSecurityContext{ SecurityContext: &corev1.PodSecurityContext{
...@@ -314,8 +346,48 @@ func TestValidateRestorePodSpecRequiresExactlyOneContainer(t *testing.T) { ...@@ -314,8 +346,48 @@ func TestValidateRestorePodSpecRequiresExactlyOneContainer(t *testing.T) {
BasePath: "/checkpoints", BasePath: "/checkpoints",
} }
if err := ValidateRestorePodSpec(podSpec, storage, DefaultSeccompLocalhostProfile); err == nil || err.Error() != "restore target must have exactly one container, got 2" { if err := ValidateRestorePodSpec(podSpec, storage, DefaultSeccompLocalhostProfile); err == nil || err.Error() != "restore target must include a worker container named main" {
t.Fatalf("expected multi-container restore target to be rejected, got %v", err) t.Fatalf("expected multi-container restore target without main to be rejected, got %v", err)
}
}
func TestValidateRestorePodSpecAllowsMainContainerWithSidecars(t *testing.T) {
profile := DefaultSeccompLocalhostProfile
podSpec := &corev1.PodSpec{
SecurityContext: &corev1.PodSecurityContext{
SeccompProfile: &corev1.SeccompProfile{
Type: corev1.SeccompProfileTypeLocalhost,
LocalhostProfile: &profile,
},
},
Volumes: []corev1.Volume{{
Name: CheckpointVolumeName,
VolumeSource: corev1.VolumeSource{
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
ClaimName: "snapshot-pvc",
},
},
}},
Containers: []corev1.Container{
{Name: "sidecar"},
{
Name: "main",
VolumeMounts: []corev1.VolumeMount{{
Name: CheckpointVolumeName,
MountPath: "/checkpoints",
}},
},
},
}
storage := Storage{
Type: StorageTypePVC,
PVCName: "snapshot-pvc",
BasePath: "/checkpoints",
}
if err := ValidateRestorePodSpec(podSpec, storage, DefaultSeccompLocalhostProfile); err != nil {
t.Fatalf("expected main container with sidecars to validate, got %v", err)
} }
} }
......
...@@ -302,6 +302,7 @@ _Appears in:_ ...@@ -302,6 +302,7 @@ _Appears in:_
| Field | Description | Default | Validation | | Field | Description | Default | Validation |
| --- | --- | --- | --- | | --- | --- | --- | --- |
| `identity` _[DynamoCheckpointIdentity](#dynamocheckpointidentity)_ | Identity defines the inputs that determine checkpoint equivalence | | Required: \{\} <br /> | | `identity` _[DynamoCheckpointIdentity](#dynamocheckpointidentity)_ | Identity defines the inputs that determine checkpoint equivalence | | Required: \{\} <br /> |
| `gpuMemoryService` _[GPUMemoryServiceSpec](#gpumemoryservicespec)_ | GPUMemoryService enables checkpoint-time GPU Memory Service wiring.<br />It is intentionally outside spec.identity, so it does not affect the<br />checkpoint identity hash or deduplication. | | Optional: \{\} <br /> |
| `job` _[DynamoCheckpointJobConfig](#dynamocheckpointjobconfig)_ | Job defines the configuration for the checkpoint creation Job | | Required: \{\} <br /> | | `job` _[DynamoCheckpointJobConfig](#dynamocheckpointjobconfig)_ | Job defines the configuration for the checkpoint creation Job | | Required: \{\} <br /> |
...@@ -852,6 +853,7 @@ via DRA (Dynamic Resource Allocation). The sidecar runs two GMS processes per GP ...@@ -852,6 +853,7 @@ via DRA (Dynamic Resource Allocation). The sidecar runs two GMS processes per GP
_Appears in:_ _Appears in:_
- [DynamoCheckpointSpec](#dynamocheckpointspec)
- [DynamoComponentDeploymentSharedSpec](#dynamocomponentdeploymentsharedspec) - [DynamoComponentDeploymentSharedSpec](#dynamocomponentdeploymentsharedspec)
- [DynamoComponentDeploymentSpec](#dynamocomponentdeploymentspec) - [DynamoComponentDeploymentSpec](#dynamocomponentdeploymentspec)
......
...@@ -142,6 +142,16 @@ spec: ...@@ -142,6 +142,16 @@ spec:
... ...
``` ```
If this checkpoint should capture and restore GPU Memory Service helpers, set:
```yaml
spec:
gpuMemoryService:
enabled: true
```
`spec.gpuMemoryService` is outside `spec.identity`, so it does not change the checkpoint identity hash.
For a full working example, see [deploy/operator/config/samples/nvidia.com_v1alpha1_dynamocheckpoint.yaml](https://github.com/ai-dynamo/dynamo/blob/main/deploy/operator/config/samples/nvidia.com_v1alpha1_dynamocheckpoint.yaml). For a full working example, see [deploy/operator/config/samples/nvidia.com_v1alpha1_dynamocheckpoint.yaml](https://github.com/ai-dynamo/dynamo/blob/main/deploy/operator/config/samples/nvidia.com_v1alpha1_dynamocheckpoint.yaml).
Apply it: Apply it:
...@@ -262,6 +272,8 @@ spec: ...@@ -262,6 +272,8 @@ spec:
... ...
``` ```
Auto mode only hashes `checkpoint.identity`. If you need GMS-specific checkpoint behavior, configure it on the `DynamoCheckpoint` object with `spec.gpuMemoryService.enabled`.
Useful inspection commands: Useful inspection commands:
```bash ```bash
......
...@@ -4,7 +4,13 @@ ...@@ -4,7 +4,13 @@
"""CLI for GPU Memory Service.""" """CLI for GPU Memory Service."""
from gpu_memory_service.cli.args import Config, parse_args from gpu_memory_service.cli.args import Config, parse_args
from gpu_memory_service.cli.runner import main
def main():
from gpu_memory_service.cli.runner import main as runner_main
return runner_main()
__all__ = [ __all__ = [
"Config", "Config",
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GMS server entry point.
Launches two GMS server processes per GPU (one for weights, one for kv_cache).
Writes a ready file once all expected UDS sockets are present. Monitors an
optional checkpoint stop file and shuts down cleanly when it appears.
"""
from __future__ import annotations
import logging
import os
import signal
import subprocess
import sys
import time
from pathlib import Path
from gpu_memory_service.common.utils import get_socket_path
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
_TAGS = ("weights", "kv_cache")
_READY_FILE = "gms-ready"
def _ready_file_path() -> Path:
return Path(os.environ.get("GMS_SOCKET_DIR", "/tmp")) / _READY_FILE
def _list_devices() -> list[int]:
import pynvml
pynvml.nvmlInit()
try:
count = pynvml.nvmlDeviceGetCount()
finally:
pynvml.nvmlShutdown()
if count == 0:
raise SystemExit("no nvidia devices found")
return list(range(count))
def _optional_checkpoint_stop_file() -> Path | None:
control_dir = os.environ.get("GMS_CONTROL_DIR")
if not control_dir:
return None
return Path(control_dir) / "checkpoint-done"
def main() -> None:
ready_file = _ready_file_path()
ready_file.unlink(missing_ok=True)
devices = _list_devices()
processes = []
for device in devices:
for tag in _TAGS:
proc = subprocess.Popen(
[
sys.executable,
"-m",
"gpu_memory_service",
"--device",
str(device),
"--tag",
tag,
]
)
logger.info("Started GMS device=%d tag=%s pid=%d", device, tag, proc.pid)
processes.append(proc)
def shutdown() -> None:
for process in processes:
if process.poll() is None:
process.terminate()
def terminate(*_args) -> None:
shutdown()
raise SystemExit(0)
signal.signal(signal.SIGTERM, terminate)
signal.signal(signal.SIGINT, terminate)
stop_file = _optional_checkpoint_stop_file()
ready_written = False
while True:
stop_requested = stop_file is not None and stop_file.exists()
if stop_requested:
logger.info("checkpoint stop requested; shutting down GMS servers")
shutdown()
if not ready_written:
sockets_ready = all(
os.path.exists(get_socket_path(device, tag))
for device in devices
for tag in _TAGS
)
if sockets_ready:
ready_file.write_text("ready", encoding="utf-8")
ready_written = True
running = False
for process in processes:
exit_code = process.poll()
if exit_code is None:
running = True
continue
if stop_requested:
continue
shutdown()
raise SystemExit(exit_code)
if not running:
return
time.sleep(1)
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GMS checkpoint loader entry point.
Waits for the GMS server UDS socket on each device, then loads saved GMS
state from a checkpoint directory into the running GMS servers. Devices
are loaded in parallel to saturate PVC bandwidth.
"""
from __future__ import annotations
import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.snapshot.storage_client import GMSStorageClient
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
_WEIGHTS_TAG = "weights"
_DEFAULT_MAX_WORKERS = 8
def _list_devices() -> list[int]:
import pynvml
pynvml.nvmlInit()
try:
count = pynvml.nvmlDeviceGetCount()
finally:
pynvml.nvmlShutdown()
if count == 0:
raise SystemExit("no nvidia devices found")
return list(range(count))
def _wait_for_weights_socket(device: int) -> None:
socket_path = get_socket_path(device, _WEIGHTS_TAG)
while not os.path.exists(socket_path):
time.sleep(1)
def _load_device(checkpoint_dir: str, device: int, max_workers: int) -> None:
_wait_for_weights_socket(device)
input_dir = os.path.join(checkpoint_dir, f"device-{device}")
logger.info("Loading GMS checkpoint: device=%d input_dir=%s", device, input_dir)
t0 = time.monotonic()
client = GMSStorageClient(
socket_path=get_socket_path(device),
device=device,
)
client.load_to_gms(
input_dir,
max_workers=max_workers,
clear_existing=True,
)
elapsed = time.monotonic() - t0
logger.info("GMS checkpoint loaded: device=%d elapsed=%.2fs", device, elapsed)
def main() -> None:
checkpoint_dir = os.environ["GMS_CHECKPOINT_DIR"]
max_workers = int(os.environ.get("GMS_LOAD_WORKERS", str(_DEFAULT_MAX_WORKERS)))
devices = _list_devices()
t0 = time.monotonic()
with ThreadPoolExecutor(max_workers=len(devices)) as pool:
futures = {
pool.submit(_load_device, checkpoint_dir, dev, max_workers): dev
for dev in devices
}
for future in as_completed(futures):
dev = futures[future]
future.result()
logger.info("Device %d load complete", dev)
elapsed = time.monotonic() - t0
logger.info("All %d devices loaded in %.2fs", len(devices), elapsed)
while True:
time.sleep(3600)
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GMS checkpoint saver entry point.
Waits for the checkpoint pod to reach Ready=True, then saves GMS state from
each device in parallel. Writes a stop file to signal the GMS server to shut
down after save completes.
"""
from __future__ import annotations
import json
import logging
import os
import ssl
import time
import urllib.error
import urllib.request
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any
from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.snapshot.storage_client import GMSStorageClient
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
_WEIGHTS_TAG = "weights"
_SERVICE_ACCOUNT_TOKEN = Path("/var/run/secrets/kubernetes.io/serviceaccount/token")
_SERVICE_ACCOUNT_CA = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt"
def _list_devices() -> list[int]:
import pynvml
pynvml.nvmlInit()
try:
count = pynvml.nvmlDeviceGetCount()
finally:
pynvml.nvmlShutdown()
if count == 0:
raise SystemExit("no nvidia devices found")
return list(range(count))
def _wait_for_weights_socket(device: int) -> None:
socket_path = get_socket_path(device, _WEIGHTS_TAG)
while not os.path.exists(socket_path):
time.sleep(1)
def _checkpoint_pod_ready(pod: dict[str, Any]) -> bool:
status = pod.get("status") or {}
if str(status.get("phase", "")).strip() != "Running":
return False
for condition in status.get("conditions") or []:
if (
condition.get("type") == "Ready"
and str(condition.get("status", "")).strip() == "True"
):
return True
return False
def _main_terminated(pod: dict[str, Any]) -> bool:
status = pod.get("status") or {}
for container in status.get("containerStatuses") or []:
if container.get("name") != "main":
continue
return bool((container.get("state") or {}).get("terminated"))
return False
def main() -> None:
service_token = _SERVICE_ACCOUNT_TOKEN.read_text(encoding="utf-8").strip()
ssl_context = ssl.create_default_context(cafile=_SERVICE_ACCOUNT_CA)
pod_api_url = (
"https://"
+ os.environ["KUBERNETES_SERVICE_HOST"]
+ ":"
+ os.environ.get("KUBERNETES_SERVICE_PORT_HTTPS", "443")
+ f"/api/v1/namespaces/{os.environ['POD_NAMESPACE']}/pods/{os.environ['POD_NAME']}"
)
checkpoint_dir = os.environ["GMS_CHECKPOINT_DIR"]
def checkpoint_pod() -> dict[str, Any]:
request = urllib.request.Request(
pod_api_url,
headers={"Authorization": f"Bearer {service_token}"},
)
with urllib.request.urlopen(
request,
context=ssl_context,
timeout=5,
) as response:
return json.load(response)
logger.info("Waiting for checkpoint pod Ready=True before GMS save")
while True:
try:
pod = checkpoint_pod()
except (urllib.error.URLError, TimeoutError, ssl.SSLError, OSError):
time.sleep(1)
continue
if _checkpoint_pod_ready(pod):
break
if _main_terminated(pod):
raise SystemExit("main container terminated before GMS save could start")
time.sleep(1)
def _save_device(device: int, max_workers: int) -> None:
_wait_for_weights_socket(device)
output_dir = os.path.join(checkpoint_dir, f"device-{device}")
logger.info(
"Saving GMS checkpoint: device=%d output_dir=%s",
device,
output_dir,
)
t0 = time.monotonic()
client = GMSStorageClient(
output_dir,
socket_path=get_socket_path(device),
device=device,
)
client.save(max_workers=max_workers)
elapsed = time.monotonic() - t0
logger.info("GMS checkpoint saved: device=%d elapsed=%.2fs", device, elapsed)
max_workers = int(os.environ.get("GMS_SAVE_WORKERS", "8"))
logger.info("Checkpoint pod is Ready; starting GMS save")
try:
devices = _list_devices()
t0 = time.monotonic()
with ThreadPoolExecutor(max_workers=len(devices)) as pool:
futures = {
pool.submit(_save_device, dev, max_workers): dev for dev in devices
}
for future in as_completed(futures):
future.result()
elapsed = time.monotonic() - t0
logger.info("All %d devices saved in %.2fs", len(devices), elapsed)
finally:
(Path(os.environ["GMS_CONTROL_DIR"]) / "checkpoint-done").write_text(
"done",
encoding="utf-8",
)
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GMS Storage Client CLI.
Provides two subcommands for saving and loading GPU Memory Service state:
* ``save`` – connect to a running GMS server in RO mode and write every
allocation plus all metadata to a sharded binary directory.
* ``load`` – connect to a running GMS server in RW mode, read tensor data
from a saved directory, and commit the state so readers can
acquire the RO lock.
Usage examples::
# Save GMS state to disk
gms-storage-client save --output-dir /mnt/nvme/save --device 0
# Load a previous save back into a fresh GMS server
gms-storage-client load --input-dir /mnt/nvme/save --device 0
"""
import argparse
import logging
import sys
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
def _configure_logging(verbose: bool) -> None:
if verbose:
logging.getLogger().setLevel(logging.DEBUG)
logging.getLogger("gpu_memory_service").setLevel(logging.DEBUG)
def _resolve_socket(device: int, socket_path) -> str:
if socket_path is not None:
return socket_path
from gpu_memory_service.common.utils import get_socket_path
return get_socket_path(device)
# ---------------------------------------------------------------------------
# Subcommand implementations
# ---------------------------------------------------------------------------
def _run_save(args) -> None:
"""Execute the save subcommand."""
from gpu_memory_service.snapshot.storage_client import GMSStorageClient
_configure_logging(args.verbose)
socket_path = _resolve_socket(args.device, args.socket_path)
logger.info(
"Saving GMS state: device=%s, socket=%s, output_dir=%s, save_workers=%s",
args.device,
socket_path,
args.output_dir,
args.save_workers,
)
client = GMSStorageClient(
args.output_dir,
socket_path=socket_path,
device=args.device,
timeout_ms=args.timeout_ms,
shard_size_bytes=args.shard_size_bytes,
)
manifest = client.save(max_workers=args.save_workers)
shard_count = len({a.tensor_file for a in manifest.allocations})
logger.info(
"Save complete: %s allocations written to %s (%s shards)",
len(manifest.allocations),
args.output_dir,
shard_count,
)
logger.info("Layout hash: %s", manifest.layout_hash)
def _run_load(args) -> None:
"""Execute the load subcommand."""
from gpu_memory_service.snapshot.storage_client import GMSStorageClient
_configure_logging(args.verbose)
socket_path = _resolve_socket(args.device, args.socket_path)
logger.info(
"Loading GMS state: device=%s, socket=%s, input_dir=%s, clear_existing=%s",
args.device,
socket_path,
args.input_dir,
not args.no_clear,
)
client = GMSStorageClient(
socket_path=socket_path,
device=args.device,
timeout_ms=args.timeout_ms,
)
id_map = client.load_to_gms(
args.input_dir,
max_workers=args.workers,
clear_existing=not args.no_clear,
)
logger.info("Load complete: %s allocations committed to GMS", len(id_map))
for old_id, new_id in id_map.items():
logger.info(" %s → %s", old_id, new_id)
# ---------------------------------------------------------------------------
# Argument parsing
# ---------------------------------------------------------------------------
_SHARD_SIZE_DEFAULT = 4 * 1024**3 # 4 GiB
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog="gms-storage-client",
description="Save and load GPU Memory Service state to/from disk.",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
subparsers = parser.add_subparsers(dest="subcommand")
# -- save ---------------------------------------------------------------
save_p = subparsers.add_parser(
"save",
help="Save GMS state to a sharded binary directory.",
description=(
"Connect to a running GMS server in RO mode and export every "
"allocation plus all metadata to a compact sharded binary format."
),
)
save_p.add_argument(
"--output-dir",
required=True,
help="Directory to write into (created if absent).",
)
save_p.add_argument(
"--device",
type=int,
default=0,
help="CUDA device index (default: 0).",
)
save_p.add_argument(
"--socket-path",
type=str,
default=None,
help="GMS Unix socket path. Default uses GPU UUID-based path.",
)
save_p.add_argument(
"--timeout-ms",
type=int,
default=None,
help="Timeout in milliseconds for acquiring the RO lock.",
)
save_p.add_argument(
"--shard-size-bytes",
type=int,
default=_SHARD_SIZE_DEFAULT,
help=(
f"Soft upper bound per shard file in bytes "
f"(default: {_SHARD_SIZE_DEFAULT // 1024**3} GiB). "
"Decrease to increase parallelism on save/load; increase to "
"reduce file count."
),
)
save_p.add_argument(
"--save-workers",
type=int,
default=8,
help="Thread pool size for parallel shard writes (default: 8).",
)
save_p.add_argument(
"--verbose",
"-v",
action="store_true",
help="Enable verbose logging.",
)
# -- load ---------------------------------------------------------------
load_p = subparsers.add_parser(
"load",
help="Load a saved GMS state back into a running GMS server.",
description=(
"Connect to a running GMS server in RW mode, read tensor data "
"from a saved directory (reading each shard file sequentially), "
"and commit the state so readers can acquire the RO lock."
),
)
load_p.add_argument(
"--input-dir",
required=True,
help="Directory previously created by the save subcommand.",
)
load_p.add_argument(
"--device",
type=int,
default=0,
help="CUDA device index (default: 0).",
)
load_p.add_argument(
"--socket-path",
type=str,
default=None,
help="GMS Unix socket path. Default uses GPU UUID-based path.",
)
load_p.add_argument(
"--timeout-ms",
type=int,
default=None,
help="Timeout in milliseconds for acquiring the RW lock.",
)
load_p.add_argument(
"--workers",
type=int,
default=8,
help="Thread pool size for parallel shard reads (default: 8).",
)
load_p.add_argument(
"--no-clear",
action="store_true",
default=False,
help=(
"Do not clear existing GMS allocations before loading. "
"Default behaviour clears the server to produce an exact replica."
),
)
load_p.add_argument(
"--verbose",
"-v",
action="store_true",
help="Enable verbose logging.",
)
return parser
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def main() -> None:
"""Entry point for the GMS Storage Client CLI."""
parser = _build_parser()
args = parser.parse_args()
if args.subcommand is None:
parser.print_help()
sys.exit(1)
if args.subcommand == "save":
_run_save(args)
elif args.subcommand == "load":
_run_load(args)
else:
parser.print_help()
sys.exit(1)
if __name__ == "__main__":
main()
...@@ -37,7 +37,7 @@ from typing import Dict, List, Optional ...@@ -37,7 +37,7 @@ from typing import Dict, List, Optional
from gpu_memory_service.client.session import _GMSClientSession from gpu_memory_service.client.session import _GMSClientSession
from gpu_memory_service.common.cuda_utils import ( from gpu_memory_service.common.cuda_utils import (
align_to_granularity, align_to_granularity,
cuda_set_current_device, cuda_ensure_initialized,
cuda_synchronize, cuda_synchronize,
cuda_validate_pointer, cuda_validate_pointer,
cumem_address_free, cumem_address_free,
...@@ -134,22 +134,25 @@ class GMSClientMemoryManager: ...@@ -134,22 +134,25 @@ class GMSClientMemoryManager:
socket_path: str, socket_path: str,
*, *,
device: int = 0, device: int = 0,
tag: Optional[str] = None,
) -> None: ) -> None:
self.socket_path = socket_path self.socket_path = socket_path
self.device = device self.device = device
self.tag = tag
self._client: Optional[_GMSClientSession] = None self._client: Optional[_GMSClientSession] = None
self._mappings: Dict[int, LocalMapping] = {} # va -> mapping self._mappings: Dict[int, LocalMapping] = {} # va -> mapping
self._inverse_mapping: Dict[str, int] = {} self._inverse_mapping: Dict[str, int] = {}
self._unmapped = False self._unmapped = False
self._aborted = False
self._granted_lock_type: Optional[GrantedLockType] = None self._granted_lock_type: Optional[GrantedLockType] = None
# VA-stable unmap/remap state # VA-stable unmap/remap state
self._va_preserved = False self._va_preserved = False
self._last_memory_layout_hash: str = "" self._last_memory_layout_hash: str = ""
cuda_set_current_device(self.device) cuda_ensure_initialized()
self.granularity = cumem_get_allocation_granularity(device) self.granularity = cumem_get_allocation_granularity(device)
# ==================== Properties ==================== # ==================== Properties ====================
...@@ -183,9 +186,34 @@ class GMSClientMemoryManager: ...@@ -183,9 +186,34 @@ class GMSClientMemoryManager:
Updates self._granted_lock_type based on granted lock type. Saves memory layout hash Updates self._granted_lock_type based on granted lock type. Saves memory layout hash
for stale detection if server is in committed state. for stale detection if server is in committed state.
On reconnect after abort (e.g. after CRIU restore on a different GPU),
refreshes the socket path from the current GPU UUID so we connect to
the correct GMS server.
""" """
if self._client is not None: if self._client is not None:
raise RuntimeError("Memory manager is already connected") raise RuntimeError("Memory manager is already connected")
# After abort + CRIU restore the process may be on a different GPU.
# Re-derive socket path from current UUID so we talk to the right server.
if self._aborted and self.tag is not None:
from gpu_memory_service.common.utils import (
get_socket_path,
invalidate_uuid_cache,
)
invalidate_uuid_cache()
new_path = get_socket_path(self.device, self.tag)
if new_path != self.socket_path:
logger.info(
"Refreshed socket path for tag=%s: %s -> %s",
self.tag,
self.socket_path,
new_path,
)
self.socket_path = new_path
self._aborted = False
self._client = _GMSClientSession( self._client = _GMSClientSession(
self.socket_path, self.socket_path,
lock_type=lock_type, lock_type=lock_type,
...@@ -211,6 +239,7 @@ class GMSClientMemoryManager: ...@@ -211,6 +239,7 @@ class GMSClientMemoryManager:
Clean callers should unmap first. This also supports abrupt session Clean callers should unmap first. This also supports abrupt session
drop with live mappings still present. drop with live mappings still present.
""" """
self._aborted = True
if self._client is not None: if self._client is not None:
try: try:
self._client.close() self._client.close()
...@@ -464,8 +493,6 @@ class GMSClientMemoryManager: ...@@ -464,8 +493,6 @@ class GMSClientMemoryManager:
Checks layout hash for staleness. Validates each allocation still Checks layout hash for staleness. Validates each allocation still
exists and size matches before remapping. exists and size matches before remapping.
""" """
cuda_set_current_device(self.device)
# Stale layout check # Stale layout check
current_hash = self.get_memory_layout_hash() current_hash = self.get_memory_layout_hash()
if ( if (
...@@ -580,18 +607,30 @@ class GMSClientMemoryManager: ...@@ -580,18 +607,30 @@ class GMSClientMemoryManager:
# ==================== Lifecycle ==================== # ==================== Lifecycle ====================
def close(self) -> None: def close(self, *, best_effort: bool = False) -> None:
"""Strict cleanup. """Cleanup mappings and abort.
synchronize + unmap all + free all VAs + abort. synchronize + unmap all + free all VAs + abort.
"""
cuda_synchronize()
for va in list(self._mappings.keys()):
self.unmap_va(va)
self.free_va(va)
self.abort() Args:
best_effort: If True, skip cuda_synchronize and swallow
errors during cleanup. Used after checkpoint where
cuda-checkpoint may have torn down the device context
(cuda_synchronize calls os._exit via fail()).
"""
if best_effort:
try:
self.abort()
except Exception:
pass
self._mappings.clear()
self._inverse_mapping.clear()
else:
cuda_synchronize()
for va in list(self._mappings.keys()):
self.unmap_va(va)
self.free_va(va)
self.abort()
self._unmapped = False self._unmapped = False
self._va_preserved = False self._va_preserved = False
from gpu_memory_service.client.torch.allocator import ( from gpu_memory_service.client.torch.allocator import (
......
...@@ -127,7 +127,7 @@ def get_or_create_gms_client_memory_manager( ...@@ -127,7 +127,7 @@ def get_or_create_gms_client_memory_manager(
) )
return state.manager return state.manager
manager = GMSClientMemoryManager(socket_path, device=device) manager = GMSClientMemoryManager(socket_path, device=device, tag=tag)
manager.connect(mode, timeout_ms=timeout_ms) manager.connect(mode, timeout_ms=timeout_ms)
mem_pool = None mem_pool = None
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
from __future__ import annotations from __future__ import annotations
import atexit
import os import os
from gpu_memory_service.common.locks import GrantedLockType from gpu_memory_service.common.locks import GrantedLockType
...@@ -24,9 +23,6 @@ except ImportError: ...@@ -24,9 +23,6 @@ except ImportError:
cuda = _MissingCuda() cuda = _MissingCuda()
_primary_contexts: dict[int, object] = {}
_primary_context_release_registered = False
def cuda_check_result(result: cuda.CUresult, name: str) -> None: def cuda_check_result(result: cuda.CUresult, name: str) -> None:
if result != cuda.CUresult.CUDA_SUCCESS: if result != cuda.CUresult.CUDA_SUCCESS:
...@@ -169,28 +165,3 @@ def cuda_validate_pointer(va: int) -> None: ...@@ -169,28 +165,3 @@ def cuda_validate_pointer(va: int) -> None:
def cuda_synchronize() -> None: def cuda_synchronize() -> None:
(result,) = cuda.cuCtxSynchronize() (result,) = cuda.cuCtxSynchronize()
cuda_check_result(result, "cuCtxSynchronize") cuda_check_result(result, "cuCtxSynchronize")
def cuda_set_current_device(device: int) -> None:
global _primary_context_release_registered
ctx = _primary_contexts.get(device)
if ctx is None:
result, ctx = cuda.cuDevicePrimaryCtxRetain(device)
cuda_check_result(result, "cuDevicePrimaryCtxRetain")
_primary_contexts[device] = ctx
if not _primary_context_release_registered:
_primary_context_release_registered = True
atexit.register(_release_primary_contexts)
(result,) = cuda.cuCtxSetCurrent(ctx)
cuda_check_result(result, "cuCtxSetCurrent")
def _release_primary_contexts() -> None:
for device in list(_primary_contexts):
try:
(result,) = cuda.cuDevicePrimaryCtxRelease(device)
except Exception:
continue
if result == cuda.CUresult.CUDA_SUCCESS:
_primary_contexts.pop(device, None)
...@@ -17,11 +17,19 @@ def fail(message: str, *args, exc_info=None) -> NoReturn: ...@@ -17,11 +17,19 @@ def fail(message: str, *args, exc_info=None) -> NoReturn:
os._exit(1) os._exit(1)
_uuid_cache: dict[int, str] = {}
def invalidate_uuid_cache() -> None:
"""Clear cached GPU UUIDs. Call after CRIU restore when GPU assignment may change."""
_uuid_cache.clear()
def get_socket_path(device: int, tag: str = "weights") -> str: def get_socket_path(device: int, tag: str = "weights") -> str:
"""Get GMS socket path for the given CUDA device and tag. """Get GMS socket path for the given CUDA device and tag.
The socket path is based on GPU UUID, making it stable across different The socket path is based on GPU UUID, making it stable across different
CUDA_VISIBLE_DEVICES configurations. CUDA_VISIBLE_DEVICES configurations. UUIDs are cached per device index.
Args: Args:
device: CUDA device index. device: CUDA device index.
...@@ -30,12 +38,16 @@ def get_socket_path(device: int, tag: str = "weights") -> str: ...@@ -30,12 +38,16 @@ def get_socket_path(device: int, tag: str = "weights") -> str:
Socket path Socket path
(e.g., "<tempdir>/gms_GPU-12345678-1234-1234-1234-123456789abc_weights.sock"). (e.g., "<tempdir>/gms_GPU-12345678-1234-1234-1234-123456789abc_weights.sock").
""" """
import pynvml uuid = _uuid_cache.get(device)
if uuid is None:
pynvml.nvmlInit() import pynvml # deferred: not available in all environments
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(device) pynvml.nvmlInit()
uuid = pynvml.nvmlDeviceGetUUID(handle) try:
finally: handle = pynvml.nvmlDeviceGetHandleByIndex(device)
pynvml.nvmlShutdown() uuid = pynvml.nvmlDeviceGetUUID(handle)
return os.path.join(tempfile.gettempdir(), f"gms_{uuid}_{tag}.sock") finally:
pynvml.nvmlShutdown()
_uuid_cache[device] = uuid
socket_dir = os.environ.get("GMS_SOCKET_DIR") or tempfile.gettempdir()
return os.path.join(socket_dir, f"gms_{uuid}_{tag}.sock")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment