checkpoint.go 9.74 KB
Newer Older
1
// Package executor provides the top-level checkpoint and restore executors.
2
// These wire together the lib packages (criu, cuda, etc.) into multi-step workflows.
3
package executor
4
5
6
7
8
9

import (
	"context"
	"fmt"
	"os"
	"path/filepath"
10
	"strings"
11
12
13
14
15
	"time"

	criurpc "github.com/checkpoint-restore/go-criu/v8/rpc"
	"github.com/containerd/containerd"
	"github.com/go-logr/logr"
16
	"github.com/google/uuid"
17
	"k8s.io/client-go/kubernetes"
18

19
20
21
22
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/criu"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/cuda"
	snapshotruntime "github.com/ai-dynamo/dynamo/deploy/snapshot/internal/runtime"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/types"
23
24
25
26
)

// CheckpointRequest holds per-checkpoint identifiers for a checkpoint operation.
type CheckpointRequest struct {
27
28
29
30
	ContainerID        string
	ContainerName      string
	CheckpointID       string
	CheckpointLocation string
31
	StartedAt          time.Time
32
33
34
35
	NodeName           string
	PodName            string
	PodNamespace       string
	Clientset          kubernetes.Interface
36
37
}

38
39
40
41
42
43
44
45
type checkpointPhaseTimings struct {
	PrepareDuration        time.Duration
	CUDADuration           time.Duration
	CRIUDumpDuration       time.Duration
	OverlayCaptureDuration time.Duration
	FinalizeDuration       time.Duration
}

46
47
48
// Checkpoint performs a CRIU dump of a container.
// The operation has three phases: inspect, configure, capture.
//
49
50
51
// The checkpoint directory is staged under tmp/<uuid> during the operation.
// On success, the previous checkpoint is removed and the staged directory is
// renamed into place at the base path root.
52
53
func Checkpoint(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req CheckpointRequest, cfg *types.AgentConfig) error {
	checkpointStart := time.Now()
54
55
	phaseTimings := checkpointPhaseTimings{}
	prepareStart := time.Now()
56
57
	log.Info("=== Starting checkpoint operation ===")

58
59
	if strings.TrimSpace(req.CheckpointID) == "" {
		return fmt.Errorf("checkpoint ID is required")
60
	}
61
62
	if req.CheckpointLocation == "" {
		return fmt.Errorf("checkpoint location is required")
63
64
	}

65
66
67
68
69
70
71
72
73
74
75
	finalDir := req.CheckpointLocation
	tmpRoot := filepath.Join(filepath.Dir(finalDir), "tmp")
	if err := os.MkdirAll(tmpRoot, 0700); err != nil {
		return fmt.Errorf("failed to create checkpoint staging root: %w", err)
	}
	tmpDir := filepath.Join(tmpRoot, uuid.NewString())
	if err := os.Mkdir(tmpDir, 0700); err != nil {
		return fmt.Errorf("failed to create checkpoint staging directory: %w", err)
	}
	defer os.RemoveAll(tmpDir)

76
77
78
79
80
81
82
83
84
85
86
	// Phase 1: Inspect container state
	state, err := inspectContainer(ctx, ctrd, log, req)
	if err != nil {
		return err
	}

	// Phase 2: Configure CRIU options and build checkpoint manifest
	criuOpts, data, err := configureCheckpoint(log, state, req, cfg, tmpDir)
	if err != nil {
		return err
	}
87
	phaseTimings.PrepareDuration = time.Since(prepareStart)
88
89

	// Phase 3: Capture — CRIU dump, rootfs diff
90
	captureTimings, err := captureCheckpoint(ctx, criuOpts, &cfg.CRIU, data, state, tmpDir, log)
91
92
93
	if err != nil {
		return err
	}
94
95
96
	phaseTimings.CUDADuration = captureTimings.CUDADuration
	phaseTimings.CRIUDumpDuration = captureTimings.CRIUDumpDuration
	phaseTimings.OverlayCaptureDuration = captureTimings.OverlayCaptureDuration
97

98
99
100
	// Remove any previous checkpoint with the same identity hash, then
	// promote the staged checkpoint directory into place.
	finalizeStart := time.Now()
101
102
103
	if err := os.RemoveAll(finalDir); err != nil {
		return fmt.Errorf("failed to remove previous checkpoint directory: %w", err)
	}
104
105
106
	if err := os.Rename(tmpDir, finalDir); err != nil {
		return fmt.Errorf("failed to finalize checkpoint directory: %w", err)
	}
107
	phaseTimings.FinalizeDuration = time.Since(finalizeStart)
108
109

	totalDuration := time.Since(checkpointStart)
110
111
112
113
114
115
116
117
118
119
120
	log.Info("Checkpoint timing summary",
		"checkpoint", map[string]any{
			"duration": totalDuration.String(),
			"phases": map[string]string{
				"prepare_duration":         phaseTimings.PrepareDuration.String(),
				"cuda_duration":            phaseTimings.CUDADuration.String(),
				"criu_dump_duration":       phaseTimings.CRIUDumpDuration.String(),
				"overlay_capture_duration": phaseTimings.OverlayCaptureDuration.String(),
				"finalize_duration":        phaseTimings.FinalizeDuration.String(),
			},
		},
121
	)
122
123
124
125
126
	if !req.StartedAt.IsZero() {
		log.Info("Checkpoint wall time from agent detection",
			"started_to_checkpoint_complete", time.Since(req.StartedAt),
		)
	}
127
128
129
130
131
132

	return nil
}

func inspectContainer(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req CheckpointRequest) (*types.CheckpointContainerSnapshot, error) {
	containerID := req.ContainerID
133
	pid, ociSpec, err := snapshotruntime.ResolveContainer(ctx, ctrd, containerID)
134
135
136
137
138
	if err != nil {
		return nil, fmt.Errorf("failed to resolve container: %w", err)
	}

	var hostCgroupPath string
139
140
	if cgPath, err := snapshotruntime.ResolveCgroupRootFromHostPID(pid); err == nil && cgPath != "" {
		hostCgroupPath = filepath.Join(snapshotruntime.HostCgroupPath, cgPath)
141
142
	}

143
	rootFS, err := snapshotruntime.GetRootFS(pid)
144
145
146
147
	if err != nil {
		return nil, fmt.Errorf("failed to get rootfs: %w", err)
	}

148
	upperDir, err := snapshotruntime.GetOverlayUpperDir(pid)
149
150
151
152
	if err != nil {
		return nil, fmt.Errorf("failed to get overlay upperdir: %w", err)
	}

153
	mountInfo, err := snapshotruntime.ReadMountInfo(pid)
154
155
156
	if err != nil {
		return nil, fmt.Errorf("failed to parse mountinfo: %w", err)
	}
157
	mounts := snapshotruntime.ClassifyMounts(mountInfo, ociSpec, rootFS)
158

159
	netNSInode, err := snapshotruntime.GetNetNSInode(pid)
160
161
162
163
164
165
166
	if err != nil {
		return nil, fmt.Errorf("failed to get net namespace inode: %w", err)
	}

	// Read stdio FD targets (like runc's getPipeFds / descriptors.json).
	stdioFDs := make([]string, 3)
	for i := range 3 {
167
		target, err := os.Readlink(fmt.Sprintf("%s/%d/fd/%d", snapshotruntime.HostProcPath, pid, i))
168
169
170
171
172
173
174
175
		if err != nil {
			log.V(1).Info("Failed to readlink stdio FD", "fd", i, "error", err)
			continue
		}
		stdioFDs[i] = target
	}

	// Discover CUDA processes and GPU UUIDs
176
	allPIDs := snapshotruntime.ProcessTreePIDs(pid)
177
178
179
	cudaHostPIDs := cuda.FilterProcesses(ctx, allPIDs, log)
	cudaNamespacePIDs := make([]int, 0, len(cudaHostPIDs))
	for _, cudaHostPID := range cudaHostPIDs {
180
		process, err := snapshotruntime.ReadProcessDetails(snapshotruntime.HostProcPath, cudaHostPID)
181
182
183
184
185
186
187
188
189
		if err != nil {
			return nil, fmt.Errorf("failed to read process details for CUDA process %d: %w", cudaHostPID, err)
		}
		if len(process.NamespacePIDs) != 2 {
			return nil, fmt.Errorf("CUDA process %d has namespace depth %d, want 2", cudaHostPID, len(process.NamespacePIDs))
		}
		cudaNamespacePIDs = append(cudaNamespacePIDs, process.InnermostPID)
	}
	if len(cudaHostPIDs) > 0 {
190
		log.V(1).Info("Resolved checkpoint CUDA PID mapping", "host_pids", cudaHostPIDs, "namespace_pids", cudaNamespacePIDs)
191
	}
192
	var gpuUUIDs []string
193
	if len(cudaHostPIDs) > 0 {
194
195
196
197
198
199
200
201
202
203
		gpuUUIDs, err = cuda.DiscoverGPUUUIDs(
			ctx,
			req.Clientset,
			req.PodName,
			req.PodNamespace,
			req.ContainerName,
			snapshotruntime.HostProcPath,
			pid,
			log,
		)
204
205
206
207
208
209
210
211
212
213
214
215
216
217
		if err != nil {
			return nil, fmt.Errorf("failed to discover source GPU UUIDs: %w", err)
		}
	}

	return &types.CheckpointContainerSnapshot{
		PID:            pid,
		RootFS:         rootFS,
		UpperDir:       upperDir,
		OCISpec:        ociSpec,
		Mounts:         mounts,
		NetNSInode:     netNSInode,
		StdioFDs:       stdioFDs,
		HostCgroupPath: hostCgroupPath,
218
219
		CUDAHostPIDs:   cudaHostPIDs,
		CUDANSPIDs:     cudaNamespacePIDs,
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
		GPUUUIDs:       gpuUUIDs,
	}, nil
}

func configureCheckpoint(
	log logr.Logger,
	state *types.CheckpointContainerSnapshot,
	req CheckpointRequest,
	cfg *types.AgentConfig,
	checkpointDir string,
) (*criurpc.CriuOpts, *types.CheckpointManifest, error) {
	criuOpts, err := criu.BuildDumpOptions(state, &cfg.CRIU, checkpointDir, log)
	if err != nil {
		return nil, nil, err
	}

	m := types.NewCheckpointManifest(
237
		req.CheckpointID,
238
239
240
241
		types.NewCRIUDumpManifest(criuOpts, cfg.CRIU),
		types.NewSourcePodManifest(req.ContainerID, state.PID, req.NodeName, req.PodName, req.PodNamespace, state.StdioFDs),
		types.NewOverlayManifest(cfg.Overlay, state.UpperDir, state.OCISpec),
	)
242
243
	if len(state.CUDANSPIDs) > 0 {
		m.CUDA = types.NewCUDAManifest(state.CUDANSPIDs, state.GPUUUIDs)
244
245
246
247
248
249
250
251
252
	}

	if err := types.WriteManifest(checkpointDir, m); err != nil {
		return nil, nil, fmt.Errorf("failed to write checkpoint manifest: %w", err)
	}

	return criuOpts, m, nil
}

253
254
255
func captureCheckpoint(ctx context.Context, criuOpts *criurpc.CriuOpts, criuSettings *types.CRIUSettings, data *types.CheckpointManifest, state *types.CheckpointContainerSnapshot, checkpointDir string, log logr.Logger) (*checkpointPhaseTimings, error) {
	timings := &checkpointPhaseTimings{}

256
	// CUDA lock+checkpoint must happen before CRIU dump
257
	if len(state.CUDAHostPIDs) > 0 {
258
259
260
		cudaTimings, err := cuda.LockAndCheckpointProcessTree(ctx, state.CUDAHostPIDs, log)
		if err != nil {
			return nil, fmt.Errorf("CUDA checkpoint failed: %w", err)
261
		}
262
		timings.CUDADuration = cudaTimings.TotalDuration
263
264
265
266
	}

	criuDumpDuration, err := criu.ExecuteDump(criuOpts, checkpointDir, criuSettings, log)
	if err != nil {
267
		return nil, err
268
	}
269
	timings.CRIUDumpDuration = criuDumpDuration
270
271
272
273
274

	// Overlay rootfs diff capture is best-effort. Failures are logged but not
	// propagated — a checkpoint without overlay diffs is still valid for restore
	// (the base container image provides the filesystem).
	if state.UpperDir != "" {
275
		overlayCaptureStart := time.Now()
276
		if _, err := snapshotruntime.CaptureRootfsDiff(state.UpperDir, checkpointDir, data.Overlay.Exclusions, data.Overlay.BindMountDests); err != nil {
277
278
			log.Error(err, "Failed to capture rootfs diff")
		}
279
		if _, err := snapshotruntime.CaptureDeletedFiles(state.UpperDir, checkpointDir); err != nil {
280
281
			log.Error(err, "Failed to capture deleted files")
		}
282
		timings.OverlayCaptureDuration = time.Since(overlayCaptureStart)
283
284
	}

285
	return timings, nil
286
}