"lib/vscode:/vscode.git/clone" did not exist on "8bd37c96d6899b321730c2433c12fe5d1748b654"
Unverified Commit f3b181a9 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

feat(gms): operator-managed GMS checkpoint/restore support (#8153)

parent 091cdb51
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package gms
import (
"path/filepath"
corev1 "k8s.io/api/core/v1"
"k8s.io/utils/ptr"
)
const (
// ServerContainerName is the name of the GMS server init sidecar.
ServerContainerName = "gms-server"
// SharedVolumeName is the emptyDir volume shared between the GMS server
// sidecar and the main workload container for UDS sockets.
SharedVolumeName = "gms-shared"
// SharedMountPath is the mount path for the shared GMS socket directory.
SharedMountPath = "/shared"
// DRAClaimName is the pod-level DRA ResourceClaim name used by both the
// main container and GMS sidecars.
DRAClaimName = "shared-gpu"
// ControlVolumeName is the checkpoint-specific control volume name.
ControlVolumeName = "gms-control"
// ControlDir is the mount path for the checkpoint control volume.
ControlDir = "/tmp/gms-control"
readyFile = "gms-ready"
serverSidecarModule = "gpu_memory_service.cli.server"
)
// EnsureServerSidecar adds the GMS server as a restartable init sidecar with a
// startup probe. Used for checkpoint jobs and steady-state pods where the main
// container needs GMS sockets before starting.
func EnsureServerSidecar(podSpec *corev1.PodSpec, mainContainer *corev1.Container) {
if podSpec == nil || mainContainer == nil {
return
}
ensureSharedVolume(podSpec, mainContainer)
sidecar := serverContainer(mainContainer.Image)
sidecar.RestartPolicy = ptr.To(corev1.ContainerRestartPolicyAlways)
sidecar.StartupProbe = &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
Exec: &corev1.ExecAction{
Command: []string{"test", "-f", filepath.Join(SharedMountPath, readyFile)},
},
},
PeriodSeconds: 1,
FailureThreshold: 300, // 1s * 300 = 5 min
}
copyDeviceClaims(mainContainer, &sidecar)
// Idempotent — EnsureServerSidecar may be called by both the
// steady-state operator path and the checkpoint overlay.
for i := range podSpec.InitContainers {
if podSpec.InitContainers[i].Name == sidecar.Name {
return
}
}
podSpec.InitContainers = append(podSpec.InitContainers, sidecar)
}
// BuildServerContainer prepares the shared GMS volume/env and returns a GMS
// server container suitable for use as a regular sidecar. The caller must
// append the returned container to podSpec.Containers.
//
// Used for restore pods where the main container is CRIU-restored and does not
// need GMS sockets at startup. The gms-loader polls for sockets internally.
func BuildServerContainer(podSpec *corev1.PodSpec, mainContainer *corev1.Container) corev1.Container {
ensureSharedVolume(podSpec, mainContainer)
sidecar := serverContainer(mainContainer.Image)
copyDeviceClaims(mainContainer, &sidecar)
return sidecar
}
// FindServerContainer returns a pointer to the GMS server container, checking
// both init containers and regular containers. Returns nil if not present.
func FindServerContainer(podSpec *corev1.PodSpec) *corev1.Container {
if podSpec == nil {
return nil
}
for i := range podSpec.InitContainers {
if podSpec.InitContainers[i].Name == ServerContainerName {
return &podSpec.InitContainers[i]
}
}
for i := range podSpec.Containers {
if podSpec.Containers[i].Name == ServerContainerName {
return &podSpec.Containers[i]
}
}
return nil
}
// ensureSharedVolume adds the shared GMS socket volume, mounts, and env vars.
// Idempotent — may be called by both steady-state and checkpoint paths.
func ensureSharedVolume(podSpec *corev1.PodSpec, mainContainer *corev1.Container) {
hasVolume := false
for _, v := range podSpec.Volumes {
if v.Name == SharedVolumeName {
hasVolume = true
break
}
}
if !hasVolume {
podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{
Name: SharedVolumeName,
VolumeSource: corev1.VolumeSource{EmptyDir: &corev1.EmptyDirVolumeSource{}},
})
}
// Mount and env injection checked independently of volume existence —
// another code path may have added the volume without configuring main.
hasMount := false
for _, m := range mainContainer.VolumeMounts {
if m.Name == SharedVolumeName {
hasMount = true
break
}
}
if !hasMount {
mainContainer.VolumeMounts = append(mainContainer.VolumeMounts, corev1.VolumeMount{Name: SharedVolumeName, MountPath: SharedMountPath})
}
hasEnv := false
for _, e := range mainContainer.Env {
if e.Name == "GMS_SOCKET_DIR" {
hasEnv = true
break
}
}
if !hasEnv {
mainContainer.Env = append(mainContainer.Env,
corev1.EnvVar{Name: "TMPDIR", Value: SharedMountPath},
corev1.EnvVar{Name: "GMS_SOCKET_DIR", Value: SharedMountPath},
)
}
}
// serverContainer builds the base GMS server container without init-specific
// fields (RestartPolicy, StartupProbe). Callers add those as needed.
func serverContainer(image string) corev1.Container {
return corev1.Container{
Name: ServerContainerName,
Image: image,
Command: []string{"python3", "-m", serverSidecarModule},
Env: []corev1.EnvVar{
{Name: "TMPDIR", Value: SharedMountPath},
{Name: "GMS_SOCKET_DIR", Value: SharedMountPath},
},
VolumeMounts: []corev1.VolumeMount{
{Name: SharedVolumeName, MountPath: SharedMountPath},
},
}
}
func copyDeviceClaims(src *corev1.Container, dst *corev1.Container) {
if src == nil || dst == nil || len(src.Resources.Claims) == 0 {
return
}
claims := make([]corev1.ResourceClaim, len(src.Resources.Claims))
copy(claims, src.Resources.Claims)
dst.Resources.Claims = append(dst.Resources.Claims, claims...)
}
/*
* SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*/
package gms
import (
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
corev1 "k8s.io/api/core/v1"
)
func TestEnsureServerSidecar(t *testing.T) {
podSpec := &corev1.PodSpec{
Containers: []corev1.Container{{
Name: "main",
Image: "test-image:latest",
Resources: corev1.ResourceRequirements{
Claims: []corev1.ResourceClaim{{Name: DRAClaimName}},
},
}},
}
EnsureServerSidecar(podSpec, &podSpec.Containers[0])
require.Len(t, podSpec.Containers, 1)
require.Len(t, podSpec.InitContainers, 1)
main := &podSpec.Containers[0]
server := &podSpec.InitContainers[0]
assert.Equal(t, ServerContainerName, server.Name)
assert.Equal(t, []string{"python3", "-m", serverSidecarModule}, server.Command)
assert.Equal(t, SharedMountPath, envValue(t, main, "TMPDIR"))
assert.Equal(t, SharedMountPath, envValue(t, main, "GMS_SOCKET_DIR"))
assert.Equal(t, SharedMountPath, envValue(t, server, "TMPDIR"))
assert.Equal(t, SharedMountPath, envValue(t, server, "GMS_SOCKET_DIR"))
assert.Equal(t, corev1.ContainerRestartPolicyAlways, *server.RestartPolicy)
require.NotNil(t, server.StartupProbe)
assert.Equal(t, []string{"test", "-f", filepath.Join(SharedMountPath, readyFile)},
server.StartupProbe.Exec.Command)
assert.Equal(t, int32(1), server.StartupProbe.PeriodSeconds)
assert.Equal(t, int32(300), server.StartupProbe.FailureThreshold)
// DRA claim copied from main
assert.Len(t, server.Resources.Claims, 1)
assert.Equal(t, DRAClaimName, server.Resources.Claims[0].Name)
}
func TestBuildServerContainer(t *testing.T) {
podSpec := &corev1.PodSpec{
Containers: []corev1.Container{{
Name: "main",
Image: "test-image:latest",
Resources: corev1.ResourceRequirements{
Claims: []corev1.ResourceClaim{{Name: DRAClaimName}},
},
}},
}
server := BuildServerContainer(podSpec, &podSpec.Containers[0])
// Should not be added to init containers
assert.Empty(t, podSpec.InitContainers)
assert.Equal(t, ServerContainerName, server.Name)
assert.Equal(t, []string{"python3", "-m", serverSidecarModule}, server.Command)
// No init-specific fields
assert.Nil(t, server.RestartPolicy)
assert.Nil(t, server.StartupProbe)
// DRA claim copied from main
assert.Len(t, server.Resources.Claims, 1)
assert.Equal(t, DRAClaimName, server.Resources.Claims[0].Name)
// Shared volume and env should be set on main
main := &podSpec.Containers[0]
assert.Equal(t, SharedMountPath, envValue(t, main, "TMPDIR"))
assert.Equal(t, SharedMountPath, envValue(t, main, "GMS_SOCKET_DIR"))
// Shared volume should exist
var hasVolume bool
for _, v := range podSpec.Volumes {
if v.Name == SharedVolumeName {
hasVolume = true
}
}
assert.True(t, hasVolume)
}
func TestEnsureServerSidecarDoesNotAddCheckpointControl(t *testing.T) {
podSpec := &corev1.PodSpec{
Containers: []corev1.Container{{Name: "main", Image: "test:latest"}},
}
EnsureServerSidecar(podSpec, &podSpec.Containers[0])
for _, volume := range podSpec.Volumes {
if volume.Name == ControlVolumeName {
t.Fatal("runtime shaping should not add checkpoint control volume")
}
}
server := FindServerContainer(podSpec)
require.NotNil(t, server)
for _, env := range server.Env {
if env.Name == "GMS_CONTROL_DIR" {
t.Fatal("server should not have checkpoint control env")
}
}
}
func TestEnsureServerSidecarIdempotent(t *testing.T) {
podSpec := &corev1.PodSpec{
Containers: []corev1.Container{{Name: "main", Image: "test:latest"}},
}
EnsureServerSidecar(podSpec, &podSpec.Containers[0])
EnsureServerSidecar(podSpec, &podSpec.Containers[0])
assert.Len(t, podSpec.InitContainers, 1)
volumeCount := 0
for _, v := range podSpec.Volumes {
if v.Name == SharedVolumeName {
volumeCount++
}
}
assert.Equal(t, 1, volumeCount)
}
func TestFindServerContainer(t *testing.T) {
podSpec := &corev1.PodSpec{
Containers: []corev1.Container{{Name: "main", Image: "test:latest"}},
}
assert.Nil(t, FindServerContainer(podSpec))
EnsureServerSidecar(podSpec, &podSpec.Containers[0])
assert.NotNil(t, FindServerContainer(podSpec))
assert.Equal(t, ServerContainerName, FindServerContainer(podSpec).Name)
}
func envValue(t *testing.T, container *corev1.Container, name string) string {
t.Helper()
for _, env := range container.Env {
if env.Name == name {
return env.Value
}
}
t.Fatalf("env %s not found", name)
return ""
}
......@@ -95,14 +95,30 @@ func loadPod(manifestPath string) (*corev1.Pod, error) {
if kind := strings.TrimSpace(pod.Kind); kind != "" && kind != "Pod" {
return nil, fmt.Errorf("manifest %s is kind %q, expected Pod", manifestPath, kind)
}
if len(pod.Spec.Containers) != 1 {
if len(pod.Spec.Containers) == 0 {
return nil, fmt.Errorf(
"manifest %s has %d containers; snapshotctl requires exactly one worker container",
"manifest %s has no worker containers; snapshotctl requires at least one worker container",
manifestPath,
)
}
workerContainer := &pod.Spec.Containers[0]
if len(pod.Spec.Containers) > 1 {
workerContainer = nil
for index := range pod.Spec.Containers {
if pod.Spec.Containers[index].Name == "main" {
workerContainer = &pod.Spec.Containers[index]
break
}
}
if workerContainer == nil {
return nil, fmt.Errorf(
"manifest %s has %d containers; snapshotctl requires a worker container named main",
manifestPath,
len(pod.Spec.Containers),
)
}
if strings.TrimSpace(pod.Spec.Containers[0].Image) == "" {
}
if strings.TrimSpace(workerContainer.Image) == "" {
return nil, fmt.Errorf("manifest %s: worker container image is required", manifestPath)
}
if strings.TrimSpace(pod.Name) == "" {
......
......@@ -56,15 +56,16 @@ func NewCheckpointJob(podTemplate *corev1.PodTemplateSpec, opts CheckpointJobOpt
EnsureLocalhostSeccompProfile(&podTemplate.Spec, opts.SeccompProfile)
}
if opts.WrapLaunchJob {
if len(podTemplate.Spec.Containers) == 0 {
return nil, fmt.Errorf("checkpoint job requires one worker container")
container, err := ResolveCheckpointWorkerContainer(&podTemplate.Spec)
if err != nil {
return nil, err
}
if len(podTemplate.Spec.Containers[0].Command) == 0 {
if len(container.Command) == 0 {
return nil, fmt.Errorf("checkpoint job requires container.command when cuda-checkpoint launch-job wrapping is enabled")
}
podTemplate.Spec.Containers[0].Command, podTemplate.Spec.Containers[0].Args = wrapWithCudaCheckpointLaunchJob(
podTemplate.Spec.Containers[0].Command,
podTemplate.Spec.Containers[0].Args,
container.Command, container.Args = wrapWithCudaCheckpointLaunchJob(
container.Command,
container.Args,
)
}
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package protocol
import (
"fmt"
corev1 "k8s.io/api/core/v1"
)
const checkpointWorkerContainerName = "main"
func ResolveCheckpointWorkerContainer(podSpec *corev1.PodSpec) (*corev1.Container, error) {
if podSpec == nil || len(podSpec.Containers) == 0 {
return nil, fmt.Errorf("checkpoint job requires at least one container")
}
if len(podSpec.Containers) == 1 {
return &podSpec.Containers[0], nil
}
for i := range podSpec.Containers {
if podSpec.Containers[i].Name == checkpointWorkerContainerName {
return &podSpec.Containers[i], nil
}
}
return nil, fmt.Errorf("checkpoint job requires a container named %q when multiple containers are present", checkpointWorkerContainerName)
}
......@@ -11,6 +11,17 @@ import (
"k8s.io/utils/ptr"
)
func requireCheckpointContainer(t *testing.T, containers []corev1.Container, name string) *corev1.Container {
t.Helper()
for i := range containers {
if containers[i].Name == name {
return &containers[i]
}
}
t.Fatalf("container %q not found", name)
return nil
}
func TestNewCheckpointJob(t *testing.T) {
job, err := NewCheckpointJob(&corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
......@@ -87,6 +98,97 @@ func TestNewCheckpointJob(t *testing.T) {
}
}
func TestNewCheckpointJobPrefersContainerNamedMain(t *testing.T) {
job, err := NewCheckpointJob(&corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{Name: "sidecar", Command: []string{"sleep"}, Args: []string{"infinity"}},
{Name: "main", Command: []string{"python3", "-m", "dynamo.vllm"}, Args: []string{"--model", "Qwen"}},
},
},
}, CheckpointJobOptions{
Namespace: "test-ns",
CheckpointID: "hash",
ArtifactVersion: "2",
Name: "test-job",
TTLSecondsAfterFinish: ptr.To(int32(300)),
WrapLaunchJob: true,
})
if err != nil {
t.Fatalf("expected checkpoint job, got error: %v", err)
}
main := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, "main")
if len(main.Command) != 1 || main.Command[0] != "cuda-checkpoint" {
t.Fatalf("expected main container to be wrapped, got %#v", main.Command)
}
expectedArgs := []string{"--launch-job", "python3", "-m", "dynamo.vllm", "--model", "Qwen"}
if len(main.Args) != len(expectedArgs) {
t.Fatalf("expected launch-job args %#v, got %#v", expectedArgs, main.Args)
}
for i := range expectedArgs {
if main.Args[i] != expectedArgs[i] {
t.Fatalf("expected launch-job args %#v, got %#v", expectedArgs, main.Args)
}
}
sidecar := requireCheckpointContainer(t, job.Spec.Template.Spec.Containers, "sidecar")
if len(sidecar.Command) != 1 || sidecar.Command[0] != "sleep" {
t.Fatalf("expected sidecar command to remain unchanged, got %#v", sidecar.Command)
}
if len(sidecar.Args) != 1 || sidecar.Args[0] != "infinity" {
t.Fatalf("expected sidecar args to remain unchanged, got %#v", sidecar.Args)
}
}
func TestNewCheckpointJobAllowsSingleNonMainContainer(t *testing.T) {
job, err := NewCheckpointJob(&corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{{
Name: "worker",
Command: []string{"python3", "-m", "dynamo.vllm"},
Args: []string{"--model", "Qwen"},
}},
},
}, CheckpointJobOptions{
Namespace: "test-ns",
CheckpointID: "hash",
ArtifactVersion: "2",
Name: "test-job",
TTLSecondsAfterFinish: ptr.To(int32(300)),
WrapLaunchJob: true,
})
if err != nil {
t.Fatalf("expected checkpoint job, got error: %v", err)
}
container := &job.Spec.Template.Spec.Containers[0]
if len(container.Command) != 1 || container.Command[0] != "cuda-checkpoint" {
t.Fatalf("expected single container to be wrapped, got %#v", container.Command)
}
}
func TestNewCheckpointJobRejectsMultiContainerPodWithoutMain(t *testing.T) {
_, err := NewCheckpointJob(&corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{Name: "sidecar", Command: []string{"sleep"}, Args: []string{"infinity"}},
{Name: "worker", Command: []string{"python3", "-m", "dynamo.vllm"}},
},
},
}, CheckpointJobOptions{
Namespace: "test-ns",
CheckpointID: "hash",
ArtifactVersion: "2",
Name: "test-job",
TTLSecondsAfterFinish: ptr.To(int32(300)),
WrapLaunchJob: true,
})
if err == nil || err.Error() != "checkpoint job requires a container named \"main\" when multiple containers are present" {
t.Fatalf("expected missing main container error, got %v", err)
}
}
func TestGetCheckpointJobName(t *testing.T) {
name := GetCheckpointJobName("abc123def4567890", "2")
if name != "checkpoint-job-abc123def4567890-2" {
......
......@@ -39,12 +39,31 @@ func NewRestorePod(pod *corev1.Pod, opts PodOptions) *corev1.Pod {
pod.Annotations = map[string]string{}
}
ApplyRestoreTargetMetadata(pod.Labels, pod.Annotations, true, opts.CheckpointID, opts.ArtifactVersion)
PrepareRestorePodSpec(&pod.Spec, &pod.Spec.Containers[0], opts.Storage, opts.SeccompProfile, true)
container := resolveWorkerContainer(&pod.Spec)
if container == nil {
return nil
}
PrepareRestorePodSpec(&pod.Spec, container, opts.Storage, opts.SeccompProfile, true)
pod.Namespace = opts.Namespace
pod.Spec.RestartPolicy = corev1.RestartPolicyNever
return pod
}
func resolveWorkerContainer(podSpec *corev1.PodSpec) *corev1.Container {
if podSpec == nil {
return nil
}
if len(podSpec.Containers) == 1 {
return &podSpec.Containers[0]
}
for index := range podSpec.Containers {
if podSpec.Containers[index].Name == "main" {
return &podSpec.Containers[index]
}
}
return nil
}
func PrepareRestorePodSpec(
podSpec *corev1.PodSpec,
container *corev1.Container,
......@@ -92,10 +111,10 @@ func ValidateRestorePodSpec(
if podSpec == nil {
return fmt.Errorf("pod spec is nil")
}
if len(podSpec.Containers) != 1 {
return fmt.Errorf("restore target must have exactly one container, got %d", len(podSpec.Containers))
container := resolveWorkerContainer(podSpec)
if container == nil {
return fmt.Errorf("restore target must include a worker container named main")
}
container := &podSpec.Containers[0]
if storage.PVCName != "" {
hasVolume := false
for _, volume := range podSpec.Volumes {
......
......@@ -187,6 +187,38 @@ func TestPrepareRestorePodSpecSynthesizesStartupProbeFromLiveness(t *testing.T)
}
}
func TestNewRestorePodTargetsMainContainerWhenSidecarsPresent(t *testing.T) {
restorePod := NewRestorePod(&corev1.Pod{
ObjectMeta: metav1.ObjectMeta{Name: "worker"},
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{Name: "sidecar", Image: "sidecar:latest", Command: []string{"sidecar"}, Args: []string{"run"}},
{Name: "main", Image: "test:latest", Command: []string{"python3"}, Args: []string{"-m", "dynamo.vllm"}},
},
},
}, PodOptions{
Namespace: "test-ns",
CheckpointID: "hash",
ArtifactVersion: "2",
Storage: Storage{
Type: StorageTypePVC,
PVCName: "snapshot-pvc",
BasePath: "/checkpoints",
},
SeccompProfile: DefaultSeccompLocalhostProfile,
})
if got := restorePod.Spec.Containers[0].Command; len(got) != 1 || got[0] != "sidecar" {
t.Fatalf("expected sidecar command to remain unchanged, got %#v", got)
}
if got := restorePod.Spec.Containers[1].Command; len(got) != 2 || got[0] != "sleep" || got[1] != "infinity" {
t.Fatalf("expected main container placeholder command, got %#v", got)
}
if restorePod.Spec.Containers[1].Args != nil {
t.Fatalf("expected main container args to be cleared: %#v", restorePod.Spec.Containers[1].Args)
}
}
func TestPrepareRestorePodSpecSynthesizesStartupProbeFromReadiness(t *testing.T) {
podSpec := corev1.PodSpec{}
readinessProbe := &corev1.Probe{
......@@ -279,7 +311,7 @@ func TestValidateRestorePodSpec(t *testing.T) {
}
}
func TestValidateRestorePodSpecRequiresExactlyOneContainer(t *testing.T) {
func TestValidateRestorePodSpecRequiresMainContainerWhenMultiContainer(t *testing.T) {
profile := DefaultSeccompLocalhostProfile
podSpec := &corev1.PodSpec{
SecurityContext: &corev1.PodSecurityContext{
......@@ -314,8 +346,48 @@ func TestValidateRestorePodSpecRequiresExactlyOneContainer(t *testing.T) {
BasePath: "/checkpoints",
}
if err := ValidateRestorePodSpec(podSpec, storage, DefaultSeccompLocalhostProfile); err == nil || err.Error() != "restore target must have exactly one container, got 2" {
t.Fatalf("expected multi-container restore target to be rejected, got %v", err)
if err := ValidateRestorePodSpec(podSpec, storage, DefaultSeccompLocalhostProfile); err == nil || err.Error() != "restore target must include a worker container named main" {
t.Fatalf("expected multi-container restore target without main to be rejected, got %v", err)
}
}
func TestValidateRestorePodSpecAllowsMainContainerWithSidecars(t *testing.T) {
profile := DefaultSeccompLocalhostProfile
podSpec := &corev1.PodSpec{
SecurityContext: &corev1.PodSecurityContext{
SeccompProfile: &corev1.SeccompProfile{
Type: corev1.SeccompProfileTypeLocalhost,
LocalhostProfile: &profile,
},
},
Volumes: []corev1.Volume{{
Name: CheckpointVolumeName,
VolumeSource: corev1.VolumeSource{
PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{
ClaimName: "snapshot-pvc",
},
},
}},
Containers: []corev1.Container{
{Name: "sidecar"},
{
Name: "main",
VolumeMounts: []corev1.VolumeMount{{
Name: CheckpointVolumeName,
MountPath: "/checkpoints",
}},
},
},
}
storage := Storage{
Type: StorageTypePVC,
PVCName: "snapshot-pvc",
BasePath: "/checkpoints",
}
if err := ValidateRestorePodSpec(podSpec, storage, DefaultSeccompLocalhostProfile); err != nil {
t.Fatalf("expected main container with sidecars to validate, got %v", err)
}
}
......
......@@ -302,6 +302,7 @@ _Appears in:_
| Field | Description | Default | Validation |
| --- | --- | --- | --- |
| `identity` _[DynamoCheckpointIdentity](#dynamocheckpointidentity)_ | Identity defines the inputs that determine checkpoint equivalence | | Required: \{\} <br /> |
| `gpuMemoryService` _[GPUMemoryServiceSpec](#gpumemoryservicespec)_ | GPUMemoryService enables checkpoint-time GPU Memory Service wiring.<br />It is intentionally outside spec.identity, so it does not affect the<br />checkpoint identity hash or deduplication. | | Optional: \{\} <br /> |
| `job` _[DynamoCheckpointJobConfig](#dynamocheckpointjobconfig)_ | Job defines the configuration for the checkpoint creation Job | | Required: \{\} <br /> |
......@@ -852,6 +853,7 @@ via DRA (Dynamic Resource Allocation). The sidecar runs two GMS processes per GP
_Appears in:_
- [DynamoCheckpointSpec](#dynamocheckpointspec)
- [DynamoComponentDeploymentSharedSpec](#dynamocomponentdeploymentsharedspec)
- [DynamoComponentDeploymentSpec](#dynamocomponentdeploymentspec)
......
......@@ -142,6 +142,16 @@ spec:
...
```
If this checkpoint should capture and restore GPU Memory Service helpers, set:
```yaml
spec:
gpuMemoryService:
enabled: true
```
`spec.gpuMemoryService` is outside `spec.identity`, so it does not change the checkpoint identity hash.
For a full working example, see [deploy/operator/config/samples/nvidia.com_v1alpha1_dynamocheckpoint.yaml](https://github.com/ai-dynamo/dynamo/blob/main/deploy/operator/config/samples/nvidia.com_v1alpha1_dynamocheckpoint.yaml).
Apply it:
......@@ -262,6 +272,8 @@ spec:
...
```
Auto mode only hashes `checkpoint.identity`. If you need GMS-specific checkpoint behavior, configure it on the `DynamoCheckpoint` object with `spec.gpuMemoryService.enabled`.
Useful inspection commands:
```bash
......
......@@ -4,7 +4,13 @@
"""CLI for GPU Memory Service."""
from gpu_memory_service.cli.args import Config, parse_args
from gpu_memory_service.cli.runner import main
def main():
from gpu_memory_service.cli.runner import main as runner_main
return runner_main()
__all__ = [
"Config",
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GMS server entry point.
Launches two GMS server processes per GPU (one for weights, one for kv_cache).
Writes a ready file once all expected UDS sockets are present. Monitors an
optional checkpoint stop file and shuts down cleanly when it appears.
"""
from __future__ import annotations
import logging
import os
import signal
import subprocess
import sys
import time
from pathlib import Path
from gpu_memory_service.common.utils import get_socket_path
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
_TAGS = ("weights", "kv_cache")
_READY_FILE = "gms-ready"
def _ready_file_path() -> Path:
return Path(os.environ.get("GMS_SOCKET_DIR", "/tmp")) / _READY_FILE
def _list_devices() -> list[int]:
import pynvml
pynvml.nvmlInit()
try:
count = pynvml.nvmlDeviceGetCount()
finally:
pynvml.nvmlShutdown()
if count == 0:
raise SystemExit("no nvidia devices found")
return list(range(count))
def _optional_checkpoint_stop_file() -> Path | None:
control_dir = os.environ.get("GMS_CONTROL_DIR")
if not control_dir:
return None
return Path(control_dir) / "checkpoint-done"
def main() -> None:
ready_file = _ready_file_path()
ready_file.unlink(missing_ok=True)
devices = _list_devices()
processes = []
for device in devices:
for tag in _TAGS:
proc = subprocess.Popen(
[
sys.executable,
"-m",
"gpu_memory_service",
"--device",
str(device),
"--tag",
tag,
]
)
logger.info("Started GMS device=%d tag=%s pid=%d", device, tag, proc.pid)
processes.append(proc)
def shutdown() -> None:
for process in processes:
if process.poll() is None:
process.terminate()
def terminate(*_args) -> None:
shutdown()
raise SystemExit(0)
signal.signal(signal.SIGTERM, terminate)
signal.signal(signal.SIGINT, terminate)
stop_file = _optional_checkpoint_stop_file()
ready_written = False
while True:
stop_requested = stop_file is not None and stop_file.exists()
if stop_requested:
logger.info("checkpoint stop requested; shutting down GMS servers")
shutdown()
if not ready_written:
sockets_ready = all(
os.path.exists(get_socket_path(device, tag))
for device in devices
for tag in _TAGS
)
if sockets_ready:
ready_file.write_text("ready", encoding="utf-8")
ready_written = True
running = False
for process in processes:
exit_code = process.poll()
if exit_code is None:
running = True
continue
if stop_requested:
continue
shutdown()
raise SystemExit(exit_code)
if not running:
return
time.sleep(1)
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GMS checkpoint loader entry point.
Waits for the GMS server UDS socket on each device, then loads saved GMS
state from a checkpoint directory into the running GMS servers. Devices
are loaded in parallel to saturate PVC bandwidth.
"""
from __future__ import annotations
import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.snapshot.storage_client import GMSStorageClient
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
_WEIGHTS_TAG = "weights"
_DEFAULT_MAX_WORKERS = 8
def _list_devices() -> list[int]:
import pynvml
pynvml.nvmlInit()
try:
count = pynvml.nvmlDeviceGetCount()
finally:
pynvml.nvmlShutdown()
if count == 0:
raise SystemExit("no nvidia devices found")
return list(range(count))
def _wait_for_weights_socket(device: int) -> None:
socket_path = get_socket_path(device, _WEIGHTS_TAG)
while not os.path.exists(socket_path):
time.sleep(1)
def _load_device(checkpoint_dir: str, device: int, max_workers: int) -> None:
_wait_for_weights_socket(device)
input_dir = os.path.join(checkpoint_dir, f"device-{device}")
logger.info("Loading GMS checkpoint: device=%d input_dir=%s", device, input_dir)
t0 = time.monotonic()
client = GMSStorageClient(
socket_path=get_socket_path(device),
device=device,
)
client.load_to_gms(
input_dir,
max_workers=max_workers,
clear_existing=True,
)
elapsed = time.monotonic() - t0
logger.info("GMS checkpoint loaded: device=%d elapsed=%.2fs", device, elapsed)
def main() -> None:
checkpoint_dir = os.environ["GMS_CHECKPOINT_DIR"]
max_workers = int(os.environ.get("GMS_LOAD_WORKERS", str(_DEFAULT_MAX_WORKERS)))
devices = _list_devices()
t0 = time.monotonic()
with ThreadPoolExecutor(max_workers=len(devices)) as pool:
futures = {
pool.submit(_load_device, checkpoint_dir, dev, max_workers): dev
for dev in devices
}
for future in as_completed(futures):
dev = futures[future]
future.result()
logger.info("Device %d load complete", dev)
elapsed = time.monotonic() - t0
logger.info("All %d devices loaded in %.2fs", len(devices), elapsed)
while True:
time.sleep(3600)
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GMS checkpoint saver entry point.
Waits for the checkpoint pod to reach Ready=True, then saves GMS state from
each device in parallel. Writes a stop file to signal the GMS server to shut
down after save completes.
"""
from __future__ import annotations
import json
import logging
import os
import ssl
import time
import urllib.error
import urllib.request
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any
from gpu_memory_service.common.utils import get_socket_path
from gpu_memory_service.snapshot.storage_client import GMSStorageClient
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
_WEIGHTS_TAG = "weights"
_SERVICE_ACCOUNT_TOKEN = Path("/var/run/secrets/kubernetes.io/serviceaccount/token")
_SERVICE_ACCOUNT_CA = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt"
def _list_devices() -> list[int]:
import pynvml
pynvml.nvmlInit()
try:
count = pynvml.nvmlDeviceGetCount()
finally:
pynvml.nvmlShutdown()
if count == 0:
raise SystemExit("no nvidia devices found")
return list(range(count))
def _wait_for_weights_socket(device: int) -> None:
socket_path = get_socket_path(device, _WEIGHTS_TAG)
while not os.path.exists(socket_path):
time.sleep(1)
def _checkpoint_pod_ready(pod: dict[str, Any]) -> bool:
status = pod.get("status") or {}
if str(status.get("phase", "")).strip() != "Running":
return False
for condition in status.get("conditions") or []:
if (
condition.get("type") == "Ready"
and str(condition.get("status", "")).strip() == "True"
):
return True
return False
def _main_terminated(pod: dict[str, Any]) -> bool:
status = pod.get("status") or {}
for container in status.get("containerStatuses") or []:
if container.get("name") != "main":
continue
return bool((container.get("state") or {}).get("terminated"))
return False
def main() -> None:
service_token = _SERVICE_ACCOUNT_TOKEN.read_text(encoding="utf-8").strip()
ssl_context = ssl.create_default_context(cafile=_SERVICE_ACCOUNT_CA)
pod_api_url = (
"https://"
+ os.environ["KUBERNETES_SERVICE_HOST"]
+ ":"
+ os.environ.get("KUBERNETES_SERVICE_PORT_HTTPS", "443")
+ f"/api/v1/namespaces/{os.environ['POD_NAMESPACE']}/pods/{os.environ['POD_NAME']}"
)
checkpoint_dir = os.environ["GMS_CHECKPOINT_DIR"]
def checkpoint_pod() -> dict[str, Any]:
request = urllib.request.Request(
pod_api_url,
headers={"Authorization": f"Bearer {service_token}"},
)
with urllib.request.urlopen(
request,
context=ssl_context,
timeout=5,
) as response:
return json.load(response)
logger.info("Waiting for checkpoint pod Ready=True before GMS save")
while True:
try:
pod = checkpoint_pod()
except (urllib.error.URLError, TimeoutError, ssl.SSLError, OSError):
time.sleep(1)
continue
if _checkpoint_pod_ready(pod):
break
if _main_terminated(pod):
raise SystemExit("main container terminated before GMS save could start")
time.sleep(1)
def _save_device(device: int, max_workers: int) -> None:
_wait_for_weights_socket(device)
output_dir = os.path.join(checkpoint_dir, f"device-{device}")
logger.info(
"Saving GMS checkpoint: device=%d output_dir=%s",
device,
output_dir,
)
t0 = time.monotonic()
client = GMSStorageClient(
output_dir,
socket_path=get_socket_path(device),
device=device,
)
client.save(max_workers=max_workers)
elapsed = time.monotonic() - t0
logger.info("GMS checkpoint saved: device=%d elapsed=%.2fs", device, elapsed)
max_workers = int(os.environ.get("GMS_SAVE_WORKERS", "8"))
logger.info("Checkpoint pod is Ready; starting GMS save")
try:
devices = _list_devices()
t0 = time.monotonic()
with ThreadPoolExecutor(max_workers=len(devices)) as pool:
futures = {
pool.submit(_save_device, dev, max_workers): dev for dev in devices
}
for future in as_completed(futures):
future.result()
elapsed = time.monotonic() - t0
logger.info("All %d devices saved in %.2fs", len(devices), elapsed)
finally:
(Path(os.environ["GMS_CONTROL_DIR"]) / "checkpoint-done").write_text(
"done",
encoding="utf-8",
)
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""GMS Storage Client CLI.
Provides two subcommands for saving and loading GPU Memory Service state:
* ``save`` – connect to a running GMS server in RO mode and write every
allocation plus all metadata to a sharded binary directory.
* ``load`` – connect to a running GMS server in RW mode, read tensor data
from a saved directory, and commit the state so readers can
acquire the RO lock.
Usage examples::
# Save GMS state to disk
gms-storage-client save --output-dir /mnt/nvme/save --device 0
# Load a previous save back into a fresh GMS server
gms-storage-client load --input-dir /mnt/nvme/save --device 0
"""
import argparse
import logging
import sys
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
def _configure_logging(verbose: bool) -> None:
if verbose:
logging.getLogger().setLevel(logging.DEBUG)
logging.getLogger("gpu_memory_service").setLevel(logging.DEBUG)
def _resolve_socket(device: int, socket_path) -> str:
if socket_path is not None:
return socket_path
from gpu_memory_service.common.utils import get_socket_path
return get_socket_path(device)
# ---------------------------------------------------------------------------
# Subcommand implementations
# ---------------------------------------------------------------------------
def _run_save(args) -> None:
"""Execute the save subcommand."""
from gpu_memory_service.snapshot.storage_client import GMSStorageClient
_configure_logging(args.verbose)
socket_path = _resolve_socket(args.device, args.socket_path)
logger.info(
"Saving GMS state: device=%s, socket=%s, output_dir=%s, save_workers=%s",
args.device,
socket_path,
args.output_dir,
args.save_workers,
)
client = GMSStorageClient(
args.output_dir,
socket_path=socket_path,
device=args.device,
timeout_ms=args.timeout_ms,
shard_size_bytes=args.shard_size_bytes,
)
manifest = client.save(max_workers=args.save_workers)
shard_count = len({a.tensor_file for a in manifest.allocations})
logger.info(
"Save complete: %s allocations written to %s (%s shards)",
len(manifest.allocations),
args.output_dir,
shard_count,
)
logger.info("Layout hash: %s", manifest.layout_hash)
def _run_load(args) -> None:
"""Execute the load subcommand."""
from gpu_memory_service.snapshot.storage_client import GMSStorageClient
_configure_logging(args.verbose)
socket_path = _resolve_socket(args.device, args.socket_path)
logger.info(
"Loading GMS state: device=%s, socket=%s, input_dir=%s, clear_existing=%s",
args.device,
socket_path,
args.input_dir,
not args.no_clear,
)
client = GMSStorageClient(
socket_path=socket_path,
device=args.device,
timeout_ms=args.timeout_ms,
)
id_map = client.load_to_gms(
args.input_dir,
max_workers=args.workers,
clear_existing=not args.no_clear,
)
logger.info("Load complete: %s allocations committed to GMS", len(id_map))
for old_id, new_id in id_map.items():
logger.info(" %s → %s", old_id, new_id)
# ---------------------------------------------------------------------------
# Argument parsing
# ---------------------------------------------------------------------------
_SHARD_SIZE_DEFAULT = 4 * 1024**3 # 4 GiB
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog="gms-storage-client",
description="Save and load GPU Memory Service state to/from disk.",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
subparsers = parser.add_subparsers(dest="subcommand")
# -- save ---------------------------------------------------------------
save_p = subparsers.add_parser(
"save",
help="Save GMS state to a sharded binary directory.",
description=(
"Connect to a running GMS server in RO mode and export every "
"allocation plus all metadata to a compact sharded binary format."
),
)
save_p.add_argument(
"--output-dir",
required=True,
help="Directory to write into (created if absent).",
)
save_p.add_argument(
"--device",
type=int,
default=0,
help="CUDA device index (default: 0).",
)
save_p.add_argument(
"--socket-path",
type=str,
default=None,
help="GMS Unix socket path. Default uses GPU UUID-based path.",
)
save_p.add_argument(
"--timeout-ms",
type=int,
default=None,
help="Timeout in milliseconds for acquiring the RO lock.",
)
save_p.add_argument(
"--shard-size-bytes",
type=int,
default=_SHARD_SIZE_DEFAULT,
help=(
f"Soft upper bound per shard file in bytes "
f"(default: {_SHARD_SIZE_DEFAULT // 1024**3} GiB). "
"Decrease to increase parallelism on save/load; increase to "
"reduce file count."
),
)
save_p.add_argument(
"--save-workers",
type=int,
default=8,
help="Thread pool size for parallel shard writes (default: 8).",
)
save_p.add_argument(
"--verbose",
"-v",
action="store_true",
help="Enable verbose logging.",
)
# -- load ---------------------------------------------------------------
load_p = subparsers.add_parser(
"load",
help="Load a saved GMS state back into a running GMS server.",
description=(
"Connect to a running GMS server in RW mode, read tensor data "
"from a saved directory (reading each shard file sequentially), "
"and commit the state so readers can acquire the RO lock."
),
)
load_p.add_argument(
"--input-dir",
required=True,
help="Directory previously created by the save subcommand.",
)
load_p.add_argument(
"--device",
type=int,
default=0,
help="CUDA device index (default: 0).",
)
load_p.add_argument(
"--socket-path",
type=str,
default=None,
help="GMS Unix socket path. Default uses GPU UUID-based path.",
)
load_p.add_argument(
"--timeout-ms",
type=int,
default=None,
help="Timeout in milliseconds for acquiring the RW lock.",
)
load_p.add_argument(
"--workers",
type=int,
default=8,
help="Thread pool size for parallel shard reads (default: 8).",
)
load_p.add_argument(
"--no-clear",
action="store_true",
default=False,
help=(
"Do not clear existing GMS allocations before loading. "
"Default behaviour clears the server to produce an exact replica."
),
)
load_p.add_argument(
"--verbose",
"-v",
action="store_true",
help="Enable verbose logging.",
)
return parser
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def main() -> None:
"""Entry point for the GMS Storage Client CLI."""
parser = _build_parser()
args = parser.parse_args()
if args.subcommand is None:
parser.print_help()
sys.exit(1)
if args.subcommand == "save":
_run_save(args)
elif args.subcommand == "load":
_run_load(args)
else:
parser.print_help()
sys.exit(1)
if __name__ == "__main__":
main()
......@@ -37,7 +37,7 @@ from typing import Dict, List, Optional
from gpu_memory_service.client.session import _GMSClientSession
from gpu_memory_service.common.cuda_utils import (
align_to_granularity,
cuda_set_current_device,
cuda_ensure_initialized,
cuda_synchronize,
cuda_validate_pointer,
cumem_address_free,
......@@ -134,22 +134,25 @@ class GMSClientMemoryManager:
socket_path: str,
*,
device: int = 0,
tag: Optional[str] = None,
) -> None:
self.socket_path = socket_path
self.device = device
self.tag = tag
self._client: Optional[_GMSClientSession] = None
self._mappings: Dict[int, LocalMapping] = {} # va -> mapping
self._inverse_mapping: Dict[str, int] = {}
self._unmapped = False
self._aborted = False
self._granted_lock_type: Optional[GrantedLockType] = None
# VA-stable unmap/remap state
self._va_preserved = False
self._last_memory_layout_hash: str = ""
cuda_set_current_device(self.device)
cuda_ensure_initialized()
self.granularity = cumem_get_allocation_granularity(device)
# ==================== Properties ====================
......@@ -183,9 +186,34 @@ class GMSClientMemoryManager:
Updates self._granted_lock_type based on granted lock type. Saves memory layout hash
for stale detection if server is in committed state.
On reconnect after abort (e.g. after CRIU restore on a different GPU),
refreshes the socket path from the current GPU UUID so we connect to
the correct GMS server.
"""
if self._client is not None:
raise RuntimeError("Memory manager is already connected")
# After abort + CRIU restore the process may be on a different GPU.
# Re-derive socket path from current UUID so we talk to the right server.
if self._aborted and self.tag is not None:
from gpu_memory_service.common.utils import (
get_socket_path,
invalidate_uuid_cache,
)
invalidate_uuid_cache()
new_path = get_socket_path(self.device, self.tag)
if new_path != self.socket_path:
logger.info(
"Refreshed socket path for tag=%s: %s -> %s",
self.tag,
self.socket_path,
new_path,
)
self.socket_path = new_path
self._aborted = False
self._client = _GMSClientSession(
self.socket_path,
lock_type=lock_type,
......@@ -211,6 +239,7 @@ class GMSClientMemoryManager:
Clean callers should unmap first. This also supports abrupt session
drop with live mappings still present.
"""
self._aborted = True
if self._client is not None:
try:
self._client.close()
......@@ -464,8 +493,6 @@ class GMSClientMemoryManager:
Checks layout hash for staleness. Validates each allocation still
exists and size matches before remapping.
"""
cuda_set_current_device(self.device)
# Stale layout check
current_hash = self.get_memory_layout_hash()
if (
......@@ -580,17 +607,29 @@ class GMSClientMemoryManager:
# ==================== Lifecycle ====================
def close(self) -> None:
"""Strict cleanup.
def close(self, *, best_effort: bool = False) -> None:
"""Cleanup mappings and abort.
synchronize + unmap all + free all VAs + abort.
Args:
best_effort: If True, skip cuda_synchronize and swallow
errors during cleanup. Used after checkpoint where
cuda-checkpoint may have torn down the device context
(cuda_synchronize calls os._exit via fail()).
"""
if best_effort:
try:
self.abort()
except Exception:
pass
self._mappings.clear()
self._inverse_mapping.clear()
else:
cuda_synchronize()
for va in list(self._mappings.keys()):
self.unmap_va(va)
self.free_va(va)
self.abort()
self._unmapped = False
self._va_preserved = False
......
......@@ -127,7 +127,7 @@ def get_or_create_gms_client_memory_manager(
)
return state.manager
manager = GMSClientMemoryManager(socket_path, device=device)
manager = GMSClientMemoryManager(socket_path, device=device, tag=tag)
manager.connect(mode, timeout_ms=timeout_ms)
mem_pool = None
......
......@@ -5,7 +5,6 @@
from __future__ import annotations
import atexit
import os
from gpu_memory_service.common.locks import GrantedLockType
......@@ -24,9 +23,6 @@ except ImportError:
cuda = _MissingCuda()
_primary_contexts: dict[int, object] = {}
_primary_context_release_registered = False
def cuda_check_result(result: cuda.CUresult, name: str) -> None:
if result != cuda.CUresult.CUDA_SUCCESS:
......@@ -169,28 +165,3 @@ def cuda_validate_pointer(va: int) -> None:
def cuda_synchronize() -> None:
(result,) = cuda.cuCtxSynchronize()
cuda_check_result(result, "cuCtxSynchronize")
def cuda_set_current_device(device: int) -> None:
global _primary_context_release_registered
ctx = _primary_contexts.get(device)
if ctx is None:
result, ctx = cuda.cuDevicePrimaryCtxRetain(device)
cuda_check_result(result, "cuDevicePrimaryCtxRetain")
_primary_contexts[device] = ctx
if not _primary_context_release_registered:
_primary_context_release_registered = True
atexit.register(_release_primary_contexts)
(result,) = cuda.cuCtxSetCurrent(ctx)
cuda_check_result(result, "cuCtxSetCurrent")
def _release_primary_contexts() -> None:
for device in list(_primary_contexts):
try:
(result,) = cuda.cuDevicePrimaryCtxRelease(device)
except Exception:
continue
if result == cuda.CUresult.CUDA_SUCCESS:
_primary_contexts.pop(device, None)
......@@ -17,11 +17,19 @@ def fail(message: str, *args, exc_info=None) -> NoReturn:
os._exit(1)
_uuid_cache: dict[int, str] = {}
def invalidate_uuid_cache() -> None:
"""Clear cached GPU UUIDs. Call after CRIU restore when GPU assignment may change."""
_uuid_cache.clear()
def get_socket_path(device: int, tag: str = "weights") -> str:
"""Get GMS socket path for the given CUDA device and tag.
The socket path is based on GPU UUID, making it stable across different
CUDA_VISIBLE_DEVICES configurations.
CUDA_VISIBLE_DEVICES configurations. UUIDs are cached per device index.
Args:
device: CUDA device index.
......@@ -30,7 +38,9 @@ def get_socket_path(device: int, tag: str = "weights") -> str:
Socket path
(e.g., "<tempdir>/gms_GPU-12345678-1234-1234-1234-123456789abc_weights.sock").
"""
import pynvml
uuid = _uuid_cache.get(device)
if uuid is None:
import pynvml # deferred: not available in all environments
pynvml.nvmlInit()
try:
......@@ -38,4 +48,6 @@ def get_socket_path(device: int, tag: str = "weights") -> str:
uuid = pynvml.nvmlDeviceGetUUID(handle)
finally:
pynvml.nvmlShutdown()
return os.path.join(tempfile.gettempdir(), f"gms_{uuid}_{tag}.sock")
_uuid_cache[device] = uuid
socket_dir = os.environ.get("GMS_SOCKET_DIR") or tempfile.gettempdir()
return os.path.join(socket_dir, f"gms_{uuid}_{tag}.sock")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment