"docs/api/nixl-connect/write-operation.md" did not exist on "39d645e58647d6adb074650e46be5de25f3f3bc6"
checkpoint.go 7.29 KB
Newer Older
1
// Package executor provides the top-level checkpoint and restore executors.
2
// These wire together the lib packages (criu, cuda, etc.) into multi-step workflows.
3
package executor
4
5
6
7
8
9
10
11
12
13
14

import (
	"context"
	"fmt"
	"os"
	"path/filepath"
	"time"

	criurpc "github.com/checkpoint-restore/go-criu/v8/rpc"
	"github.com/containerd/containerd"
	"github.com/go-logr/logr"
15
	"github.com/google/uuid"
16

17
18
19
20
	"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/common"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/criu"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/cuda"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/pkg/types"
21
22
23
24
)

// CheckpointRequest holds per-checkpoint identifiers for a checkpoint operation.
type CheckpointRequest struct {
25
26
27
28
29
30
31
32
	ContainerID           string
	ContainerName         string
	CheckpointHash        string
	CheckpointLocation    string
	CheckpointStorageType string
	NodeName              string
	PodName               string
	PodNamespace          string
33
34
35
36
37
}

// Checkpoint performs a CRIU dump of a container.
// The operation has three phases: inspect, configure, capture.
//
38
39
40
// The checkpoint directory is staged under tmp/<uuid> during the operation.
// On success, the previous checkpoint is removed and the staged directory is
// renamed into place at the base path root.
41
42
43
44
func Checkpoint(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req CheckpointRequest, cfg *types.AgentConfig) error {
	checkpointStart := time.Now()
	log.Info("=== Starting checkpoint operation ===")

45
46
	if req.CheckpointStorageType != "pvc" {
		return fmt.Errorf("checkpoint storage type %q is not supported", req.CheckpointStorageType)
47
	}
48
49
	if req.CheckpointLocation == "" {
		return fmt.Errorf("checkpoint location is required")
50
51
	}

52
53
54
55
56
57
58
59
60
61
62
	finalDir := req.CheckpointLocation
	tmpRoot := filepath.Join(filepath.Dir(finalDir), "tmp")
	if err := os.MkdirAll(tmpRoot, 0700); err != nil {
		return fmt.Errorf("failed to create checkpoint staging root: %w", err)
	}
	tmpDir := filepath.Join(tmpRoot, uuid.NewString())
	if err := os.Mkdir(tmpDir, 0700); err != nil {
		return fmt.Errorf("failed to create checkpoint staging directory: %w", err)
	}
	defer os.RemoveAll(tmpDir)

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
	// Phase 1: Inspect container state
	state, err := inspectContainer(ctx, ctrd, log, req)
	if err != nil {
		return err
	}

	// Phase 2: Configure CRIU options and build checkpoint manifest
	criuOpts, data, err := configureCheckpoint(log, state, req, cfg, tmpDir)
	if err != nil {
		return err
	}

	// Phase 3: Capture — CRIU dump, rootfs diff
	criuDumpDuration, err := captureCheckpoint(ctx, criuOpts, &cfg.CRIU, data, state, tmpDir, log)
	if err != nil {
		return err
	}

	// Remove any previous checkpoint with the same identity hash before finalizing
82
83
84
	if err := os.RemoveAll(finalDir); err != nil {
		return fmt.Errorf("failed to remove previous checkpoint directory: %w", err)
	}
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
	if err := os.Rename(tmpDir, finalDir); err != nil {
		return fmt.Errorf("failed to finalize checkpoint directory: %w", err)
	}

	totalDuration := time.Since(checkpointStart)
	log.Info("=== Checkpoint operation completed ===",
		"total_duration", totalDuration,
		"criu_dump_duration", criuDumpDuration,
	)

	return nil
}

func inspectContainer(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req CheckpointRequest) (*types.CheckpointContainerSnapshot, error) {
	containerID := req.ContainerID
	pid, ociSpec, err := common.ResolveContainer(ctx, ctrd, containerID)
	if err != nil {
		return nil, fmt.Errorf("failed to resolve container: %w", err)
	}

	var hostCgroupPath string
	if cgPath, err := common.ResolveCgroupRootFromHostPID(pid); err == nil && cgPath != "" {
		hostCgroupPath = filepath.Join(common.HostCgroupPath, cgPath)
	}

	rootFS, err := common.GetRootFS(pid)
	if err != nil {
		return nil, fmt.Errorf("failed to get rootfs: %w", err)
	}

	upperDir, err := common.GetOverlayUpperDir(pid)
	if err != nil {
		return nil, fmt.Errorf("failed to get overlay upperdir: %w", err)
	}

	mountInfo, err := common.ReadMountInfo(pid)
	if err != nil {
		return nil, fmt.Errorf("failed to parse mountinfo: %w", err)
	}
	mounts := common.ClassifyMounts(mountInfo, ociSpec, rootFS)

	netNSInode, err := common.GetNetNSInode(pid)
	if err != nil {
		return nil, fmt.Errorf("failed to get net namespace inode: %w", err)
	}

	// Read stdio FD targets (like runc's getPipeFds / descriptors.json).
	stdioFDs := make([]string, 3)
	for i := range 3 {
		target, err := os.Readlink(fmt.Sprintf("%s/%d/fd/%d", common.HostProcPath, pid, i))
		if err != nil {
			log.V(1).Info("Failed to readlink stdio FD", "fd", i, "error", err)
			continue
		}
		stdioFDs[i] = target
	}

	// Discover CUDA processes and GPU UUIDs
	allPIDs := common.ProcessTreePIDs(pid)
	cudaPIDs := cuda.FilterProcesses(ctx, allPIDs, log)
	var gpuUUIDs []string
	if len(cudaPIDs) > 0 {
		gpuUUIDs, err = cuda.GetPodGPUUUIDs(ctx, req.PodName, req.PodNamespace, req.ContainerName)
		if err != nil {
			return nil, fmt.Errorf("failed to discover source GPU UUIDs: %w", err)
		}
	}

	return &types.CheckpointContainerSnapshot{
		PID:            pid,
		RootFS:         rootFS,
		UpperDir:       upperDir,
		OCISpec:        ociSpec,
		Mounts:         mounts,
		NetNSInode:     netNSInode,
		StdioFDs:       stdioFDs,
		HostCgroupPath: hostCgroupPath,
		CUDAPIDs:       cudaPIDs,
		GPUUUIDs:       gpuUUIDs,
	}, nil
}

func configureCheckpoint(
	log logr.Logger,
	state *types.CheckpointContainerSnapshot,
	req CheckpointRequest,
	cfg *types.AgentConfig,
	checkpointDir string,
) (*criurpc.CriuOpts, *types.CheckpointManifest, error) {
	criuOpts, err := criu.BuildDumpOptions(state, &cfg.CRIU, checkpointDir, log)
	if err != nil {
		return nil, nil, err
	}

	m := types.NewCheckpointManifest(
		req.CheckpointHash,
		types.NewCRIUDumpManifest(criuOpts, cfg.CRIU),
		types.NewSourcePodManifest(req.ContainerID, state.PID, req.NodeName, req.PodName, req.PodNamespace, state.StdioFDs),
		types.NewOverlayManifest(cfg.Overlay, state.UpperDir, state.OCISpec),
	)
	if len(state.CUDAPIDs) > 0 {
		m.CUDA = types.NewCUDAManifest(state.CUDAPIDs, state.GPUUUIDs)
	}

	if err := types.WriteManifest(checkpointDir, m); err != nil {
		return nil, nil, fmt.Errorf("failed to write checkpoint manifest: %w", err)
	}

	return criuOpts, m, nil
}

func captureCheckpoint(ctx context.Context, criuOpts *criurpc.CriuOpts, criuSettings *types.CRIUSettings, data *types.CheckpointManifest, state *types.CheckpointContainerSnapshot, checkpointDir string, log logr.Logger) (time.Duration, error) {
	// CUDA lock+checkpoint must happen before CRIU dump
	if len(state.CUDAPIDs) > 0 {
		if err := cuda.LockAndCheckpointProcessTree(ctx, state.CUDAPIDs, log); err != nil {
			return 0, fmt.Errorf("CUDA checkpoint failed: %w", err)
		}
	}

	criuDumpDuration, err := criu.ExecuteDump(criuOpts, checkpointDir, criuSettings, log)
	if err != nil {
		return 0, err
	}

	// Overlay rootfs diff capture is best-effort. Failures are logged but not
	// propagated — a checkpoint without overlay diffs is still valid for restore
	// (the base container image provides the filesystem).
	if state.UpperDir != "" {
		if _, err := common.CaptureRootfsDiff(state.UpperDir, checkpointDir, data.Overlay.Exclusions, data.Overlay.BindMountDests); err != nil {
			log.Error(err, "Failed to capture rootfs diff")
		}
		if _, err := common.CaptureDeletedFiles(state.UpperDir, checkpointDir); err != nil {
			log.Error(err, "Failed to capture deleted files")
		}
	}

	return criuDumpDuration, nil
}