cuda.go 7.11 KB
Newer Older
1
2
3
4
5
6
7
// Package cuda provides CUDA checkpoint and restore operations.
package cuda

import (
	"context"
	"fmt"
	"os/exec"
8
	"regexp"
9
10
	"strconv"
	"strings"
11
	"time"
12
13
14
15
16
17
18
19

	"github.com/go-logr/logr"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"

	podresourcesv1 "k8s.io/kubelet/pkg/apis/podresources/v1"
)

20
21
22
23
const (
	nvidiaGPUResource  = "nvidia.com/gpu"
	nvidiaGPUDRADriver = "gpu.nvidia.com"
)
24
25

var podResourcesSocketPath = "/var/lib/kubelet/pod-resources/kubelet.sock"
26

27
28
var gpuUUIDPattern = regexp.MustCompile(`^GPU-[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$`)

29
30
31
32
33
34
35
36
type CheckpointPhaseTimings struct {
	TotalDuration time.Duration
}

type RestorePhaseTimings struct {
	TotalDuration time.Duration
}

37
38
// GetPodGPUUUIDs resolves GPU UUIDs for a pod/container from kubelet
// PodResources (nvidia.com/gpu entries in GetDevices()).
39
40
41
42
43
func GetPodGPUUUIDs(ctx context.Context, podName, podNamespace, containerName string) ([]string, error) {
	if podName == "" || podNamespace == "" {
		return nil, nil
	}

44
45
	conn, err := grpc.NewClient(
		"unix://"+podResourcesSocketPath,
46
47
48
49
50
51
52
53
54
55
56
57
58
		grpc.WithTransportCredentials(insecure.NewCredentials()),
	)
	if err != nil {
		return nil, err
	}
	defer conn.Close()

	client := podresourcesv1.NewPodResourcesListerClient(conn)
	resp, err := client.List(ctx, &podresourcesv1.ListPodResourcesRequest{})
	if err != nil {
		return nil, err
	}

59
	var uuids []string
60
61
62
63
64
65
66
67
68
69
	for _, pod := range resp.GetPodResources() {
		if pod.GetName() != podName || pod.GetNamespace() != podNamespace {
			continue
		}
		for _, container := range pod.GetContainers() {
			if containerName != "" && container.GetName() != containerName {
				continue
			}
			for _, device := range container.GetDevices() {
				if device.GetResourceName() == nvidiaGPUResource {
70
					uuids = append(uuids, device.GetDeviceIds()...)
71
72
				}
			}
73

74
75
76
		}
	}

77
	return uuids, nil
78
79
}

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
// GetGPUUUIDsViaNvidiaSmi discovers GPU UUIDs by running nvidia-smi inside the
// container's mount namespace. This is the fallback path when the kubelet
// PodResources API does not report GPU devices (e.g. when GPUs are allocated
// via DRA instead of the NVIDIA device plugin).
func GetGPUUUIDsViaNvidiaSmi(ctx context.Context, hostProcPath string, pid int) ([]string, error) {
	mountPath := fmt.Sprintf("%s/%d/ns/mnt", strings.TrimRight(hostProcPath, "/"), pid)
	cmd := exec.CommandContext(
		ctx,
		"nsenter",
		fmt.Sprintf("--mount=%s", mountPath),
		"--",
		"nvidia-smi", "--query-gpu=gpu_uuid", "--format=csv,noheader",
	)
	output, err := cmd.Output()
	if err != nil {
		return nil, fmt.Errorf("nvidia-smi via nsenter (pid %d) failed: %w", pid, err)
	}
	var uuids []string
	for _, line := range strings.Split(strings.TrimSpace(string(output)), "\n") {
		line = strings.TrimSpace(line)
		if line != "" {
			uuids = append(uuids, line)
		}
	}
	return uuids, nil
}

107
108
109
110
111
// FilterProcesses returns the subset of candidate PIDs that hold actual CUDA contexts.
// Uses --get-restore-tid (the same technique as the CRIU CUDA plugin) instead of
// --get-state, because --get-state incorrectly matches coordinator processes like
// cuda-checkpoint --launch-job that share a /proc namespace with CUDA processes but
// don't hold CUDA contexts themselves.
112
113
114
115
116
117
func FilterProcesses(ctx context.Context, allPIDs []int, log logr.Logger) []int {
	cudaPIDs := make([]int, 0, len(allPIDs))
	for _, pid := range allPIDs {
		if pid <= 0 {
			continue
		}
118
		cmd := exec.CommandContext(ctx, cudaCheckpointHelperBinary, "--get-restore-tid", "--pid", strconv.Itoa(pid))
119
120
		output, err := cmd.CombinedOutput()
		if err != nil {
121
122
123
			if ctx.Err() != nil {
				break
			}
124
			log.V(1).Info("CUDA restore-tid probe negative", "pid", pid)
125
126
			continue
		}
127
128
		tid := strings.TrimSpace(string(output))
		log.V(1).Info("CUDA restore-tid probe positive", "pid", pid, "tid", tid)
129
130
131
132
133
		cudaPIDs = append(cudaPIDs, pid)
	}
	return cudaPIDs
}

134
// BuildDeviceMap creates a cuda-checkpoint-helper --device-map value from source and target GPU UUID lists.
135
136
137
// When a source UUID exists in the target set, it maps to itself (identity mapping) to avoid
// unnecessary cross-GPU restore on same-node restores where kubelet returns GPUs in different order.
// Remaining unmatched source UUIDs are paired with remaining unmatched target UUIDs positionally.
138
func BuildDeviceMap(sourceUUIDs, targetUUIDs []string, log logr.Logger) (string, error) {
139
140
141
142
143
144
	if len(sourceUUIDs) != len(targetUUIDs) {
		return "", fmt.Errorf("GPU count mismatch: source has %d, target has %d", len(sourceUUIDs), len(targetUUIDs))
	}
	if len(sourceUUIDs) == 0 {
		return "", fmt.Errorf("GPU UUID list is empty")
	}
145
	log.V(1).Info("BuildDeviceMap inputs", "source_uuids", sourceUUIDs, "target_uuids", targetUUIDs)
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176

	targetSet := make(map[string]bool, len(targetUUIDs))
	for _, t := range targetUUIDs {
		targetSet[t] = true
	}

	// First pass: identity-map any source UUID that exists in the target set
	mapping := make(map[string]string, len(sourceUUIDs))
	usedTargets := make(map[string]bool, len(targetUUIDs))
	for _, src := range sourceUUIDs {
		if targetSet[src] {
			mapping[src] = src
			usedTargets[src] = true
		}
	}

	// Second pass: pair remaining source UUIDs with remaining target UUIDs positionally
	var remainingTargets []string
	for _, t := range targetUUIDs {
		if !usedTargets[t] {
			remainingTargets = append(remainingTargets, t)
		}
	}
	idx := 0
	for _, src := range sourceUUIDs {
		if _, ok := mapping[src]; !ok {
			mapping[src] = remainingTargets[idx]
			idx++
		}
	}

177
	pairs := make([]string, len(sourceUUIDs))
178
179
	for i, src := range sourceUUIDs {
		pairs[i] = src + "=" + mapping[src]
180
181
182
183
184
	}
	return strings.Join(pairs, ","), nil
}

// LockAndCheckpointProcessTree locks and checkpoints CUDA state for all given PIDs.
185
// On failure, the caller is expected to fail the operation and terminate the workload.
186
187
188
189
func LockAndCheckpointProcessTree(ctx context.Context, cudaPIDs []int, log logr.Logger) (CheckpointPhaseTimings, error) {
	var timings CheckpointPhaseTimings

	start := time.Now()
190
191
	for _, pid := range cudaPIDs {
		if err := lock(ctx, pid, log); err != nil {
192
193
			timings.TotalDuration = time.Since(start)
			return timings, err
194
195
196
197
198
		}
	}

	for _, pid := range cudaPIDs {
		if err := checkpoint(ctx, pid, log); err != nil {
199
200
			timings.TotalDuration = time.Since(start)
			return timings, err
201
202
		}
	}
203
	timings.TotalDuration = time.Since(start)
204

205
	return timings, nil
206
207
208
}

// RestoreAndUnlockProcessTree restores and unlocks CUDA state for the given PIDs.
209
210
211
212
func RestoreAndUnlockProcessTree(ctx context.Context, cudaPIDs []int, deviceMap string, log logr.Logger) (RestorePhaseTimings, error) {
	var timings RestorePhaseTimings

	start := time.Now()
213
214
	for _, pid := range cudaPIDs {
		if err := restoreProcess(ctx, pid, deviceMap, log); err != nil {
215
216
			timings.TotalDuration = time.Since(start)
			return timings, err
217
218
		}
	}
219

220
221
	for _, pid := range cudaPIDs {
		if err := unlock(ctx, pid, log); err != nil {
222
			timings.TotalDuration = time.Since(start)
223
224
			state, stateErr := getState(ctx, pid)
			if stateErr == nil && state == "running" {
225
				log.Info("cuda-checkpoint-helper unlock returned error but process is already running", "pid", pid)
226
227
				continue
			}
228
			return timings, err
229
230
		}
	}
231
232
233
	timings.TotalDuration = time.Since(start)

	return timings, nil
234
}