"examples/common/launch_utils.sh" did not exist on "c95bfc2e59374ad225344b5638cc199f6814d0d1"
restore.go 6.98 KB
Newer Older
1
package executor
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"os"
	"os/exec"
	"path/filepath"
	"strconv"
	"strings"
	"time"

	"github.com/containerd/containerd"
	"github.com/go-logr/logr"
17
	"k8s.io/client-go/kubernetes"
18

19
20
21
22
23
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/criu"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/cuda"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/logging"
	snapshotruntime "github.com/ai-dynamo/dynamo/deploy/snapshot/internal/runtime"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/types"
24
25
26
27
)

// RestoreRequest holds the parameters for a restore operation.
type RestoreRequest struct {
28
29
30
31
32
33
34
	CheckpointID       string
	CheckpointLocation string
	NSRestorePath      string
	PodName            string
	PodNamespace       string
	ContainerName      string
	Clientset          kubernetes.Interface
35
36
37
38
39
40
41
42
43
}

// Restore performs external restore for the given request.
// Returns the namespace-relative PID of the restored process.
// The DaemonSet side inspects the placeholder and launches nsrestore,
// which handles rootfs application, CRIU restore, and CUDA restore inside the namespace.
func Restore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req RestoreRequest) (int, error) {
	restoreStart := time.Now()
	log.Info("=== Starting external restore ===",
44
		"checkpoint_id", req.CheckpointID,
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
		"pod", req.PodName,
		"namespace", req.PodNamespace,
		"container", req.ContainerName,
	)

	// Phase 1: Inspect — resolve placeholder, discover target GPUs, build device map
	snap, err := inspectRestore(ctx, ctrd, log, req)
	if err != nil {
		return 0, err
	}

	// Phase 2: Execute — nsrestore handles rootfs, CRIU restore, and CUDA restore inside namespace
	restoredPID, err := execNSRestore(ctx, log, req, snap)
	if err != nil {
		return 0, fmt.Errorf("nsrestore failed: %w", err)
	}
	log.Info("nsrestore completed", "restored_pid", restoredPID)

	// Validate restored process from the host side
	procRoot := filepath.Join(snap.TargetRoot, "proc")
65
	if err := snapshotruntime.ValidateProcessState(procRoot, restoredPID); err != nil {
66
67
68
69
70
71
72
73
74
75
76
		restoreLogPath := filepath.Join(snap.TargetRoot, "var", "criu-work", criu.RestoreLogFilename)
		logging.LogProcessDiagnostics(procRoot, restoredPID, restoreLogPath, log)
		return 0, fmt.Errorf("restored process failed post-restore validation: %w", err)
	}

	log.Info("=== External restore completed ===", "total_duration", time.Since(restoreStart))

	return restoredPID, nil
}

func inspectRestore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req RestoreRequest) (*types.RestoreContainerSnapshot, error) {
77
78
79
80
81
82
	if req.CheckpointLocation == "" {
		return nil, fmt.Errorf("checkpoint location is required")
	}

	checkpointPath := req.CheckpointLocation
	baseAbs, err := filepath.Abs(filepath.Dir(checkpointPath))
83
84
85
86
87
88
89
90
	if err != nil {
		return nil, fmt.Errorf("failed to resolve checkpoint base path: %w", err)
	}
	checkpointAbs, err := filepath.Abs(checkpointPath)
	if err != nil {
		return nil, fmt.Errorf("failed to resolve checkpoint path: %w", err)
	}
	if checkpointAbs != baseAbs && !strings.HasPrefix(checkpointAbs, baseAbs+string(os.PathSeparator)) {
91
		return nil, fmt.Errorf("invalid checkpoint id %q", req.CheckpointID)
92
93
94
95
96
97
98
99
100
101
102
103
	}

	m, err := types.ReadManifest(checkpointPath)
	if err != nil {
		return nil, fmt.Errorf("failed to read checkpoint manifest: %w", err)
	}

	containerName := req.ContainerName
	if containerName == "" {
		containerName = "main"
	}

104
	placeholderPID, _, err := snapshotruntime.ResolveContainerByPod(ctx, ctrd, req.PodName, req.PodNamespace, containerName)
105
106
107
108
109
	if err != nil {
		return nil, fmt.Errorf("failed to resolve placeholder container: %w", err)
	}
	log.Info("Resolved placeholder container", "pid", placeholderPID)

110
	cgroupRoot, err := snapshotruntime.ResolveCgroupRootFromHostPID(placeholderPID)
111
112
113
114
115
116
117
118
119
120
121
122
123
124
	if err != nil {
		log.Error(err, "Failed to resolve placeholder cgroup root; proceeding without explicit cgroup remap")
		cgroupRoot = ""
	}

	cudaDeviceMap := ""
	if !m.CUDA.IsEmpty() {
		if len(m.CUDA.SourceGPUUUIDs) == 0 {
			return nil, fmt.Errorf("missing source GPU UUIDs in checkpoint manifest")
		}
		targetGPUUUIDs, err := cuda.GetPodGPUUUIDs(ctx, req.PodName, req.PodNamespace, containerName)
		if err != nil {
			return nil, fmt.Errorf("failed to get target GPU UUIDs: %w", err)
		}
125
126
		if len(targetGPUUUIDs) == 0 {
			log.Info("PodResources API returned no target GPU UUIDs, falling back to nvidia-smi", "pid", placeholderPID)
127
			targetGPUUUIDs, err = cuda.GetGPUUUIDsViaNvidiaSmi(ctx, snapshotruntime.HostProcPath, placeholderPID)
128
129
130
131
132
			if err != nil {
				return nil, fmt.Errorf("nvidia-smi GPU UUID fallback failed for restore target: %w", err)
			}
			log.Info("nvidia-smi fallback discovered target GPU UUIDs", "uuids", targetGPUUUIDs)
		}
133
134
135
		if len(targetGPUUUIDs) == 0 {
			return nil, fmt.Errorf("missing target GPU UUIDs for %s/%s container %s", req.PodNamespace, req.PodName, containerName)
		}
136
		cudaDeviceMap, err = cuda.BuildDeviceMap(m.CUDA.SourceGPUUUIDs, targetGPUUUIDs, log)
137
138
139
		if err != nil {
			return nil, fmt.Errorf("failed to build CUDA device map: %w", err)
		}
140
141
142
143
144
		log.Info("GPU UUIDs for device map",
			"source_uuids", m.CUDA.SourceGPUUUIDs,
			"target_uuids", targetGPUUUIDs,
			"device_map", cudaDeviceMap,
		)
145
146
147
148
149
	}

	return &types.RestoreContainerSnapshot{
		CheckpointPath: checkpointPath,
		PlaceholderPID: placeholderPID,
150
		TargetRoot:     fmt.Sprintf("%s/%d/root", snapshotruntime.HostProcPath, placeholderPID),
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
		CgroupRoot:     cgroupRoot,
		CUDADeviceMap:  cudaDeviceMap,
	}, nil
}

// execNSRestore launches the nsrestore binary inside the placeholder container's
// namespaces via nsenter and parses the restored PID from stdout JSON.
func execNSRestore(ctx context.Context, log logr.Logger, req RestoreRequest, snap *types.RestoreContainerSnapshot) (int, error) {
	args := []string{
		"-t", strconv.Itoa(snap.PlaceholderPID),
		// Intentionally exclude cgroup namespace (-C): CRIU must manage cgroups
		// from the host-visible hierarchy so --cgroup-root remap works.
		"-m", "-u", "-i", "-n", "-p",
		"--", req.NSRestorePath,
		"--checkpoint-path", snap.CheckpointPath,
	}
	if snap.CUDADeviceMap != "" {
		args = append(args, "--cuda-device-map", snap.CUDADeviceMap)
	}
	if snap.CgroupRoot != "" {
		args = append(args, "--cgroup-root", snap.CgroupRoot)
	}

	cmd := exec.CommandContext(ctx, "nsenter", args...)
175
176
	// Inherit the agent environment so nsrestore uses the same logger settings.
	cmd.Env = os.Environ()
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
	log.V(1).Info("Executing nsenter + nsrestore", "cmd", cmd.String())

	var stdout bytes.Buffer
	cmd.Stdout = &stdout
	cmd.Stderr = os.Stderr

	if err := cmd.Run(); err != nil {
		return 0, fmt.Errorf("nsrestore failed: %w\nstdout: %s", err, stdout.String())
	}

	var result struct {
		RestoredPID int `json:"restoredPID"`
	}
	if err := json.Unmarshal(stdout.Bytes(), &result); err != nil {
		return 0, fmt.Errorf("failed to parse nsrestore result: %w\nstdout: %s", err, stdout.String())
	}
	if result.RestoredPID <= 0 {
		return 0, fmt.Errorf("nsrestore returned invalid PID %d", result.RestoredPID)
	}

	return result.RestoredPID, nil
}