nsrestore.go 5.78 KB
Newer Older
1
package executor
2
3
4
5
6
7
8
9
10
11

import (
	"context"
	"fmt"
	"syscall"
	"time"

	criurpc "github.com/checkpoint-restore/go-criu/v8/rpc"
	"github.com/go-logr/logr"

12
13
14
15
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/criu"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/cuda"
	snapshotruntime "github.com/ai-dynamo/dynamo/deploy/snapshot/internal/runtime"
	"github.com/ai-dynamo/dynamo/deploy/snapshot/internal/types"
16
17
18
19
20
21
22
23
24
)

// RestoreOptions holds configuration for an in-namespace restore.
type RestoreOptions struct {
	CheckpointPath string
	CUDADeviceMap  string
	CgroupRoot     string
}

25
26
27
28
29
30
31
type RestoreInNamespaceResult struct {
	RestoredPID            int           `json:"restoredPID"`
	NSRestoreSetupDuration time.Duration `json:"nsrestoreSetupDuration"`
	CRIURestoreDuration    time.Duration `json:"criuRestoreDuration"`
	CUDADuration           time.Duration `json:"cudaDuration"`
}

32
// RestoreInNamespace performs a full restore from inside the target container's namespaces.
33
func RestoreInNamespace(ctx context.Context, opts RestoreOptions, log logr.Logger) (*RestoreInNamespaceResult, error) {
34
35
36
37
38
39
40
	restoreStart := time.Now()
	log.Info("Starting nsrestore workflow",
		"checkpoint_path", opts.CheckpointPath,
		"has_cuda_map", opts.CUDADeviceMap != "",
		"cgroup_root", opts.CgroupRoot,
	)

41
	manifestReadStart := time.Now()
42
43
	m, err := types.ReadManifest(opts.CheckpointPath)
	if err != nil {
44
		return nil, fmt.Errorf("failed to read manifest: %w", err)
45
	}
46
47
	manifestReadDuration := time.Since(manifestReadStart)
	log.V(1).Info("Loaded checkpoint manifest",
48
49
50
51
52
		"ext_mounts", len(m.CRIUDump.ExtMnt),
		"criu_log_level", m.CRIUDump.CRIU.LogLevel,
		"manage_cgroups_mode", m.CRIUDump.CRIU.ManageCgroupsMode,
		"checkpoint_has_cuda", !m.CUDA.IsEmpty(),
	)
53

54
	// Phase 1: Configure — build CRIU opts from manifest
55
	configureStart := time.Now()
56
	criuOpts, err := criu.BuildRestoreOpts(m, opts.CheckpointPath, opts.CgroupRoot, log)
57
	if err != nil {
58
		return nil, err
59
	}
60
	configureDuration := time.Since(configureStart)
61
62

	// Phase 2: Execute — rootfs, CRIU restore, CUDA restore
63
	executeTimings, restoredPID, err := executeRestore(ctx, criuOpts, m, opts, log)
64
	if err != nil {
65
		return nil, err
66
67
	}

68
69
70
71
72
73
74
75
76
77
78
79
80
81
	result := &RestoreInNamespaceResult{
		RestoredPID:            restoredPID,
		NSRestoreSetupDuration: manifestReadDuration + configureDuration + executeTimings.nsrestoreSetupDuration,
		CRIURestoreDuration:    executeTimings.criuRestoreDuration,
		CUDADuration:           executeTimings.cudaDuration,
	}
	log.V(1).Info("nsrestore timing summary",
		"restored_pid", restoredPID,
		"nsrestore_setup_duration", result.NSRestoreSetupDuration,
		"criu_restore_duration", result.CRIURestoreDuration,
		"cuda_duration", result.CUDADuration,
		"total_duration", time.Since(restoreStart),
	)
	return result, nil
82
83
}

84
85
86
87
88
89
90
91
92
type nsrestorePhaseTimings struct {
	nsrestoreSetupDuration time.Duration
	criuRestoreDuration    time.Duration
	cudaDuration           time.Duration
}

func executeRestore(ctx context.Context, criuOpts *criurpc.CriuOpts, m *types.CheckpointManifest, opts RestoreOptions, log logr.Logger) (*nsrestorePhaseTimings, int, error) {
	timings := &nsrestorePhaseTimings{}

93
	// Apply rootfs diff inside the namespace (target root is /)
94
	nsrestoreSetupStart := time.Now()
95
	if err := snapshotruntime.ApplyRootfsDiff(opts.CheckpointPath, "/", log); err != nil {
96
		return nil, 0, fmt.Errorf("rootfs diff failed: %w", err)
97
	}
98

99
	if err := snapshotruntime.ApplyDeletedFiles(opts.CheckpointPath, "/", log); err != nil {
100
101
102
103
104
		log.Error(err, "Failed to apply deleted files")
	}

	// Unmount placeholder's /dev/shm so CRIU can recreate tmpfs with checkpointed content
	if err := syscall.Unmount("/dev/shm", 0); err != nil {
105
		return nil, 0, fmt.Errorf("failed to unmount /dev/shm before restore: %w", err)
106
107
	}

108
	if err := snapshotruntime.RemountProcSys(true); err != nil {
109
		return nil, 0, fmt.Errorf("failed to remount /proc/sys read-write for restore: %w", err)
110
	}
111
	timings.nsrestoreSetupDuration = time.Since(nsrestoreSetupStart)
112
	defer func() {
113
		if err := snapshotruntime.RemountProcSys(false); err != nil {
114
115
116
117
118
			log.Error(err, "Failed to remount /proc/sys read-only after restore")
		}
	}()

	// CRIU restore
119
	criuRestoreStart := time.Now()
120
121
	restoredPID, err := criu.ExecuteRestore(criuOpts, m, opts.CheckpointPath, log)
	if err != nil {
122
		return nil, 0, err
123
	}
124
125
126
	timings.criuRestoreDuration = time.Since(criuRestoreStart)

	cudaStart := time.Now()
127
	processes, err := snapshotruntime.ReadProcessTable("/proc")
128
	if err != nil {
129
		return nil, 0, fmt.Errorf("failed to read restored process table: %w", err)
130
	}
131
	log.V(1).Info("Restored process table snapshot",
132
133
134
135
136
137
		"proc_root", "/proc",
		"criu_callback_pid", restoredPID,
		"process_count", len(processes),
		"manifest_cuda_pids", m.CUDA.PIDs,
	)
	for _, process := range processes {
138
		log.V(1).Info("Restored process entry",
139
140
141
142
143
144
145
146
			"observed_pid", process.ObservedPID,
			"parent_pid", process.ParentPID,
			"outermost_pid", process.OutermostPID,
			"innermost_pid", process.InnermostPID,
			"namespace_pids", process.NamespacePIDs,
			"cmdline", process.Cmdline,
		)
	}
147

148
149
	// CUDA restore — remap checkpoint-time innermost namespace PIDs onto the
	// current visible restored PIDs before invoking cuda-checkpoint.
150
	if !m.CUDA.IsEmpty() {
151
		restorePIDs, err := snapshotruntime.ResolveManifestPIDsToObservedPIDs(processes, int(restoredPID), m.CUDA.PIDs)
152
		if err != nil {
153
			return nil, 0, fmt.Errorf("failed to resolve restored CUDA PIDs: %w", err)
154
		}
155
		log.V(1).Info("Resolved manifest CUDA PIDs to current restore PIDs",
156
157
158
159
			"manifest_cuda_pids", m.CUDA.PIDs,
			"restored_cuda_pids", restorePIDs,
			"criu_callback_pid", restoredPID,
		)
160
161
162
		_, err = cuda.RestoreAndUnlockProcessTree(ctx, restorePIDs, opts.CUDADeviceMap, log)
		if err != nil {
			return nil, 0, fmt.Errorf("CUDA restore failed: %w", err)
163
164
		}
	}
165
	timings.cudaDuration = time.Since(cudaStart)
166

167
	return timings, int(restoredPID), nil
168
}