restore.go 7.74 KB
Newer Older
1
package executor
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"os"
	"os/exec"
	"path/filepath"
	"strconv"
	"strings"
	"time"

	"github.com/containerd/containerd"
	"github.com/go-logr/logr"
17
	"k8s.io/client-go/kubernetes"
18

19
20
21
22
23
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/criu"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/cuda"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/logging"
	snapshotruntime "github.com/ai-dynamo/dynamo/deploy/snapshot/internal/runtime"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/types"
24
25
26
27
)

// RestoreRequest holds the parameters for a restore operation.
type RestoreRequest struct {
28
29
	CheckpointID       string
	CheckpointLocation string
30
	StartedAt          time.Time
31
32
33
34
35
	NSRestorePath      string
	PodName            string
	PodNamespace       string
	ContainerName      string
	Clientset          kubernetes.Interface
36
37
38
39
40
41
}

// Restore performs external restore for the given request.
// Returns the namespace-relative PID of the restored process.
// The DaemonSet side inspects the placeholder and launches nsrestore,
// which handles rootfs application, CRIU restore, and CUDA restore inside the namespace.
42
43
44
45
//
// Returns the placeholder container's host PID so callers can reach into the
// container's mount namespace (e.g. to write sentinels under /snapshot-control)
// without re-resolving via containerd.
46
47
48
func Restore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req RestoreRequest) (int, error) {
	restoreStart := time.Now()
	log.Info("=== Starting external restore ===",
49
		"checkpoint_id", req.CheckpointID,
50
51
52
53
54
		"pod", req.PodName,
		"namespace", req.PodNamespace,
		"container", req.ContainerName,
	)

55
56
	// Phase 1: Host inspect — resolve placeholder, discover target GPUs, build device map
	hostInspectStart := time.Now()
57
58
59
60
	snap, err := inspectRestore(ctx, ctrd, log, req)
	if err != nil {
		return 0, err
	}
61
	hostInspectDuration := time.Since(hostInspectStart)
62
63

	// Phase 2: Execute — nsrestore handles rootfs, CRIU restore, and CUDA restore inside namespace
64
	result, err := execNSRestore(ctx, log, req, snap)
65
66
67
	if err != nil {
		return 0, fmt.Errorf("nsrestore failed: %w", err)
	}
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
	restoreDuration := hostInspectDuration + result.NSRestoreSetupDuration + result.CRIURestoreDuration + result.CUDADuration
	log.Info("Restore timing summary",
		"restore", map[string]any{
			"duration": restoreDuration.String(),
			"phases": map[string]string{
				"host_inspect_duration":    hostInspectDuration.String(),
				"nsrestore_setup_duration": result.NSRestoreSetupDuration.String(),
				"criu_restore_duration":    result.CRIURestoreDuration.String(),
				"cuda_duration":            result.CUDADuration.String(),
			},
		},
	)
	if !req.StartedAt.IsZero() {
		log.Info("Restore wall time from agent detection",
			"started_to_restore_complete", time.Since(req.StartedAt),
		)
	}
85
86

	// Validate restored process from the host side
87
	validationStart := time.Now()
88
	procRoot := filepath.Join(snap.TargetRoot, "proc")
89
	if err := snapshotruntime.ValidateProcessState(procRoot, result.RestoredPID); err != nil {
90
		restoreLogPath := filepath.Join(snap.TargetRoot, "var", "criu-work", criu.RestoreLogFilename)
91
		logging.LogProcessDiagnostics(procRoot, result.RestoredPID, restoreLogPath, log)
92
93
94
		return 0, fmt.Errorf("restored process failed post-restore validation: %w", err)
	}

95
96
	log.Info("=== External restore completed ===",
		"restored_pid", result.RestoredPID,
97
		"placeholder_host_pid", snap.PlaceholderPID,
98
99
100
		"validation_duration", time.Since(validationStart),
		"total_duration", time.Since(restoreStart),
	)
101

102
	return snap.PlaceholderPID, nil
103
104
105
}

func inspectRestore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req RestoreRequest) (*types.RestoreContainerSnapshot, error) {
106
107
108
109
110
111
	if req.CheckpointLocation == "" {
		return nil, fmt.Errorf("checkpoint location is required")
	}

	checkpointPath := req.CheckpointLocation
	baseAbs, err := filepath.Abs(filepath.Dir(checkpointPath))
112
113
114
115
116
117
118
119
	if err != nil {
		return nil, fmt.Errorf("failed to resolve checkpoint base path: %w", err)
	}
	checkpointAbs, err := filepath.Abs(checkpointPath)
	if err != nil {
		return nil, fmt.Errorf("failed to resolve checkpoint path: %w", err)
	}
	if checkpointAbs != baseAbs && !strings.HasPrefix(checkpointAbs, baseAbs+string(os.PathSeparator)) {
120
		return nil, fmt.Errorf("invalid checkpoint id %q", req.CheckpointID)
121
122
123
124
125
126
127
128
129
130
131
132
	}

	m, err := types.ReadManifest(checkpointPath)
	if err != nil {
		return nil, fmt.Errorf("failed to read checkpoint manifest: %w", err)
	}

	containerName := req.ContainerName
	if containerName == "" {
		containerName = "main"
	}

133
	placeholderPID, _, err := snapshotruntime.ResolveContainerByPod(ctx, ctrd, req.PodName, req.PodNamespace, containerName)
134
135
136
	if err != nil {
		return nil, fmt.Errorf("failed to resolve placeholder container: %w", err)
	}
137
	log.V(1).Info("Resolved placeholder container", "pid", placeholderPID)
138

139
	cgroupRoot, err := snapshotruntime.ResolveCgroupRootFromHostPID(placeholderPID)
140
141
142
143
144
145
146
147
148
149
	if err != nil {
		log.Error(err, "Failed to resolve placeholder cgroup root; proceeding without explicit cgroup remap")
		cgroupRoot = ""
	}

	cudaDeviceMap := ""
	if !m.CUDA.IsEmpty() {
		if len(m.CUDA.SourceGPUUUIDs) == 0 {
			return nil, fmt.Errorf("missing source GPU UUIDs in checkpoint manifest")
		}
150
151
152
153
154
155
156
157
158
159
		targetGPUUUIDs, err := cuda.DiscoverGPUUUIDs(
			ctx,
			req.Clientset,
			req.PodName,
			req.PodNamespace,
			containerName,
			snapshotruntime.HostProcPath,
			placeholderPID,
			log,
		)
160
161
162
163
164
165
		if err != nil {
			return nil, fmt.Errorf("failed to get target GPU UUIDs: %w", err)
		}
		if len(targetGPUUUIDs) == 0 {
			return nil, fmt.Errorf("missing target GPU UUIDs for %s/%s container %s", req.PodNamespace, req.PodName, containerName)
		}
166
		cudaDeviceMap, err = cuda.BuildDeviceMap(m.CUDA.SourceGPUUUIDs, targetGPUUUIDs, log)
167
168
169
		if err != nil {
			return nil, fmt.Errorf("failed to build CUDA device map: %w", err)
		}
170
		log.V(1).Info("GPU UUIDs for device map",
171
172
173
174
			"source_uuids", m.CUDA.SourceGPUUUIDs,
			"target_uuids", targetGPUUUIDs,
			"device_map", cudaDeviceMap,
		)
175
176
177
178
179
	}

	return &types.RestoreContainerSnapshot{
		CheckpointPath: checkpointPath,
		PlaceholderPID: placeholderPID,
180
		TargetRoot:     fmt.Sprintf("%s/%d/root", snapshotruntime.HostProcPath, placeholderPID),
181
182
183
184
185
186
187
		CgroupRoot:     cgroupRoot,
		CUDADeviceMap:  cudaDeviceMap,
	}, nil
}

// execNSRestore launches the nsrestore binary inside the placeholder container's
// namespaces via nsenter and parses the restored PID from stdout JSON.
188
func execNSRestore(ctx context.Context, log logr.Logger, req RestoreRequest, snap *types.RestoreContainerSnapshot) (*RestoreInNamespaceResult, error) {
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
	args := []string{
		"-t", strconv.Itoa(snap.PlaceholderPID),
		// Intentionally exclude cgroup namespace (-C): CRIU must manage cgroups
		// from the host-visible hierarchy so --cgroup-root remap works.
		"-m", "-u", "-i", "-n", "-p",
		"--", req.NSRestorePath,
		"--checkpoint-path", snap.CheckpointPath,
	}
	if snap.CUDADeviceMap != "" {
		args = append(args, "--cuda-device-map", snap.CUDADeviceMap)
	}
	if snap.CgroupRoot != "" {
		args = append(args, "--cgroup-root", snap.CgroupRoot)
	}

	cmd := exec.CommandContext(ctx, "nsenter", args...)
205
206
	// Inherit the agent environment so nsrestore uses the same logger settings.
	cmd.Env = os.Environ()
207
208
209
210
211
212
213
	log.V(1).Info("Executing nsenter + nsrestore", "cmd", cmd.String())

	var stdout bytes.Buffer
	cmd.Stdout = &stdout
	cmd.Stderr = os.Stderr

	if err := cmd.Run(); err != nil {
214
		return nil, fmt.Errorf("nsrestore failed: %w\nstdout: %s", err, stdout.String())
215
216
	}

217
	var result RestoreInNamespaceResult
218
	if err := json.Unmarshal(stdout.Bytes(), &result); err != nil {
219
		return nil, fmt.Errorf("failed to parse nsrestore result: %w\nstdout: %s", err, stdout.String())
220
221
	}
	if result.RestoredPID <= 0 {
222
		return nil, fmt.Errorf("nsrestore returned invalid PID %d", result.RestoredPID)
223
224
	}

225
	return &result, nil
226
}