gms.go 3.62 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
/*
 * SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

// Package gms provides GMS (GPU Memory Service) server container building
// for both steady-state DGD pods and checkpoint/restore flows.
package gms

import (
	"path/filepath"

	"github.com/ai-dynamo/dynamo/deploy/operator/internal/dra"
	corev1 "k8s.io/api/core/v1"
	"k8s.io/utils/ptr"
)

const (
	// ServerContainerName is the name of the GMS server init sidecar.
	ServerContainerName = "gms-server"

	// SharedVolumeName is the emptyDir volume shared between the GMS server
	// sidecar and the main workload container for UDS sockets.
	SharedVolumeName = "gms-shared"

	// SharedMountPath is the mount path for the shared GMS socket directory.
	SharedMountPath = "/shared"

	// EnvSocketDir is the environment variable name for the GMS UDS socket directory.
	EnvSocketDir = "GMS_SOCKET_DIR"

	// ServerModule is the Python module for the GMS server entry point.
	ServerModule = "gpu_memory_service.cli.server"

	readyFile = "gms-ready"
)

// EnsureServerSidecar adds the GMS server as a restartable init sidecar with a
// startup probe. Idempotent — safe to call from both the DGD and checkpoint paths.
func EnsureServerSidecar(podSpec *corev1.PodSpec, mainContainer *corev1.Container) {
	if podSpec == nil || mainContainer == nil {
		return
	}

	EnsureSharedVolume(podSpec, mainContainer)

	sidecar := Container(ServerContainerName, ServerModule, mainContainer.Image)
	sidecar.RestartPolicy = ptr.To(corev1.ContainerRestartPolicyAlways)
	sidecar.StartupProbe = &corev1.Probe{
		ProbeHandler: corev1.ProbeHandler{
			Exec: &corev1.ExecAction{
				Command: []string{"test", "-f", filepath.Join(SharedMountPath, readyFile)},
			},
		},
		PeriodSeconds:    1,
		FailureThreshold: 300, // 1s * 300 = 5 min
	}
	for i := range podSpec.InitContainers {
		if podSpec.InitContainers[i].Name == sidecar.Name {
			return
		}
	}
	podSpec.InitContainers = append(podSpec.InitContainers, sidecar)
}

// EnsureSharedVolume adds the GMS UDS socket volume, mount, and GMS_SOCKET_DIR
// env var to the main container. Idempotent.
func EnsureSharedVolume(podSpec *corev1.PodSpec, mainContainer *corev1.Container) {
	hasVolume := false
	for _, v := range podSpec.Volumes {
		if v.Name == SharedVolumeName {
			hasVolume = true
			break
		}
	}
	if !hasVolume {
		podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{
			Name:         SharedVolumeName,
			VolumeSource: corev1.VolumeSource{EmptyDir: &corev1.EmptyDirVolumeSource{}},
		})
	}

	hasMount := false
	for _, m := range mainContainer.VolumeMounts {
		if m.Name == SharedVolumeName {
			hasMount = true
			break
		}
	}
	if !hasMount {
		mainContainer.VolumeMounts = append(mainContainer.VolumeMounts, corev1.VolumeMount{Name: SharedVolumeName, MountPath: SharedMountPath})
	}

	hasEnv := false
	for _, e := range mainContainer.Env {
		if e.Name == EnvSocketDir {
			hasEnv = true
			break
		}
	}
	if !hasEnv {
		mainContainer.Env = append(mainContainer.Env, corev1.EnvVar{Name: EnvSocketDir, Value: SharedMountPath})
	}
}

// Container builds a GMS container with the shared socket volume, env, and
// DRA claim. Used for the server, loader, and saver.
func Container(name, module, image string) corev1.Container {
	return corev1.Container{
		Name:    name,
		Image:   image,
		Command: []string{"python3", "-m", module},
		Env: []corev1.EnvVar{
			{Name: EnvSocketDir, Value: SharedMountPath},
		},
		VolumeMounts: []corev1.VolumeMount{
			{Name: SharedVolumeName, MountPath: SharedMountPath},
		},
		Resources: corev1.ResourceRequirements{
			Claims: []corev1.ResourceClaim{{Name: dra.ClaimName}},
		},
	}
}