checkpoint.go 8.55 KB
Newer Older
1
// Package executor provides the top-level checkpoint and restore executors.
2
// These wire together the lib packages (criu, cuda, etc.) into multi-step workflows.
3
package executor
4
5
6
7
8
9

import (
	"context"
	"fmt"
	"os"
	"path/filepath"
10
	"strings"
11
12
13
14
15
	"time"

	criurpc "github.com/checkpoint-restore/go-criu/v8/rpc"
	"github.com/containerd/containerd"
	"github.com/go-logr/logr"
16
	"github.com/google/uuid"
17
	"k8s.io/client-go/kubernetes"
18

19
20
21
22
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/criu"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/cuda"
	snapshotruntime "github.com/ai-dynamo/dynamo/deploy/snapshot/internal/runtime"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/types"
23
24
25
26
)

// CheckpointRequest holds per-checkpoint identifiers for a checkpoint operation.
type CheckpointRequest struct {
27
28
29
30
31
32
33
34
	ContainerID        string
	ContainerName      string
	CheckpointID       string
	CheckpointLocation string
	NodeName           string
	PodName            string
	PodNamespace       string
	Clientset          kubernetes.Interface
35
36
37
38
39
}

// Checkpoint performs a CRIU dump of a container.
// The operation has three phases: inspect, configure, capture.
//
40
41
42
// The checkpoint directory is staged under tmp/<uuid> during the operation.
// On success, the previous checkpoint is removed and the staged directory is
// renamed into place at the base path root.
43
44
45
46
func Checkpoint(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req CheckpointRequest, cfg *types.AgentConfig) error {
	checkpointStart := time.Now()
	log.Info("=== Starting checkpoint operation ===")

47
48
	if strings.TrimSpace(req.CheckpointID) == "" {
		return fmt.Errorf("checkpoint ID is required")
49
	}
50
51
	if req.CheckpointLocation == "" {
		return fmt.Errorf("checkpoint location is required")
52
53
	}

54
55
56
57
58
59
60
61
62
63
64
	finalDir := req.CheckpointLocation
	tmpRoot := filepath.Join(filepath.Dir(finalDir), "tmp")
	if err := os.MkdirAll(tmpRoot, 0700); err != nil {
		return fmt.Errorf("failed to create checkpoint staging root: %w", err)
	}
	tmpDir := filepath.Join(tmpRoot, uuid.NewString())
	if err := os.Mkdir(tmpDir, 0700); err != nil {
		return fmt.Errorf("failed to create checkpoint staging directory: %w", err)
	}
	defer os.RemoveAll(tmpDir)

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
	// Phase 1: Inspect container state
	state, err := inspectContainer(ctx, ctrd, log, req)
	if err != nil {
		return err
	}

	// Phase 2: Configure CRIU options and build checkpoint manifest
	criuOpts, data, err := configureCheckpoint(log, state, req, cfg, tmpDir)
	if err != nil {
		return err
	}

	// Phase 3: Capture — CRIU dump, rootfs diff
	criuDumpDuration, err := captureCheckpoint(ctx, criuOpts, &cfg.CRIU, data, state, tmpDir, log)
	if err != nil {
		return err
	}

	// Remove any previous checkpoint with the same identity hash before finalizing
84
85
86
	if err := os.RemoveAll(finalDir); err != nil {
		return fmt.Errorf("failed to remove previous checkpoint directory: %w", err)
	}
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
	if err := os.Rename(tmpDir, finalDir); err != nil {
		return fmt.Errorf("failed to finalize checkpoint directory: %w", err)
	}

	totalDuration := time.Since(checkpointStart)
	log.Info("=== Checkpoint operation completed ===",
		"total_duration", totalDuration,
		"criu_dump_duration", criuDumpDuration,
	)

	return nil
}

func inspectContainer(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req CheckpointRequest) (*types.CheckpointContainerSnapshot, error) {
	containerID := req.ContainerID
102
	pid, ociSpec, err := snapshotruntime.ResolveContainer(ctx, ctrd, containerID)
103
104
105
106
107
	if err != nil {
		return nil, fmt.Errorf("failed to resolve container: %w", err)
	}

	var hostCgroupPath string
108
109
	if cgPath, err := snapshotruntime.ResolveCgroupRootFromHostPID(pid); err == nil && cgPath != "" {
		hostCgroupPath = filepath.Join(snapshotruntime.HostCgroupPath, cgPath)
110
111
	}

112
	rootFS, err := snapshotruntime.GetRootFS(pid)
113
114
115
116
	if err != nil {
		return nil, fmt.Errorf("failed to get rootfs: %w", err)
	}

117
	upperDir, err := snapshotruntime.GetOverlayUpperDir(pid)
118
119
120
121
	if err != nil {
		return nil, fmt.Errorf("failed to get overlay upperdir: %w", err)
	}

122
	mountInfo, err := snapshotruntime.ReadMountInfo(pid)
123
124
125
	if err != nil {
		return nil, fmt.Errorf("failed to parse mountinfo: %w", err)
	}
126
	mounts := snapshotruntime.ClassifyMounts(mountInfo, ociSpec, rootFS)
127

128
	netNSInode, err := snapshotruntime.GetNetNSInode(pid)
129
130
131
132
133
134
135
	if err != nil {
		return nil, fmt.Errorf("failed to get net namespace inode: %w", err)
	}

	// Read stdio FD targets (like runc's getPipeFds / descriptors.json).
	stdioFDs := make([]string, 3)
	for i := range 3 {
136
		target, err := os.Readlink(fmt.Sprintf("%s/%d/fd/%d", snapshotruntime.HostProcPath, pid, i))
137
138
139
140
141
142
143
144
		if err != nil {
			log.V(1).Info("Failed to readlink stdio FD", "fd", i, "error", err)
			continue
		}
		stdioFDs[i] = target
	}

	// Discover CUDA processes and GPU UUIDs
145
	allPIDs := snapshotruntime.ProcessTreePIDs(pid)
146
147
148
	cudaHostPIDs := cuda.FilterProcesses(ctx, allPIDs, log)
	cudaNamespacePIDs := make([]int, 0, len(cudaHostPIDs))
	for _, cudaHostPID := range cudaHostPIDs {
149
		process, err := snapshotruntime.ReadProcessDetails(snapshotruntime.HostProcPath, cudaHostPID)
150
151
152
153
154
155
156
157
158
159
160
		if err != nil {
			return nil, fmt.Errorf("failed to read process details for CUDA process %d: %w", cudaHostPID, err)
		}
		if len(process.NamespacePIDs) != 2 {
			return nil, fmt.Errorf("CUDA process %d has namespace depth %d, want 2", cudaHostPID, len(process.NamespacePIDs))
		}
		cudaNamespacePIDs = append(cudaNamespacePIDs, process.InnermostPID)
	}
	if len(cudaHostPIDs) > 0 {
		log.Info("Resolved checkpoint CUDA PID mapping", "host_pids", cudaHostPIDs, "namespace_pids", cudaNamespacePIDs)
	}
161
	var gpuUUIDs []string
162
	if len(cudaHostPIDs) > 0 {
163
164
165
166
		gpuUUIDs, err = cuda.GetPodGPUUUIDs(ctx, req.PodName, req.PodNamespace, req.ContainerName)
		if err != nil {
			return nil, fmt.Errorf("failed to discover source GPU UUIDs: %w", err)
		}
167
168
		if len(gpuUUIDs) == 0 {
			log.Info("PodResources API returned no GPU UUIDs, falling back to nvidia-smi", "pid", pid)
169
			gpuUUIDs, err = cuda.GetGPUUUIDsViaNvidiaSmi(ctx, snapshotruntime.HostProcPath, pid)
170
171
172
173
174
			if err != nil {
				return nil, fmt.Errorf("nvidia-smi GPU UUID fallback failed: %w", err)
			}
			log.Info("nvidia-smi fallback discovered GPU UUIDs", "uuids", gpuUUIDs)
		}
175
176
177
178
179
180
181
182
183
184
185
	}

	return &types.CheckpointContainerSnapshot{
		PID:            pid,
		RootFS:         rootFS,
		UpperDir:       upperDir,
		OCISpec:        ociSpec,
		Mounts:         mounts,
		NetNSInode:     netNSInode,
		StdioFDs:       stdioFDs,
		HostCgroupPath: hostCgroupPath,
186
187
		CUDAHostPIDs:   cudaHostPIDs,
		CUDANSPIDs:     cudaNamespacePIDs,
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
		GPUUUIDs:       gpuUUIDs,
	}, nil
}

func configureCheckpoint(
	log logr.Logger,
	state *types.CheckpointContainerSnapshot,
	req CheckpointRequest,
	cfg *types.AgentConfig,
	checkpointDir string,
) (*criurpc.CriuOpts, *types.CheckpointManifest, error) {
	criuOpts, err := criu.BuildDumpOptions(state, &cfg.CRIU, checkpointDir, log)
	if err != nil {
		return nil, nil, err
	}

	m := types.NewCheckpointManifest(
205
		req.CheckpointID,
206
207
208
209
		types.NewCRIUDumpManifest(criuOpts, cfg.CRIU),
		types.NewSourcePodManifest(req.ContainerID, state.PID, req.NodeName, req.PodName, req.PodNamespace, state.StdioFDs),
		types.NewOverlayManifest(cfg.Overlay, state.UpperDir, state.OCISpec),
	)
210
211
	if len(state.CUDANSPIDs) > 0 {
		m.CUDA = types.NewCUDAManifest(state.CUDANSPIDs, state.GPUUUIDs)
212
213
214
215
216
217
218
219
220
221
222
	}

	if err := types.WriteManifest(checkpointDir, m); err != nil {
		return nil, nil, fmt.Errorf("failed to write checkpoint manifest: %w", err)
	}

	return criuOpts, m, nil
}

func captureCheckpoint(ctx context.Context, criuOpts *criurpc.CriuOpts, criuSettings *types.CRIUSettings, data *types.CheckpointManifest, state *types.CheckpointContainerSnapshot, checkpointDir string, log logr.Logger) (time.Duration, error) {
	// CUDA lock+checkpoint must happen before CRIU dump
223
224
	if len(state.CUDAHostPIDs) > 0 {
		if err := cuda.LockAndCheckpointProcessTree(ctx, state.CUDAHostPIDs, log); err != nil {
225
226
227
228
229
230
231
232
233
234
235
236
237
			return 0, fmt.Errorf("CUDA checkpoint failed: %w", err)
		}
	}

	criuDumpDuration, err := criu.ExecuteDump(criuOpts, checkpointDir, criuSettings, log)
	if err != nil {
		return 0, err
	}

	// Overlay rootfs diff capture is best-effort. Failures are logged but not
	// propagated — a checkpoint without overlay diffs is still valid for restore
	// (the base container image provides the filesystem).
	if state.UpperDir != "" {
238
		if _, err := snapshotruntime.CaptureRootfsDiff(state.UpperDir, checkpointDir, data.Overlay.Exclusions, data.Overlay.BindMountDests); err != nil {
239
240
			log.Error(err, "Failed to capture rootfs diff")
		}
241
		if _, err := snapshotruntime.CaptureDeletedFiles(state.UpperDir, checkpointDir); err != nil {
242
243
244
245
246
247
			log.Error(err, "Failed to capture deleted files")
		}
	}

	return criuDumpDuration, nil
}