gms.go 7.68 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
/*
 * SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

package checkpoint

import (
	"context"
	"fmt"
	"path/filepath"

	gmsruntime "github.com/ai-dynamo/dynamo/deploy/operator/internal/gms"
	snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
	appsv1 "k8s.io/api/apps/v1"
	corev1 "k8s.io/api/core/v1"
	ctrlclient "sigs.k8s.io/controller-runtime/pkg/client"
)

const (
	GMSLoaderContainer = "gms-loader"
	GMSSaverContainer  = "gms-saver"

	gmsCheckpointLoaderModule = "gpu_memory_service.cli.snapshot.loader"
	gmsCheckpointSaverModule  = "gpu_memory_service.cli.snapshot.saver"
)

func ResolveGMSCheckpointStorage(
	ctx context.Context,
	reader ctrlclient.Reader,
	namespace string,
	checkpointID string,
	artifactVersion string,
) (snapshotprotocol.Storage, error) {
	if reader == nil {
		return snapshotprotocol.Storage{}, fmt.Errorf("checkpoint client is required")
	}

	daemonSets := &appsv1.DaemonSetList{}
	if err := reader.List(
		ctx,
		daemonSets,
		ctrlclient.InNamespace(namespace),
		ctrlclient.MatchingLabels{snapshotprotocol.SnapshotAgentLabelKey: snapshotprotocol.SnapshotAgentLabelValue},
	); err != nil {
		return snapshotprotocol.Storage{}, fmt.Errorf("list snapshot-agent daemonsets in %s: %w", namespace, err)
	}

	storage, err := snapshotprotocol.DiscoverStorageFromDaemonSets(namespace, daemonSets.Items)
	if err != nil {
		return snapshotprotocol.Storage{}, err
	}
	return snapshotprotocol.ResolveCheckpointStorage(checkpointID, artifactVersion, storage)
}

// BuildGMSRestoreSidecars prepares GMS infrastructure for a restore pod and
// returns the additional containers the caller must append to podSpec.Containers.
//
// The GMS server runs as a regular container (not init) because the CRIU-restored
// main process already has GPU memory mapped and does not need sockets at
// startup. The gms-loader polls for sockets internally via wait_for_weights_socket.
func BuildGMSRestoreSidecars(
	podSpec *corev1.PodSpec,
	mainContainer *corev1.Container,
	storage snapshotprotocol.Storage,
) []corev1.Container {
	if podSpec == nil || mainContainer == nil {
		return nil
	}

	// Remove gms-server from initContainers if the DGD-level
	// applyGPUMemoryService already placed it there. For restore pods the
	// server runs as a regular container so that all containers start in
	// parallel — the restored main process does not need sockets at startup.
	for i := range podSpec.InitContainers {
		if podSpec.InitContainers[i].Name == gmsruntime.ServerContainerName {
			podSpec.InitContainers = append(podSpec.InitContainers[:i], podSpec.InitContainers[i+1:]...)
			break
		}
	}

	server := gmsruntime.BuildServerContainer(podSpec, mainContainer)

	loader := gmsCheckpointLoaderContainer(mainContainer.Image)
	copyGMSDeviceClaims(mainContainer, &loader)
	ensureCheckpointVolume(podSpec, storage.PVCName)
	loader.VolumeMounts = append(loader.VolumeMounts, corev1.VolumeMount{Name: snapshotprotocol.CheckpointVolumeName, MountPath: storage.BasePath})
	loader.Env = append(loader.Env, corev1.EnvVar{Name: "GMS_CHECKPOINT_DIR", Value: resolveGMSArtifactDir(storage)})

	return []corev1.Container{server, loader}
}

// BuildGMSCheckpointJobSidecars prepares GMS infrastructure for a checkpoint
// job and returns the additional containers the caller must append to
// podSpec.Containers.
func BuildGMSCheckpointJobSidecars(
	podSpec *corev1.PodSpec,
	mainContainer *corev1.Container,
	storage snapshotprotocol.Storage,
) ([]corev1.Container, error) {
	if podSpec == nil || mainContainer == nil {
		return nil, nil
	}
	if len(mainContainer.Resources.Claims) == 0 {
		return nil, fmt.Errorf("gms sidecars require main container resource claims")
	}
	if storage.PVCName == "" || storage.BasePath == "" || storage.Location == "" {
		return nil, fmt.Errorf("gms checkpoint jobs require resolved checkpoint storage")
	}

	gmsruntime.EnsureServerSidecar(podSpec, mainContainer)
	ensureGMSCheckpointControl(podSpec)

	saver := gmsCheckpointSaverContainer(mainContainer.Image)
	copyGMSDeviceClaims(mainContainer, &saver)
	ensureCheckpointVolume(podSpec, storage.PVCName)
	saver.VolumeMounts = append(saver.VolumeMounts, corev1.VolumeMount{Name: snapshotprotocol.CheckpointVolumeName, MountPath: storage.BasePath})
	saver.Env = append(saver.Env, corev1.EnvVar{Name: "GMS_CHECKPOINT_DIR", Value: resolveGMSArtifactDir(storage)})

	return []corev1.Container{saver}, nil
}

func resolveGMSArtifactDir(storage snapshotprotocol.Storage) string {
	// GMS data lives under /checkpoints/gms/<hash>/versions/<version>
	// separate from the CRIU tree (/checkpoints/<hash>/versions/<version>)
	// so the non-root saver can create directories at the PVC root.
	artifactVersion := filepath.Base(storage.Location)
	checkpointID := filepath.Base(filepath.Dir(filepath.Dir(storage.Location)))
	return filepath.Join(storage.BasePath, "gms", checkpointID, "versions", artifactVersion)
}

func gmsCheckpointLoaderContainer(image string) corev1.Container {
	container := corev1.Container{
		Name:    GMSLoaderContainer,
		Image:   image,
		Command: []string{"python3", "-m", gmsCheckpointLoaderModule},
		Env: []corev1.EnvVar{
			{Name: "TMPDIR", Value: gmsruntime.SharedMountPath},
			{Name: "GMS_SOCKET_DIR", Value: gmsruntime.SharedMountPath},
		},
		VolumeMounts: []corev1.VolumeMount{
			{Name: gmsruntime.SharedVolumeName, MountPath: gmsruntime.SharedMountPath},
		},
	}
	return container
}

func gmsCheckpointSaverContainer(image string) corev1.Container {
	container := corev1.Container{
		Name:    GMSSaverContainer,
		Image:   image,
		Command: []string{"python3", "-m", gmsCheckpointSaverModule},
		Env: []corev1.EnvVar{
			{Name: "POD_NAME", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{FieldPath: "metadata.name"}}},
			{Name: "POD_NAMESPACE", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{FieldPath: "metadata.namespace"}}},
			{Name: "TMPDIR", Value: gmsruntime.SharedMountPath},
			{Name: "GMS_SOCKET_DIR", Value: gmsruntime.SharedMountPath},
			{Name: "GMS_CONTROL_DIR", Value: gmsruntime.ControlDir},
		},
		VolumeMounts: []corev1.VolumeMount{
			{Name: gmsruntime.SharedVolumeName, MountPath: gmsruntime.SharedMountPath},
			{Name: gmsruntime.ControlVolumeName, MountPath: gmsruntime.ControlDir},
		},
	}
	return container
}

// ensureGMSCheckpointControl adds the control volume and injects
// GMS_CONTROL_DIR into the GMS server container for checkpoint coordination.
func ensureGMSCheckpointControl(podSpec *corev1.PodSpec) {
	podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{
		Name:         gmsruntime.ControlVolumeName,
		VolumeSource: corev1.VolumeSource{EmptyDir: &corev1.EmptyDirVolumeSource{}},
	})
	server := gmsruntime.FindServerContainer(podSpec)
	if server != nil {
		server.VolumeMounts = append(server.VolumeMounts, corev1.VolumeMount{Name: gmsruntime.ControlVolumeName, MountPath: gmsruntime.ControlDir})
		server.Env = append(server.Env, corev1.EnvVar{Name: "GMS_CONTROL_DIR", Value: gmsruntime.ControlDir})
	}
}

func copyGMSDeviceClaims(mainContainer *corev1.Container, container *corev1.Container) {
	if mainContainer == nil || container == nil || len(mainContainer.Resources.Claims) == 0 {
		return
	}
	container.Resources.Claims = append([]corev1.ResourceClaim{}, mainContainer.Resources.Claims...)
}

func ensureCheckpointVolume(podSpec *corev1.PodSpec, pvcName string) {
	if pvcName == "" {
		return
	}
	for i := range podSpec.Volumes {
		if podSpec.Volumes[i].Name == snapshotprotocol.CheckpointVolumeName {
			return
		}
	}
	podSpec.Volumes = append(podSpec.Volumes, corev1.Volume{
		Name: snapshotprotocol.CheckpointVolumeName,
		VolumeSource: corev1.VolumeSource{
			PersistentVolumeClaim: &corev1.PersistentVolumeClaimVolumeSource{ClaimName: pvcName},
		},
	})
}