cuda.go 8.63 KB
Newer Older
1
2
3
4
5
6
7
// Package cuda provides CUDA checkpoint and restore operations.
package cuda

import (
	"context"
	"fmt"
	"os/exec"
8
	"regexp"
9
10
	"strconv"
	"strings"
11
	"time"
12
13
14
15

	"github.com/go-logr/logr"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
16
	"k8s.io/client-go/kubernetes"
17
18
19
	podresourcesv1 "k8s.io/kubelet/pkg/apis/podresources/v1"
)

20
21
22
23
const (
	nvidiaGPUResource  = "nvidia.com/gpu"
	nvidiaGPUDRADriver = "gpu.nvidia.com"
)
24
25

var podResourcesSocketPath = "/var/lib/kubelet/pod-resources/kubelet.sock"
26

27
28
var gpuUUIDPattern = regexp.MustCompile(`^GPU-[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$`)

29
30
31
32
33
34
35
36
type CheckpointPhaseTimings struct {
	TotalDuration time.Duration
}

type RestorePhaseTimings struct {
	TotalDuration time.Duration
}

37
38
// GetPodGPUUUIDs resolves GPU UUIDs for a pod/container from kubelet
// PodResources (nvidia.com/gpu entries in GetDevices()).
39
40
41
42
43
func GetPodGPUUUIDs(ctx context.Context, podName, podNamespace, containerName string) ([]string, error) {
	if podName == "" || podNamespace == "" {
		return nil, nil
	}

44
45
	conn, err := grpc.NewClient(
		"unix://"+podResourcesSocketPath,
46
47
48
49
50
51
52
53
54
55
56
57
58
		grpc.WithTransportCredentials(insecure.NewCredentials()),
	)
	if err != nil {
		return nil, err
	}
	defer conn.Close()

	client := podresourcesv1.NewPodResourcesListerClient(conn)
	resp, err := client.List(ctx, &podresourcesv1.ListPodResourcesRequest{})
	if err != nil {
		return nil, err
	}

59
	var uuids []string
60
61
62
63
64
65
66
67
68
69
	for _, pod := range resp.GetPodResources() {
		if pod.GetName() != podName || pod.GetNamespace() != podNamespace {
			continue
		}
		for _, container := range pod.GetContainers() {
			if containerName != "" && container.GetName() != containerName {
				continue
			}
			for _, device := range container.GetDevices() {
				if device.GetResourceName() == nvidiaGPUResource {
70
					uuids = append(uuids, device.GetDeviceIds()...)
71
72
				}
			}
73

74
75
76
		}
	}

77
	return uuids, nil
78
79
}

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
// GetGPUUUIDsViaNvidiaSmi discovers GPU UUIDs by running nvidia-smi inside the
// container's mount namespace. This is the fallback path when the kubelet
// PodResources API does not report GPU devices (e.g. when GPUs are allocated
// via DRA instead of the NVIDIA device plugin).
func GetGPUUUIDsViaNvidiaSmi(ctx context.Context, hostProcPath string, pid int) ([]string, error) {
	mountPath := fmt.Sprintf("%s/%d/ns/mnt", strings.TrimRight(hostProcPath, "/"), pid)
	cmd := exec.CommandContext(
		ctx,
		"nsenter",
		fmt.Sprintf("--mount=%s", mountPath),
		"--",
		"nvidia-smi", "--query-gpu=gpu_uuid", "--format=csv,noheader",
	)
	output, err := cmd.Output()
	if err != nil {
		return nil, fmt.Errorf("nvidia-smi via nsenter (pid %d) failed: %w", pid, err)
	}
	var uuids []string
	for _, line := range strings.Split(strings.TrimSpace(string(output)), "\n") {
		line = strings.TrimSpace(line)
		if line != "" {
			uuids = append(uuids, line)
		}
	}
	return uuids, nil
}

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
// DiscoverGPUUUIDs resolves GPU UUIDs according to the pod's allocation mode:
// DRA-backed pods use the DRA API, classic nvidia.com/gpu pods use PodResources,
// and nvidia-smi remains the last fallback for either path.
func DiscoverGPUUUIDs(ctx context.Context, clientset kubernetes.Interface, podName, podNamespace, containerName, hostProcPath string, pid int, log logr.Logger) ([]string, error) {
	gpuUUIDs, hasNVIDIADRAAllocation, err := GetGPUUUIDsViaDRAAPI(ctx, clientset, podName, podNamespace, log)
	fallbackReason := "DRA API returned no GPU UUIDs"
	if err != nil {
		log.Error(
			err,
			"DRA API GPU UUID lookup failed, trying other discovery paths",
			"pod", podNamespace+"/"+podName,
			"has_nvidia_dra_allocation", hasNVIDIADRAAllocation,
		)
		gpuUUIDs = nil
		fallbackReason = "DRA API GPU UUID lookup failed"
	}
	if len(gpuUUIDs) > 0 {
		return gpuUUIDs, nil
	}
	if !hasNVIDIADRAAllocation {
		gpuUUIDs, err = GetPodGPUUUIDs(ctx, podName, podNamespace, containerName)
		if err != nil {
			return nil, fmt.Errorf("PodResources GPU UUID lookup failed: %w", err)
		}
		if len(gpuUUIDs) > 0 {
			return gpuUUIDs, nil
		}
		fallbackReason = "PodResources API returned no GPU UUIDs"
	}

	log.Info(fallbackReason+", falling back to nvidia-smi", "pid", pid)
	gpuUUIDs, err = GetGPUUUIDsViaNvidiaSmi(ctx, hostProcPath, pid)
	if err != nil {
		return nil, fmt.Errorf("nvidia-smi GPU UUID fallback failed: %w", err)
	}
	log.Info("nvidia-smi fallback discovered GPU UUIDs", "uuids", gpuUUIDs)
	return gpuUUIDs, nil
}

146
147
148
149
150
// FilterProcesses returns the subset of candidate PIDs that hold actual CUDA contexts.
// Uses --get-restore-tid (the same technique as the CRIU CUDA plugin) instead of
// --get-state, because --get-state incorrectly matches coordinator processes like
// cuda-checkpoint --launch-job that share a /proc namespace with CUDA processes but
// don't hold CUDA contexts themselves.
151
152
153
154
155
156
func FilterProcesses(ctx context.Context, allPIDs []int, log logr.Logger) []int {
	cudaPIDs := make([]int, 0, len(allPIDs))
	for _, pid := range allPIDs {
		if pid <= 0 {
			continue
		}
157
		cmd := exec.CommandContext(ctx, cudaCheckpointHelperBinary, "--get-restore-tid", "--pid", strconv.Itoa(pid))
158
159
		output, err := cmd.CombinedOutput()
		if err != nil {
160
161
162
			if ctx.Err() != nil {
				break
			}
163
			log.V(1).Info("CUDA restore-tid probe negative", "pid", pid)
164
165
			continue
		}
166
167
		tid := strings.TrimSpace(string(output))
		log.V(1).Info("CUDA restore-tid probe positive", "pid", pid, "tid", tid)
168
169
170
171
172
		cudaPIDs = append(cudaPIDs, pid)
	}
	return cudaPIDs
}

173
// BuildDeviceMap creates a cuda-checkpoint-helper --device-map value from source and target GPU UUID lists.
174
175
176
// When a source UUID exists in the target set, it maps to itself (identity mapping) to avoid
// unnecessary cross-GPU restore on same-node restores where kubelet returns GPUs in different order.
// Remaining unmatched source UUIDs are paired with remaining unmatched target UUIDs positionally.
177
func BuildDeviceMap(sourceUUIDs, targetUUIDs []string, log logr.Logger) (string, error) {
178
179
180
181
182
183
	if len(sourceUUIDs) != len(targetUUIDs) {
		return "", fmt.Errorf("GPU count mismatch: source has %d, target has %d", len(sourceUUIDs), len(targetUUIDs))
	}
	if len(sourceUUIDs) == 0 {
		return "", fmt.Errorf("GPU UUID list is empty")
	}
184
	log.V(1).Info("BuildDeviceMap inputs", "source_uuids", sourceUUIDs, "target_uuids", targetUUIDs)
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

	targetSet := make(map[string]bool, len(targetUUIDs))
	for _, t := range targetUUIDs {
		targetSet[t] = true
	}

	// First pass: identity-map any source UUID that exists in the target set
	mapping := make(map[string]string, len(sourceUUIDs))
	usedTargets := make(map[string]bool, len(targetUUIDs))
	for _, src := range sourceUUIDs {
		if targetSet[src] {
			mapping[src] = src
			usedTargets[src] = true
		}
	}

	// Second pass: pair remaining source UUIDs with remaining target UUIDs positionally
	var remainingTargets []string
	for _, t := range targetUUIDs {
		if !usedTargets[t] {
			remainingTargets = append(remainingTargets, t)
		}
	}
	idx := 0
	for _, src := range sourceUUIDs {
		if _, ok := mapping[src]; !ok {
			mapping[src] = remainingTargets[idx]
			idx++
		}
	}

216
	pairs := make([]string, len(sourceUUIDs))
217
218
	for i, src := range sourceUUIDs {
		pairs[i] = src + "=" + mapping[src]
219
220
221
222
223
	}
	return strings.Join(pairs, ","), nil
}

// LockAndCheckpointProcessTree locks and checkpoints CUDA state for all given PIDs.
224
// On failure, the caller is expected to fail the operation and terminate the workload.
225
226
227
228
func LockAndCheckpointProcessTree(ctx context.Context, cudaPIDs []int, log logr.Logger) (CheckpointPhaseTimings, error) {
	var timings CheckpointPhaseTimings

	start := time.Now()
229
230
	for _, pid := range cudaPIDs {
		if err := lock(ctx, pid, log); err != nil {
231
232
			timings.TotalDuration = time.Since(start)
			return timings, err
233
234
235
236
237
		}
	}

	for _, pid := range cudaPIDs {
		if err := checkpoint(ctx, pid, log); err != nil {
238
239
			timings.TotalDuration = time.Since(start)
			return timings, err
240
241
		}
	}
242
	timings.TotalDuration = time.Since(start)
243

244
	return timings, nil
245
246
247
}

// RestoreAndUnlockProcessTree restores and unlocks CUDA state for the given PIDs.
248
249
250
251
func RestoreAndUnlockProcessTree(ctx context.Context, cudaPIDs []int, deviceMap string, log logr.Logger) (RestorePhaseTimings, error) {
	var timings RestorePhaseTimings

	start := time.Now()
252
253
	for _, pid := range cudaPIDs {
		if err := restoreProcess(ctx, pid, deviceMap, log); err != nil {
254
255
			timings.TotalDuration = time.Since(start)
			return timings, err
256
257
		}
	}
258

259
260
	for _, pid := range cudaPIDs {
		if err := unlock(ctx, pid, log); err != nil {
261
			timings.TotalDuration = time.Since(start)
262
263
			state, stateErr := getState(ctx, pid)
			if stateErr == nil && state == "running" {
264
				log.Info("cuda-checkpoint-helper unlock returned error but process is already running", "pid", pid)
265
266
				continue
			}
267
			return timings, err
268
269
		}
	}
270
271
272
	timings.TotalDuration = time.Since(start)

	return timings, nil
273
}