// Package cuda provides CUDA checkpoint and restore operations. package cuda import ( "context" "fmt" "os/exec" "regexp" "strconv" "strings" "github.com/go-logr/logr" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" podresourcesv1 "k8s.io/kubelet/pkg/apis/podresources/v1" ) const ( nvidiaGPUResource = "nvidia.com/gpu" nvidiaGPUDRADriver = "gpu.nvidia.com" ) var podResourcesSocketPath = "/var/lib/kubelet/pod-resources/kubelet.sock" var gpuUUIDPattern = regexp.MustCompile(`^GPU-[a-fA-F0-9]{8}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{4}-[a-fA-F0-9]{12}$`) // GetPodGPUUUIDs resolves GPU UUIDs for a pod/container from kubelet // PodResources (nvidia.com/gpu entries in GetDevices()). func GetPodGPUUUIDs(ctx context.Context, podName, podNamespace, containerName string) ([]string, error) { if podName == "" || podNamespace == "" { return nil, nil } conn, err := grpc.NewClient( "unix://"+podResourcesSocketPath, grpc.WithTransportCredentials(insecure.NewCredentials()), ) if err != nil { return nil, err } defer conn.Close() client := podresourcesv1.NewPodResourcesListerClient(conn) resp, err := client.List(ctx, &podresourcesv1.ListPodResourcesRequest{}) if err != nil { return nil, err } var uuids []string for _, pod := range resp.GetPodResources() { if pod.GetName() != podName || pod.GetNamespace() != podNamespace { continue } for _, container := range pod.GetContainers() { if containerName != "" && container.GetName() != containerName { continue } for _, device := range container.GetDevices() { if device.GetResourceName() == nvidiaGPUResource { uuids = append(uuids, device.GetDeviceIds()...) } } } } return uuids, nil } // GetGPUUUIDsViaNvidiaSmi discovers GPU UUIDs by running nvidia-smi inside the // container's mount namespace. This is the fallback path when the kubelet // PodResources API does not report GPU devices (e.g. when GPUs are allocated // via DRA instead of the NVIDIA device plugin). func GetGPUUUIDsViaNvidiaSmi(ctx context.Context, hostProcPath string, pid int) ([]string, error) { mountPath := fmt.Sprintf("%s/%d/ns/mnt", strings.TrimRight(hostProcPath, "/"), pid) cmd := exec.CommandContext( ctx, "nsenter", fmt.Sprintf("--mount=%s", mountPath), "--", "nvidia-smi", "--query-gpu=gpu_uuid", "--format=csv,noheader", ) output, err := cmd.Output() if err != nil { return nil, fmt.Errorf("nvidia-smi via nsenter (pid %d) failed: %w", pid, err) } var uuids []string for _, line := range strings.Split(strings.TrimSpace(string(output)), "\n") { line = strings.TrimSpace(line) if line != "" { uuids = append(uuids, line) } } return uuids, nil } // FilterProcesses returns the subset of candidate PIDs that hold actual CUDA contexts. // Uses --get-restore-tid (the same technique as the CRIU CUDA plugin) instead of // --get-state, because --get-state incorrectly matches coordinator processes like // cuda-checkpoint --launch-job that share a /proc namespace with CUDA processes but // don't hold CUDA contexts themselves. func FilterProcesses(ctx context.Context, allPIDs []int, log logr.Logger) []int { cudaPIDs := make([]int, 0, len(allPIDs)) for _, pid := range allPIDs { if pid <= 0 { continue } cmd := exec.CommandContext(ctx, cudaCheckpointBinary, "--get-restore-tid", "--pid", strconv.Itoa(pid)) output, err := cmd.CombinedOutput() if err != nil { if ctx.Err() != nil { break } log.V(1).Info("CUDA restore-tid probe negative", "pid", pid) continue } tid := strings.TrimSpace(string(output)) log.V(1).Info("CUDA restore-tid probe positive", "pid", pid, "tid", tid) cudaPIDs = append(cudaPIDs, pid) } return cudaPIDs } // BuildDeviceMap creates a cuda-checkpoint --device-map value from source and target GPU UUID lists. // When a source UUID exists in the target set, it maps to itself (identity mapping) to avoid // unnecessary cross-GPU restore on same-node restores where kubelet returns GPUs in different order. // Remaining unmatched source UUIDs are paired with remaining unmatched target UUIDs positionally. func BuildDeviceMap(sourceUUIDs, targetUUIDs []string, log logr.Logger) (string, error) { if len(sourceUUIDs) != len(targetUUIDs) { return "", fmt.Errorf("GPU count mismatch: source has %d, target has %d", len(sourceUUIDs), len(targetUUIDs)) } if len(sourceUUIDs) == 0 { return "", fmt.Errorf("GPU UUID list is empty") } log.V(1).Info("BuildDeviceMap inputs", "source_uuids", sourceUUIDs, "target_uuids", targetUUIDs) targetSet := make(map[string]bool, len(targetUUIDs)) for _, t := range targetUUIDs { targetSet[t] = true } // First pass: identity-map any source UUID that exists in the target set mapping := make(map[string]string, len(sourceUUIDs)) usedTargets := make(map[string]bool, len(targetUUIDs)) for _, src := range sourceUUIDs { if targetSet[src] { mapping[src] = src usedTargets[src] = true } } // Second pass: pair remaining source UUIDs with remaining target UUIDs positionally var remainingTargets []string for _, t := range targetUUIDs { if !usedTargets[t] { remainingTargets = append(remainingTargets, t) } } idx := 0 for _, src := range sourceUUIDs { if _, ok := mapping[src]; !ok { mapping[src] = remainingTargets[idx] idx++ } } pairs := make([]string, len(sourceUUIDs)) for i, src := range sourceUUIDs { pairs[i] = src + "=" + mapping[src] } return strings.Join(pairs, ","), nil } // LockAndCheckpointProcessTree locks and checkpoints CUDA state for all given PIDs. // On failure, the caller is expected to fail the operation and terminate the workload. func LockAndCheckpointProcessTree(ctx context.Context, cudaPIDs []int, log logr.Logger) error { for _, pid := range cudaPIDs { if err := lock(ctx, pid, log); err != nil { return err } } for _, pid := range cudaPIDs { if err := checkpoint(ctx, pid, log); err != nil { return err } } return nil } // RestoreAndUnlockProcessTree restores and unlocks CUDA state for the given PIDs. func RestoreAndUnlockProcessTree(ctx context.Context, cudaPIDs []int, deviceMap string, log logr.Logger) error { for _, pid := range cudaPIDs { if err := restoreProcess(ctx, pid, deviceMap, log); err != nil { return err } } for _, pid := range cudaPIDs { if err := unlock(ctx, pid, log); err != nil { state, stateErr := getState(ctx, pid) if stateErr == nil && state == "running" { log.Info("cuda-checkpoint unlock returned error but process is already running", "pid", pid) continue } return err } } return nil }