"recipes/llama-3-70b/vllm/vscode:/vscode.git/clone" did not exist on "cf55e8b85aa3f4d212fd73ca8f4504732db5bbd7"
Unverified Commit bb8fc8a4 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

feat(chrek): external restore, signal-based IPC, and package refactor (#6286)


Co-authored-by: default avatarDan Feigin <dfeigin@nvidia.com>
parent c8423b57
// server.go provides the HTTP server for the checkpoint agent.
package httpApiServer
import (
"context"
"log"
"net/http"
"time"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint"
)
// ServerConfig holds the configuration for the HTTP API server.
type ServerConfig struct {
ListenAddr string
NodeName string
CheckpointSpec *checkpoint.CheckpointSpec
}
// Server is the HTTP API server for checkpoint operations.
type Server struct {
cfg ServerConfig
handlers *Handlers
httpServer *http.Server
}
// NewServer creates a new Server instance.
func NewServer(cfg ServerConfig, checkpointer *checkpoint.Checkpointer) *Server {
handlers := NewHandlers(cfg, checkpointer)
// Setup routes
mux := http.NewServeMux()
mux.HandleFunc("/health", handlers.HandleHealth)
mux.HandleFunc("/checkpoint", handlers.HandleCheckpoint)
mux.HandleFunc("/checkpoints", handlers.HandleListCheckpoints)
// WriteTimeout must exceed the CRIU checkpoint timeout since /checkpoint
// blocks until the dump completes. Add 60s buffer for pre/post work.
writeTimeout := time.Duration(cfg.CheckpointSpec.CRIU.Timeout)*time.Second + 60*time.Second
if writeTimeout < 300*time.Second {
writeTimeout = 300 * time.Second
}
httpServer := &http.Server{
Addr: cfg.ListenAddr,
Handler: LoggingMiddleware(mux),
ReadTimeout: 30 * time.Second,
WriteTimeout: writeTimeout,
IdleTimeout: 120 * time.Second,
}
return &Server{
cfg: cfg,
handlers: handlers,
httpServer: httpServer,
}
}
// Start starts the HTTP server.
// This method blocks until the server is shut down.
func (s *Server) Start() error {
log.Printf("HTTP API server listening on %s", s.cfg.ListenAddr)
return s.httpServer.ListenAndServe()
}
// Shutdown gracefully shuts down the server.
func (s *Server) Shutdown(ctx context.Context) error {
log.Println("Shutting down HTTP server...")
return s.httpServer.Shutdown(ctx)
}
// Addr returns the server's listen address.
func (s *Server) Addr() string {
return s.cfg.ListenAddr
}
// Package server provides HTTP server functionality for the checkpoint agent.
package httpApiServer
import "time"
// CheckpointRequest is the request body for checkpoint operations.
type CheckpointRequest struct {
ContainerID string `json:"container_id"`
ContainerName string `json:"container_name,omitempty"` // K8s container name (for volume type lookup)
CheckpointID string `json:"checkpoint_id"`
PodName string `json:"pod_name,omitempty"`
PodNamespace string `json:"pod_namespace,omitempty"`
DisableCUDA bool `json:"disable_cuda,omitempty"` // Disable CUDA plugin for non-GPU workloads
}
// CheckpointResponse is the response for checkpoint operations.
type CheckpointResponse struct {
Success bool `json:"success"`
CheckpointID string `json:"checkpoint_id,omitempty"`
Message string `json:"message,omitempty"`
Error string `json:"error,omitempty"`
}
// CheckpointInfo represents information about a checkpoint.
type CheckpointInfo struct {
ID string `json:"id"`
CreatedAt time.Time `json:"created_at"`
SourceNode string `json:"source_node"`
ContainerID string `json:"container_id"`
PodName string `json:"pod_name"`
PodNamespace string `json:"pod_namespace"`
}
// ListCheckpointsResponse is the response for list checkpoints.
type ListCheckpointsResponse struct {
Checkpoints []CheckpointInfo `json:"checkpoints"`
}
// HealthResponse is the response for health check.
type HealthResponse struct {
Status string `json:"status"`
NodeName string `json:"node_name"`
}
package logging
import (
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"github.com/go-logr/logr"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
)
// LogProcessDiagnostics logs process state and CRIU restore log for debugging a failed restore.
func LogProcessDiagnostics(procRoot string, pid int, restoreLogPath string, log logr.Logger) {
entry := log.WithValues("restored_pid", pid)
// Process status and cmdline
pidStr := strconv.Itoa(pid)
if data, err := os.ReadFile(filepath.Join(procRoot, pidStr, "status")); err == nil {
entry.Info("Process status", "content", strings.TrimSpace(string(data)))
}
if data, err := os.ReadFile(filepath.Join(procRoot, pidStr, "cmdline")); err == nil {
cmdline := strings.TrimSpace(strings.ReplaceAll(string(data), "\x00", " "))
if cmdline != "" {
entry.Info("Process cmdline", "cmdline", cmdline)
}
}
// Exit code from /proc/stat
if data, err := os.ReadFile(filepath.Join(procRoot, pidStr, "stat")); err == nil {
if ws, err := common.ParseProcExitCode(string(data)); err == nil {
entry.Info("Process exit code", "exit_status", ws.ExitStatus(), "term_signal", ws.Signal(), "core_dumped", ws.CoreDump())
}
}
// PID 1 children in restored namespace
if data, err := os.ReadFile(filepath.Join(procRoot, "1", "task", "1", "children")); err == nil {
entry.Info("PID 1 children", "children", strings.TrimSpace(string(data)))
}
// CRIU restore log summary
logRestoreLog(restoreLogPath, entry)
}
// LogRestoreErrors finds the CRIU restore.log and logs key lines from it.
func LogRestoreErrors(checkpointPath, workDir string, log logr.Logger) {
// Try workdir first, then checkpoint dir
for _, dir := range []string{workDir, checkpointPath} {
if dir == "" {
continue
}
logPath := filepath.Join(dir, "restore.log")
if _, err := os.Stat(logPath); err == nil {
logRestoreLog(logPath, log)
return
}
}
}
// logRestoreLog extracts key lines and tail from a CRIU restore log file.
func logRestoreLog(path string, log logr.Logger) {
data, err := os.ReadFile(path)
if err != nil {
return
}
lines := strings.Split(string(data), "\n")
// Extract error/warning/notable lines
var keyLines []string
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "" {
continue
}
lower := strings.ToLower(trimmed)
if strings.Contains(lower, "error") ||
strings.Contains(lower, "warn") ||
strings.Contains(lower, "fail") ||
strings.Contains(lower, "cuda") ||
strings.Contains(lower, "restore finished successfully") {
keyLines = append(keyLines, trimmed)
if len(keyLines) >= 80 {
break
}
}
}
if len(keyLines) > 0 {
log.Info("CRIU restore key lines", "path", path, "lines", strings.Join(keyLines, " | "))
}
// Last 40 non-empty lines
var tail []string
for i := len(lines) - 1; i >= 0 && len(tail) < 40; i-- {
if trimmed := strings.TrimSpace(lines[i]); trimmed != "" {
tail = append(tail, trimmed)
}
}
slices.Reverse(tail)
if len(tail) > 0 {
log.Info("CRIU restore tail", "path", path, "lines", strings.Join(tail, " | "))
}
}
// Package logging provides shared logger configuration for chrek binaries.
package logging
import (
"fmt"
"os"
"strings"
"github.com/go-logr/logr"
"github.com/go-logr/zapr"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// ConfigureLogger creates a logr.Logger from the CHREK_LOG_LEVEL environment variable.
// output controls where log lines are written ("stdout" or "stderr").
// Supported levels: trace, debug, info, warn, error. Defaults to info.
func ConfigureLogger(output string) logr.Logger {
level := strings.TrimSpace(strings.ToLower(os.Getenv("CHREK_LOG_LEVEL")))
if level == "" {
level = "info"
}
zapLevel := zapcore.InfoLevel
var parseErr error
switch level {
case "trace", "debug":
zapLevel = zapcore.DebugLevel
case "info":
zapLevel = zapcore.InfoLevel
case "warn", "warning":
zapLevel = zapcore.WarnLevel
case "error":
zapLevel = zapcore.ErrorLevel
default:
parseErr = fmt.Errorf("invalid level %q", level)
}
if output == "" {
output = "stdout"
}
zapCfg := zap.Config{
Level: zap.NewAtomicLevelAt(zapLevel),
Development: false,
Encoding: "console",
EncoderConfig: zap.NewProductionEncoderConfig(),
OutputPaths: []string{output},
ErrorOutputPaths: []string{"stderr"},
}
zapCfg.EncoderConfig.EncodeTime = zapcore.RFC3339NanoTimeEncoder
zapCfg.EncoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
zapLog, err := zapCfg.Build()
if err != nil {
zapLog, _ = zap.NewDevelopment()
}
log := zapr.NewLogger(zapLog)
if parseErr != nil {
log.WithName("setup").Info("Invalid CHREK_LOG_LEVEL, falling back to info", "value", level, "error", parseErr)
}
return log
}
// Package orchestrate provides the top-level checkpoint and restore orchestrators.
// These wire together the lib packages (criu, cuda, etc.) into multi-step workflows.
package orchestrate
import (
"context"
"fmt"
"os"
"path/filepath"
"time"
criurpc "github.com/checkpoint-restore/go-criu/v8/rpc"
"github.com/containerd/containerd"
"github.com/go-logr/logr"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/criu"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/cuda"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/types"
)
// CheckpointRequest holds per-checkpoint identifiers for a checkpoint operation.
type CheckpointRequest struct {
ContainerID string
ContainerName string
CheckpointHash string
CheckpointDir string
NodeName string
PodName string
PodNamespace string
}
// Checkpoint performs a CRIU dump of a container.
// The operation has three phases: inspect, configure, capture.
//
// The checkpoint directory is staged under tmp/<hash> during the operation.
// On success, it is atomically renamed to <hash> at the base path root.
func Checkpoint(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req CheckpointRequest, cfg *types.AgentConfig) error {
checkpointStart := time.Now()
log.Info("=== Starting checkpoint operation ===")
finalDir := filepath.Join(req.CheckpointDir, req.CheckpointHash)
tmpDir := filepath.Join(req.CheckpointDir, "tmp", req.CheckpointHash)
if err := os.RemoveAll(tmpDir); err != nil {
return fmt.Errorf("failed to clean checkpoint staging directory: %w", err)
}
if err := os.MkdirAll(tmpDir, 0700); err != nil {
return fmt.Errorf("failed to create checkpoint directory: %w", err)
}
// Phase 1: Inspect container state
state, err := inspectContainer(ctx, ctrd, log, req)
if err != nil {
return err
}
// Phase 2: Configure CRIU options and build checkpoint manifest
criuOpts, data, err := configureCheckpoint(log, state, req, cfg, tmpDir)
if err != nil {
return err
}
// Phase 3: Capture — CRIU dump, rootfs diff
criuDumpDuration, err := captureCheckpoint(ctx, criuOpts, &cfg.CRIU, data, state, tmpDir, log)
if err != nil {
return err
}
// Remove any previous checkpoint with the same identity hash before finalizing
os.RemoveAll(finalDir)
if err := os.Rename(tmpDir, finalDir); err != nil {
return fmt.Errorf("failed to finalize checkpoint directory: %w", err)
}
totalDuration := time.Since(checkpointStart)
log.Info("=== Checkpoint operation completed ===",
"total_duration", totalDuration,
"criu_dump_duration", criuDumpDuration,
)
return nil
}
func inspectContainer(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req CheckpointRequest) (*types.CheckpointContainerSnapshot, error) {
containerID := req.ContainerID
pid, ociSpec, err := common.ResolveContainer(ctx, ctrd, containerID)
if err != nil {
return nil, fmt.Errorf("failed to resolve container: %w", err)
}
var hostCgroupPath string
if cgPath, err := common.ResolveCgroupRootFromHostPID(pid); err == nil && cgPath != "" {
hostCgroupPath = filepath.Join(common.HostCgroupPath, cgPath)
}
rootFS, err := common.GetRootFS(pid)
if err != nil {
return nil, fmt.Errorf("failed to get rootfs: %w", err)
}
upperDir, err := common.GetOverlayUpperDir(pid)
if err != nil {
return nil, fmt.Errorf("failed to get overlay upperdir: %w", err)
}
mountInfo, err := common.ReadMountInfo(pid)
if err != nil {
return nil, fmt.Errorf("failed to parse mountinfo: %w", err)
}
mounts := common.ClassifyMounts(mountInfo, ociSpec, rootFS)
netNSInode, err := common.GetNetNSInode(pid)
if err != nil {
return nil, fmt.Errorf("failed to get net namespace inode: %w", err)
}
// Read stdio FD targets (like runc's getPipeFds / descriptors.json).
stdioFDs := make([]string, 3)
for i := range 3 {
target, err := os.Readlink(fmt.Sprintf("%s/%d/fd/%d", common.HostProcPath, pid, i))
if err != nil {
log.V(1).Info("Failed to readlink stdio FD", "fd", i, "error", err)
continue
}
stdioFDs[i] = target
}
// Discover CUDA processes and GPU UUIDs
allPIDs := common.ProcessTreePIDs(pid)
cudaPIDs := cuda.FilterProcesses(ctx, allPIDs, log)
var gpuUUIDs []string
if len(cudaPIDs) > 0 {
gpuUUIDs, err = cuda.GetPodGPUUUIDs(ctx, req.PodName, req.PodNamespace, req.ContainerName)
if err != nil {
return nil, fmt.Errorf("failed to discover source GPU UUIDs: %w", err)
}
}
return &types.CheckpointContainerSnapshot{
PID: pid,
RootFS: rootFS,
UpperDir: upperDir,
OCISpec: ociSpec,
Mounts: mounts,
NetNSInode: netNSInode,
StdioFDs: stdioFDs,
HostCgroupPath: hostCgroupPath,
CUDAPIDs: cudaPIDs,
GPUUUIDs: gpuUUIDs,
}, nil
}
func configureCheckpoint(
log logr.Logger,
state *types.CheckpointContainerSnapshot,
req CheckpointRequest,
cfg *types.AgentConfig,
checkpointDir string,
) (*criurpc.CriuOpts, *types.CheckpointManifest, error) {
criuOpts, err := criu.BuildDumpOptions(state, &cfg.CRIU, checkpointDir, log)
if err != nil {
return nil, nil, err
}
m := types.NewCheckpointManifest(
req.CheckpointHash,
types.NewCRIUDumpManifest(criuOpts, cfg.CRIU),
types.NewSourcePodManifest(req.ContainerID, state.PID, req.NodeName, req.PodName, req.PodNamespace, state.StdioFDs),
types.NewOverlayManifest(cfg.Overlay, state.UpperDir, state.OCISpec),
)
if len(state.CUDAPIDs) > 0 {
m.CUDA = types.NewCUDAManifest(state.CUDAPIDs, state.GPUUUIDs)
}
if err := types.WriteManifest(checkpointDir, m); err != nil {
return nil, nil, fmt.Errorf("failed to write checkpoint manifest: %w", err)
}
return criuOpts, m, nil
}
func captureCheckpoint(ctx context.Context, criuOpts *criurpc.CriuOpts, criuSettings *types.CRIUSettings, data *types.CheckpointManifest, state *types.CheckpointContainerSnapshot, checkpointDir string, log logr.Logger) (time.Duration, error) {
// CUDA lock+checkpoint must happen before CRIU dump
if len(state.CUDAPIDs) > 0 {
if err := cuda.LockAndCheckpointProcessTree(ctx, state.CUDAPIDs, log); err != nil {
return 0, fmt.Errorf("CUDA checkpoint failed: %w", err)
}
}
criuDumpDuration, err := criu.ExecuteDump(criuOpts, checkpointDir, criuSettings, log)
if err != nil {
return 0, err
}
// Overlay rootfs diff capture is best-effort. Failures are logged but not
// propagated — a checkpoint without overlay diffs is still valid for restore
// (the base container image provides the filesystem).
if state.UpperDir != "" {
if _, err := common.CaptureRootfsDiff(state.UpperDir, checkpointDir, data.Overlay.Exclusions, data.Overlay.BindMountDests); err != nil {
log.Error(err, "Failed to capture rootfs diff")
}
if _, err := common.CaptureDeletedFiles(state.UpperDir, checkpointDir); err != nil {
log.Error(err, "Failed to capture deleted files")
}
}
return criuDumpDuration, nil
}
package orchestrate
import (
"context"
"fmt"
"syscall"
"time"
criurpc "github.com/checkpoint-restore/go-criu/v8/rpc"
"github.com/go-logr/logr"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/criu"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/cuda"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/types"
)
// RestoreOptions holds configuration for an in-namespace restore.
type RestoreOptions struct {
CheckpointPath string
CUDADeviceMap string
CgroupRoot string
}
// RestoreInNamespace performs a full restore from inside the target container's namespaces.
func RestoreInNamespace(ctx context.Context, opts RestoreOptions, log logr.Logger) (int, error) {
restoreStart := time.Now()
log.Info("Starting nsrestore workflow",
"checkpoint_path", opts.CheckpointPath,
"has_cuda_map", opts.CUDADeviceMap != "",
"cgroup_root", opts.CgroupRoot,
)
m, err := types.ReadManifest(opts.CheckpointPath)
if err != nil {
return 0, fmt.Errorf("failed to read manifest: %w", err)
}
log.Info("Loaded checkpoint manifest",
"ext_mounts", len(m.CRIUDump.ExtMnt),
"criu_log_level", m.CRIUDump.CRIU.LogLevel,
"manage_cgroups_mode", m.CRIUDump.CRIU.ManageCgroupsMode,
"checkpoint_has_cuda", !m.CUDA.IsEmpty(),
)
// Phase 1: Configure — build CRIU opts from manifest
criuOpts, err := criu.BuildRestoreOpts(m, opts.CgroupRoot, log)
if err != nil {
return 0, err
}
// Phase 2: Execute — rootfs, CRIU restore, CUDA restore
restoredPID, err := executeRestore(ctx, criuOpts, m, opts, log)
if err != nil {
return 0, err
}
log.Info("nsrestore completed", "restored_pid", restoredPID, "duration", time.Since(restoreStart))
return restoredPID, nil
}
func executeRestore(ctx context.Context, criuOpts *criurpc.CriuOpts, m *types.CheckpointManifest, opts RestoreOptions, log logr.Logger) (int, error) {
// Apply rootfs diff inside the namespace (target root is /)
if err := common.ApplyRootfsDiff(opts.CheckpointPath, "/", log); err != nil {
return 0, fmt.Errorf("rootfs diff failed: %w", err)
}
if err := common.ApplyDeletedFiles(opts.CheckpointPath, "/", log); err != nil {
log.Error(err, "Failed to apply deleted files")
}
// Unmount placeholder's /dev/shm so CRIU can recreate tmpfs with checkpointed content
if err := syscall.Unmount("/dev/shm", 0); err != nil {
return 0, fmt.Errorf("failed to unmount /dev/shm before restore: %w", err)
}
if err := common.RemountProcSys(true); err != nil {
return 0, fmt.Errorf("failed to remount /proc/sys read-write for restore: %w", err)
}
defer func() {
if err := common.RemountProcSys(false); err != nil {
log.Error(err, "Failed to remount /proc/sys read-only after restore")
}
}()
// CRIU restore
restoredPID, err := criu.ExecuteRestore(criuOpts, m, opts.CheckpointPath, log)
if err != nil {
return 0, err
}
// CUDA restore — discover PIDs in the restored process tree, then restore+unlock
if !m.CUDA.IsEmpty() {
candidates := common.ProcessTreePIDs(int(restoredPID))
cudaPIDs := cuda.FilterProcesses(ctx, candidates, log)
if len(cudaPIDs) == 0 {
return 0, fmt.Errorf("checkpoint has %d CUDA PIDs but none found in restored process tree", len(m.CUDA.PIDs))
}
if err := cuda.RestoreAndUnlockProcessTree(ctx, cudaPIDs, opts.CUDADeviceMap, log); err != nil {
return 0, fmt.Errorf("CUDA restore failed: %w", err)
}
}
return int(restoredPID), nil
}
package orchestrate
import (
"bytes"
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/containerd/containerd"
"github.com/go-logr/logr"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/criu"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/cuda"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/logging"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/types"
)
// RestoreRequest holds the parameters for a restore operation.
type RestoreRequest struct {
CheckpointHash string
CheckpointBase string
NSRestorePath string
PodName string
PodNamespace string
ContainerName string
}
// Restore performs external restore for the given request.
// Returns the namespace-relative PID of the restored process.
// The DaemonSet side inspects the placeholder and launches nsrestore,
// which handles rootfs application, CRIU restore, and CUDA restore inside the namespace.
func Restore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req RestoreRequest) (int, error) {
restoreStart := time.Now()
log.Info("=== Starting external restore ===",
"checkpoint_hash", req.CheckpointHash,
"pod", req.PodName,
"namespace", req.PodNamespace,
"container", req.ContainerName,
)
// Phase 1: Inspect — resolve placeholder, discover target GPUs, build device map
snap, err := inspectRestore(ctx, ctrd, log, req)
if err != nil {
return 0, err
}
// Phase 2: Execute — nsrestore handles rootfs, CRIU restore, and CUDA restore inside namespace
restoredPID, err := execNSRestore(ctx, log, req, snap)
if err != nil {
return 0, fmt.Errorf("nsrestore failed: %w", err)
}
log.Info("nsrestore completed", "restored_pid", restoredPID)
// Validate restored process from the host side
procRoot := filepath.Join(snap.TargetRoot, "proc")
if err := common.ValidateProcessState(procRoot, restoredPID); err != nil {
restoreLogPath := filepath.Join(snap.TargetRoot, "var", "criu-work", criu.RestoreLogFilename)
logging.LogProcessDiagnostics(procRoot, restoredPID, restoreLogPath, log)
return 0, fmt.Errorf("restored process failed post-restore validation: %w", err)
}
log.Info("=== External restore completed ===", "total_duration", time.Since(restoreStart))
return restoredPID, nil
}
func inspectRestore(ctx context.Context, ctrd *containerd.Client, log logr.Logger, req RestoreRequest) (*types.RestoreContainerSnapshot, error) {
checkpointPath := filepath.Join(req.CheckpointBase, req.CheckpointHash)
baseAbs, err := filepath.Abs(req.CheckpointBase)
if err != nil {
return nil, fmt.Errorf("failed to resolve checkpoint base path: %w", err)
}
checkpointAbs, err := filepath.Abs(checkpointPath)
if err != nil {
return nil, fmt.Errorf("failed to resolve checkpoint path: %w", err)
}
if checkpointAbs != baseAbs && !strings.HasPrefix(checkpointAbs, baseAbs+string(os.PathSeparator)) {
return nil, fmt.Errorf("invalid checkpoint hash %q", req.CheckpointHash)
}
m, err := types.ReadManifest(checkpointPath)
if err != nil {
return nil, fmt.Errorf("failed to read checkpoint manifest: %w", err)
}
containerName := req.ContainerName
if containerName == "" {
containerName = "main"
}
placeholderPID, _, err := common.ResolveContainerByPod(ctx, ctrd, req.PodName, req.PodNamespace, containerName)
if err != nil {
return nil, fmt.Errorf("failed to resolve placeholder container: %w", err)
}
log.Info("Resolved placeholder container", "pid", placeholderPID)
cgroupRoot, err := common.ResolveCgroupRootFromHostPID(placeholderPID)
if err != nil {
log.Error(err, "Failed to resolve placeholder cgroup root; proceeding without explicit cgroup remap")
cgroupRoot = ""
}
cudaDeviceMap := ""
if !m.CUDA.IsEmpty() {
if len(m.CUDA.SourceGPUUUIDs) == 0 {
return nil, fmt.Errorf("missing source GPU UUIDs in checkpoint manifest")
}
targetGPUUUIDs, err := cuda.GetPodGPUUUIDs(ctx, req.PodName, req.PodNamespace, containerName)
if err != nil {
return nil, fmt.Errorf("failed to get target GPU UUIDs: %w", err)
}
if len(targetGPUUUIDs) == 0 {
return nil, fmt.Errorf("missing target GPU UUIDs for %s/%s container %s", req.PodNamespace, req.PodName, containerName)
}
cudaDeviceMap, err = cuda.BuildDeviceMap(m.CUDA.SourceGPUUUIDs, targetGPUUUIDs)
if err != nil {
return nil, fmt.Errorf("failed to build CUDA device map: %w", err)
}
}
return &types.RestoreContainerSnapshot{
CheckpointPath: checkpointPath,
PlaceholderPID: placeholderPID,
TargetRoot: fmt.Sprintf("%s/%d/root", common.HostProcPath, placeholderPID),
CgroupRoot: cgroupRoot,
CUDADeviceMap: cudaDeviceMap,
}, nil
}
// execNSRestore launches the nsrestore binary inside the placeholder container's
// namespaces via nsenter and parses the restored PID from stdout JSON.
func execNSRestore(ctx context.Context, log logr.Logger, req RestoreRequest, snap *types.RestoreContainerSnapshot) (int, error) {
args := []string{
"-t", strconv.Itoa(snap.PlaceholderPID),
// Intentionally exclude cgroup namespace (-C): CRIU must manage cgroups
// from the host-visible hierarchy so --cgroup-root remap works.
"-m", "-u", "-i", "-n", "-p",
"--", req.NSRestorePath,
"--checkpoint-path", snap.CheckpointPath,
}
if snap.CUDADeviceMap != "" {
args = append(args, "--cuda-device-map", snap.CUDADeviceMap)
}
if snap.CgroupRoot != "" {
args = append(args, "--cgroup-root", snap.CgroupRoot)
}
cmd := exec.CommandContext(ctx, "nsenter", args...)
log.V(1).Info("Executing nsenter + nsrestore", "cmd", cmd.String())
var stdout bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return 0, fmt.Errorf("nsrestore failed: %w\nstdout: %s", err, stdout.String())
}
var result struct {
RestoredPID int `json:"restoredPID"`
}
if err := json.Unmarshal(stdout.Bytes(), &result); err != nil {
return 0, fmt.Errorf("failed to parse nsrestore result: %w\nstdout: %s", err, stdout.String())
}
if result.RestoredPID <= 0 {
return 0, fmt.Errorf("nsrestore returned invalid PID %d", result.RestoredPID)
}
return result.RestoredPID, nil
}
// config.go defines the RestoreRequest struct for CRIU restore operations.
// CRIU options come from the saved CheckpointManifest, not from this request.
//
// The restore-entrypoint runs in placeholder containers which do NOT mount the
// ConfigMap. Static defaults are hardcoded here; per-pod dynamic values come
// from environment variables injected by the operator.
package restore
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint"
)
const (
// RestoreLogFilename is the CRIU restore log filename.
RestoreLogFilename = "restore.log"
// CRIULogDir is the directory where CRIU restore logs are copied for debugging.
CRIULogDir = "/checkpoints/restore-logs"
// RestoreTriggerPath is the default path to the trigger file for trigger-based restore.
RestoreTriggerPath = "/tmp/restore-trigger"
)
// RestoreRequest holds runtime request inputs for the restore entrypoint.
// CRIU options are NOT stored here - they come from the saved CheckpointManifest.
type RestoreRequest struct {
// === Per-pod dynamic values (from operator-injected env vars) ===
// CheckpointPath is the base directory containing checkpoints.
CheckpointPath string
// CheckpointHash is the ID/hash of the checkpoint to restore.
CheckpointHash string
// CheckpointLocation is the full resolved path to the checkpoint directory.
CheckpointLocation string
// SkipWaitForCheckpoint controls the entrypoint behavior.
SkipWaitForCheckpoint bool
// ColdStartArgs is the command+args to exec if no checkpoint is available.
ColdStartArgs []string
// Debug enables debug logging.
Debug bool
// === Static defaults (hardcoded) ===
// RestoreMarkerFilePath is where restore-entrypoint writes a marker before CRIU restore.
RestoreMarkerFilePath string
// RestoreTrigger is the path to the trigger file that signals restore should start.
RestoreTrigger string
// WaitTimeout is the maximum time to wait for a checkpoint.
// Zero means wait indefinitely.
WaitTimeout time.Duration
}
// ConfigError represents a configuration validation error.
type ConfigError struct {
Field string
Message string
}
func (e *ConfigError) Error() string {
return fmt.Sprintf("config error: %s: %s", e.Field, e.Message)
}
// NewRestoreRequest creates a RestoreRequest with hardcoded defaults and
// operator-injected environment variable values.
func NewRestoreRequest(args []string) (*RestoreRequest, error) {
cfg := &RestoreRequest{
RestoreTrigger: RestoreTriggerPath,
ColdStartArgs: args,
}
if v := os.Getenv("DYN_CHECKPOINT_PATH"); v != "" {
cfg.CheckpointPath = v
}
if v := os.Getenv("DYN_CHECKPOINT_HASH"); v != "" {
cfg.CheckpointHash = v
}
if v := os.Getenv("DYN_CHECKPOINT_LOCATION"); v != "" {
cfg.CheckpointLocation = v
} else if cfg.CheckpointPath != "" && cfg.CheckpointHash != "" {
cfg.CheckpointLocation = cfg.CheckpointPath + "/" + cfg.CheckpointHash
}
cfg.SkipWaitForCheckpoint = os.Getenv("SKIP_WAIT_FOR_CHECKPOINT") == "1"
cfg.Debug = os.Getenv("DEBUG") == "1"
cfg.RestoreMarkerFilePath = os.Getenv("DYN_RESTORE_MARKER_FILE")
if cfg.RestoreMarkerFilePath == "" {
return nil, &ConfigError{
Field: "DYN_RESTORE_MARKER_FILE",
Message: "must be set",
}
}
return cfg, nil
}
type checkpointDoneMarker struct {
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
func checkpointDoneSucceeded(donePath string, log *logrus.Entry) bool {
data, err := os.ReadFile(donePath)
if err != nil {
log.WithError(err).WithField("path", donePath).Warn("Failed to read checkpoint.done marker")
return false
}
var marker checkpointDoneMarker
if err := json.Unmarshal(data, &marker); err != nil {
log.WithError(err).WithField("path", donePath).Warn("Failed to parse checkpoint.done marker")
return false
}
if !marker.Success {
fields := logrus.Fields{"path": donePath}
if marker.Error != "" {
fields["error"] = marker.Error
}
log.WithFields(fields).Warn("checkpoint.done marker reports failed checkpoint")
return false
}
return true
}
// ShouldRestore checks if a restore should be performed.
// Returns the checkpoint path and true if restore should proceed.
func ShouldRestore(cfg *RestoreRequest, log *logrus.Entry) (string, bool) {
// Method 1: Checkpoint location is set and checkpoint is fully complete
if cfg.CheckpointLocation != "" {
donePath := cfg.CheckpointLocation + "/" + checkpoint.CheckpointDoneFilename
if _, err := os.Stat(donePath); err == nil {
if checkpointDoneSucceeded(donePath, log) {
log.WithField("path", cfg.CheckpointLocation).Info("Checkpoint found (checkpoint.done success=true)")
return cfg.CheckpointLocation, true
}
}
// Fallback: check for manifest.yaml but warn about potential race condition.
manifestPath := cfg.CheckpointLocation + "/" + checkpoint.CheckpointManifestFilename
if _, err := os.Stat(manifestPath); err == nil {
log.WithFields(logrus.Fields{
"path": cfg.CheckpointLocation,
"warning": "checkpoint.done marker not found, checkpoint may be incomplete",
}).Warn("Checkpoint manifest found but checkpoint.done missing - checkpoint may still be in progress")
}
}
// Method 2: Restore trigger file exists with checkpoint path
if cfg.RestoreTrigger != "" {
data, err := os.ReadFile(cfg.RestoreTrigger)
if err == nil {
checkpointPath := strings.TrimSpace(string(data))
if checkpointPath != "" {
donePath := checkpointPath + "/" + checkpoint.CheckpointDoneFilename
if _, err := os.Stat(donePath); err == nil {
if checkpointDoneSucceeded(donePath, log) {
log.WithField("path", checkpointPath).Info("Restore triggered via file (checkpoint.done success=true)")
return checkpointPath, true
}
}
}
}
}
return "", false
}
// WaitForCheckpoint waits for a checkpoint to become available.
// If cfg.WaitTimeout is zero, waits indefinitely (until ctx is cancelled).
func WaitForCheckpoint(ctx context.Context, cfg *RestoreRequest, log *logrus.Entry) (string, error) {
if cfg.WaitTimeout > 0 {
log.WithField("timeout", cfg.WaitTimeout).Info("Waiting for checkpoint")
} else {
log.Info("Waiting for checkpoint indefinitely")
}
startTime := time.Now()
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
lastLog := time.Now()
for {
select {
case <-ctx.Done():
return "", ctx.Err()
case <-ticker.C:
if path, ok := ShouldRestore(cfg, log); ok {
return path, nil
}
// Log progress every 30 seconds
if time.Since(lastLog) >= 30*time.Second {
elapsed := time.Since(startTime)
log.WithField("elapsed", elapsed).Info("Still waiting for checkpoint...")
lastLog = time.Now()
}
// Only enforce deadline if WaitTimeout is set (non-zero)
if cfg.WaitTimeout > 0 && time.Since(startTime) >= cfg.WaitTimeout {
return "", fmt.Errorf("timed out waiting for checkpoint after %s", cfg.WaitTimeout)
}
}
}
}
// criu provides CRIU-specific configuration and utilities for restore operations.
package restore
import (
"os"
criurpc "github.com/checkpoint-restore/go-criu/v7/rpc"
"github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"google.golang.org/protobuf/proto"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
)
// CRIURestorePlan holds configuration for CRIU restore operations.
// Most fields come from the saved CheckpointManifest.CRIUDump.CRIU settings.
type CRIURestorePlan struct {
// File descriptors
ImageDirFD int32
WorkDirFD int32
NetNsFD int32
// Paths
RootPath string
LogFile string
// Options from CheckpointManifest.CRIUDump.CRIU.
LogLevel int32
Timeout uint32 // CRIU timeout in seconds (0 = no timeout, required for CUDA)
ShellJob bool // Allow session leaders (containers are often session leaders)
TcpClose bool // Close TCP connections (pod IPs change on restore)
FileLocks bool // Allow file locks
ExtUnixSk bool // Allow external Unix sockets
LinkRemap bool // Handle deleted-but-open files via CRIU link remap
ManageCgroupsMode string // Cgroup handling mode: "ignore" lets K8s manage cgroups
// External mount mappings (from CheckpointManifest.CRIUDump.ExtMnt).
ExtMountMaps []*criurpc.ExtMountMap
}
// OpenImageDir opens a checkpoint directory and clears CLOEXEC for CRIU.
// Returns the opened file and its FD. Caller must close the file when done.
func OpenImageDir(checkpointPath string) (*os.File, int32, error) {
return common.OpenPathForCRIU(checkpointPath)
}
// OpenNetworkNamespace opens the target network namespace for restore.
// Returns the opened file and its FD. Caller must close the file when done.
func OpenNetworkNamespace(nsPath string) (*os.File, int32, error) {
return common.OpenPathForCRIU(nsPath)
}
// OpenWorkDir opens a work directory for CRIU and clears CLOEXEC.
// Returns the opened file and its FD, or nil/-1 if workDir is empty or fails.
func OpenWorkDir(workDir string, log *logrus.Entry) (*os.File, int32) {
if workDir == "" {
return nil, -1
}
// Ensure work directory exists
if err := os.MkdirAll(workDir, 0755); err != nil {
log.WithError(err).Warn("Failed to create CRIU work directory, using default")
return nil, -1
}
workDirFile, err := os.Open(workDir)
if err != nil {
log.WithError(err).Warn("Failed to open CRIU work directory, using default")
return nil, -1
}
if _, err := unix.FcntlInt(workDirFile.Fd(), unix.F_SETFD, 0); err != nil {
log.WithError(err).Warn("Failed to clear CLOEXEC on work dir")
workDirFile.Close()
return nil, -1
}
log.WithField("path", workDir).Info("Using custom CRIU work directory")
return workDirFile, int32(workDirFile.Fd())
}
// BuildCRIURestoreOptions creates CRIU options for restore from a runtime plan.
//
// Options from CheckpointManifest.CRIUDump.CRIU (saved at checkpoint time):
// - ShellJob, TcpClose, FileLocks, ExtUnixSk, LinkRemap, ManageCgroupsMode
//
// Hardcoded restore-specific options:
// - RstSibling: restore in detached mode
// - MntnsCompatMode: cross-container restore
// - EvasiveDevices, ForceIrmap: device/inode handling
func BuildCRIURestoreOptions(plan CRIURestorePlan) *criurpc.CriuOpts {
// Map cgroup management mode from plan.
var cgMode criurpc.CriuCgMode
switch plan.ManageCgroupsMode {
case "soft":
cgMode = criurpc.CriuCgMode_SOFT
case "full":
cgMode = criurpc.CriuCgMode_FULL
case "strict":
cgMode = criurpc.CriuCgMode_STRICT
case "ignore", "":
cgMode = criurpc.CriuCgMode_IGNORE
default:
cgMode = criurpc.CriuCgMode_IGNORE
}
criuOpts := &criurpc.CriuOpts{
ImagesDirFd: proto.Int32(plan.ImageDirFD),
LogLevel: proto.Int32(plan.LogLevel),
LogFile: proto.String(plan.LogFile),
// Root filesystem - use current container's root
Root: proto.String(plan.RootPath),
// Restore in detached mode - process runs in background (restore-specific)
RstSibling: proto.Bool(true),
// Mount namespace mode:
// - MntnsCompatMode=false (default): Uses mount-v2 with MOVE_MOUNT_SET_GROUP (kernel 5.15+)
// This is preferred as it doesn't create temp dirs in /tmp
// - MntnsCompatMode=true: Uses compat mode which creates /tmp/cr-tmpfs.XXX
// This can cause "Device or resource busy" errors on cleanup
// We explicitly set to false to use mount-v2 (requires kernel 5.15+)
MntnsCompatMode: proto.Bool(false),
// Options from saved CheckpointManifest.CRIUDump.CRIU.
ShellJob: proto.Bool(plan.ShellJob),
TcpClose: proto.Bool(plan.TcpClose),
FileLocks: proto.Bool(plan.FileLocks),
ExtUnixSk: proto.Bool(plan.ExtUnixSk),
LinkRemap: proto.Bool(plan.LinkRemap),
// Cgroup management from saved settings.
ManageCgroups: proto.Bool(true),
ManageCgroupsMode: &cgMode,
// Device and inode handling (restore-specific)
EvasiveDevices: proto.Bool(true),
ForceIrmap: proto.Bool(true),
// External mount mappings
ExtMnt: plan.ExtMountMaps,
}
// Add network namespace inheritance if provided
if plan.NetNsFD >= 0 {
criuOpts.InheritFd = []*criurpc.InheritFd{
{
Key: proto.String("extNetNs"),
Fd: proto.Int32(plan.NetNsFD),
},
}
}
// Add work directory if specified
if plan.WorkDirFD >= 0 {
criuOpts.WorkDirFd = proto.Int32(plan.WorkDirFD)
}
// Add timeout if specified (required for CUDA restores)
if plan.Timeout > 0 {
criuOpts.Timeout = proto.Uint32(plan.Timeout)
}
return criuOpts
}
package restore
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"github.com/sirupsen/logrus"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint"
)
// ApplyRootfsDiff extracts the rootfs-diff.tar from the checkpoint to the target root.
// This restores filesystem changes that were made in the original container.
func ApplyRootfsDiff(checkpointPath, targetRoot string, log *logrus.Entry) error {
rootfsDiffPath := filepath.Join(checkpointPath, checkpoint.RootfsDiffFilename)
// Check if rootfs-diff.tar exists
if _, err := os.Stat(rootfsDiffPath); os.IsNotExist(err) {
log.Info("No rootfs-diff.tar found, skipping filesystem restoration")
return nil
}
log.WithField("path", rootfsDiffPath).Info("Applying rootfs diff")
// Exclusions are already applied at checkpoint time (bind mounts, system dirs, etc.)
// so we just extract with --keep-old-files to avoid overwriting existing files.
cmd := exec.Command("tar",
"--keep-old-files",
"-C", targetRoot,
"-xf", rootfsDiffPath,
)
output, err := cmd.CombinedOutput()
if err != nil {
// Some failures are expected (read-only mounts, existing files)
// tar returns exit code 1 for "file exists" which is not fatal for us
if exitErr, ok := err.(*exec.ExitError); ok && exitErr.ExitCode() == 1 {
log.WithField("output", string(output)).Info("Rootfs diff applied (some files may have been skipped due to mounts)")
return nil
}
return fmt.Errorf("failed to extract rootfs diff: %w (output: %s)", err, string(output))
}
log.Info("Rootfs diff applied successfully")
return nil
}
// ApplyDeletedFiles removes files that were deleted in the original container.
// These are tracked via overlay whiteout markers (.wh.<filename>).
func ApplyDeletedFiles(checkpointPath, targetRoot string, log *logrus.Entry) error {
deletedFilesPath := filepath.Join(checkpointPath, checkpoint.DeletedFilesFilename)
// Check if deleted-files.json exists
data, err := os.ReadFile(deletedFilesPath)
if os.IsNotExist(err) {
log.Debug("No deleted-files.json found")
return nil
}
if err != nil {
return fmt.Errorf("failed to read deleted files list: %w", err)
}
log.Info("Applying deleted files from whiteout list")
// Parse JSON array of deleted file paths
var deletedFiles []string
if err := json.Unmarshal(data, &deletedFiles); err != nil {
return fmt.Errorf("failed to parse deleted files JSON: %w", err)
}
deletedCount := 0
for _, filePath := range deletedFiles {
if filePath == "" {
continue
}
targetPath := filepath.Join(targetRoot, filePath)
// Check if file exists before attempting deletion
if _, err := os.Stat(targetPath); os.IsNotExist(err) {
continue
}
if err := os.RemoveAll(targetPath); err != nil {
log.WithError(err).WithField("path", targetPath).Debug("Could not delete file")
continue
}
deletedCount++
}
log.WithField("count", deletedCount).Info("Deleted files applied")
return nil
}
// CheckpointFilesExist verifies that the checkpoint directory contains valid checkpoint files.
func CheckpointFilesExist(checkpointPath string) bool {
// Check for CRIU image files (core-*.img is always present)
matches, err := filepath.Glob(filepath.Join(checkpointPath, "core-*.img"))
return err == nil && len(matches) > 0
}
// Package restore provides CRIU restore operations.
package restore
import (
"encoding/binary"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"github.com/checkpoint-restore/go-criu/v7/crit"
"github.com/checkpoint-restore/go-criu/v7/crit/images/fdinfo"
"github.com/checkpoint-restore/go-criu/v7/crit/images/regfile"
remap_file_path "github.com/checkpoint-restore/go-criu/v7/crit/images/remap-file-path"
"github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto"
)
// CreateLinkRemapStubs parses CRIU images to find remapped files and creates
// the link_remap stub files needed for CRIU restore.
//
// Background: When a file is unlink()'d but a process still has an open FD to it,
// CRIU handles this via "link remapping":
//
// - During dump: CRIU creates a hardlink link_remap.<id> -> original_file
// - During restore: CRIU does linkat(link_remap.<id>, original_path) to recreate it
//
// The link_remap file only exists on the original node's filesystem. For cross-node
// restore, we must create stub files so CRIU can hardlink from them.
//
// Without these stubs, CRIU fails with:
//
// "Can't link <path>/link_remap.X -> <path>/original: No such file or directory"
func CreateLinkRemapStubs(checkpointPath string, log *logrus.Entry) error {
// 1. Parse remap-fpath.img to find files that need remapping
remapPath := filepath.Join(checkpointPath, "remap-fpath.img")
remaps, err := parseRemapFpath(remapPath)
if err != nil {
if os.IsNotExist(err) {
log.Debug("No remap-fpath.img found, no link_remap stubs needed")
return nil
}
return fmt.Errorf("failed to parse remap-fpath.img: %w", err)
}
if len(remaps) == 0 {
log.Debug("No file remaps found in checkpoint")
return nil
}
// 2. Parse file info to build ID -> fileInfo mapping
// Try reg-files.img first (older CRIU format), fall back to files.img (newer format)
regFilesPath := filepath.Join(checkpointPath, "reg-files.img")
filesPath := filepath.Join(checkpointPath, "files.img")
var fileMap map[uint32]fileInfo
var parseErr error
// Try reg-files.img first (older CRIU format)
fileMap, parseErr = parseRegFilesWithMode(regFilesPath)
if parseErr != nil {
log.WithError(parseErr).Debug("Could not parse reg-files.img, trying files.img")
// Fall back to files.img (newer format)
fileMap, parseErr = parseFilesImgWithMode(filesPath)
if parseErr != nil {
log.WithError(parseErr).WithField("remap_count", len(remaps)).Warn(
"Found remap entries but could not parse reg-files.img or files.img — link_remap stubs will not be created")
return fmt.Errorf("found %d remap entries but could not build file map: %w", len(remaps), parseErr)
}
}
// 3. Create link_remap stub files for all remapped files
var created []string
for _, remap := range remaps {
// Look up the original file by ID
origInfo, ok := fileMap[remap.origID]
if !ok {
log.WithField("orig_id", remap.origID).Debug("Original file ID not found in file map, skipping")
continue
}
// Look up the remap file path by remap ID
// This is the link_remap.XXX file that CRIU will hardlink FROM
remapInfo, ok := fileMap[remap.remapID]
var remapName string
var mode os.FileMode
if ok {
remapName = remapInfo.name
mode = remapInfo.mode
} else {
// If we can't find the remap file in fileMap, construct it
// CRIU creates link_remap files in the same directory as the original
// with format: link_remap.<remap_id>
dir := filepath.Dir(origInfo.name)
if !strings.HasPrefix(dir, "/") {
dir = "/" + dir
}
remapName = filepath.Join(dir, fmt.Sprintf("link_remap.%d", remap.remapID))
// Use original file's mode since we don't have the remap file's mode
mode = origInfo.mode
log.WithFields(logrus.Fields{
"orig_id": remap.origID,
"remap_id": remap.remapID,
"orig_path": origInfo.name,
"remap_path": remapName,
"mode": fmt.Sprintf("%04o", mode),
}).Debug("Constructed link_remap path from remap ID")
}
// Normalize path
if !strings.HasPrefix(remapName, "/") {
remapName = "/" + remapName
}
// Check if the link_remap file already exists
if _, err := os.Stat(remapName); err == nil {
log.WithField("remap_file", remapName).Debug("Link remap file already exists")
continue
}
// Create the link_remap stub file with correct permissions
// CRIU will hardlink FROM this file TO the original path
if err := createLinkRemapStub(remapName, mode); err != nil {
log.WithError(err).WithFields(logrus.Fields{
"remap_file": remapName,
"target": origInfo.name,
"mode": fmt.Sprintf("%04o", mode),
}).Warn("Failed to create link_remap stub")
continue
}
created = append(created, filepath.Base(remapName))
log.WithFields(logrus.Fields{
"remap_file": remapName,
"target": origInfo.name,
"mode": fmt.Sprintf("%04o", mode),
}).Debug("Created link_remap stub file")
}
if len(created) > 0 {
log.WithFields(logrus.Fields{
"count": len(created),
"remap_files": created,
}).Info("Created link_remap stub files for CRIU restore")
} else {
log.Debug("No link_remap stubs needed")
}
return nil
}
// fileInfo holds file metadata from CRIU checkpoint images
type fileInfo struct {
name string
mode os.FileMode
}
// remapEntry represents a file remap entry from CRIU
type remapEntry struct {
origID uint32
remapID uint32
remapType int32
}
// parseRemapFpath parses the remap-fpath.img file
func parseRemapFpath(path string) ([]remapEntry, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
// Read and validate magic number using go-criu's ReadMagic
magic, err := crit.ReadMagic(f)
if err != nil {
return nil, fmt.Errorf("failed to read magic: %w", err)
}
if magic != "REMAP_FPATH" {
return nil, fmt.Errorf("unexpected magic: %s (expected REMAP_FPATH)", magic)
}
var entries []remapEntry
sizeBuf := make([]byte, 4)
for {
// Read entry size
_, err := io.ReadFull(f, sizeBuf)
if err == io.EOF || err == io.ErrUnexpectedEOF {
break
}
if err != nil {
return nil, fmt.Errorf("failed to read entry size: %w", err)
}
entrySize := binary.LittleEndian.Uint32(sizeBuf)
entryBuf := make([]byte, entrySize)
if _, err := io.ReadFull(f, entryBuf); err != nil {
return nil, fmt.Errorf("failed to read entry data: %w", err)
}
// Parse protobuf
entry := &remap_file_path.RemapFilePathEntry{}
if err := proto.Unmarshal(entryBuf, entry); err != nil {
return nil, fmt.Errorf("failed to unmarshal entry: %w", err)
}
entries = append(entries, remapEntry{
origID: entry.GetOrigId(),
remapID: entry.GetRemapId(),
remapType: int32(entry.GetRemapType()),
})
}
return entries, nil
}
// parseRegFilesWithMode parses the reg-files.img file and returns a map of ID -> fileInfo
func parseRegFilesWithMode(path string) (map[uint32]fileInfo, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
// Read and validate magic number using go-criu's ReadMagic
magic, err := crit.ReadMagic(f)
if err != nil {
return nil, fmt.Errorf("failed to read magic: %w", err)
}
if magic != "REG_FILES" {
return nil, fmt.Errorf("unexpected magic: %s (expected REG_FILES)", magic)
}
fileMap := make(map[uint32]fileInfo)
sizeBuf := make([]byte, 4)
for {
// Read entry size
_, err := io.ReadFull(f, sizeBuf)
if err == io.EOF || err == io.ErrUnexpectedEOF {
break
}
if err != nil {
return nil, fmt.Errorf("failed to read entry size: %w", err)
}
entrySize := binary.LittleEndian.Uint32(sizeBuf)
entryBuf := make([]byte, entrySize)
if _, err := io.ReadFull(f, entryBuf); err != nil {
return nil, fmt.Errorf("failed to read entry data: %w", err)
}
// Parse protobuf
entry := &regfile.RegFileEntry{}
if err := proto.Unmarshal(entryBuf, entry); err != nil {
return nil, fmt.Errorf("failed to unmarshal entry: %w", err)
}
// Convert CRIU mode (includes file type bits) to os.FileMode
// CRIU stores the full st_mode, we need just the permission bits
mode := os.FileMode(entry.GetMode() & 0777)
if mode == 0 {
mode = 0600 // Default to owner read/write if mode not set
}
fileMap[entry.GetId()] = fileInfo{
name: entry.GetName(),
mode: mode,
}
}
return fileMap, nil
}
// parseFilesImgWithMode parses the files.img file and returns a map of ID -> fileInfo
// This is the newer CRIU format where file info is embedded in FileEntry messages
func parseFilesImgWithMode(path string) (map[uint32]fileInfo, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
// Read and validate magic number using go-criu's ReadMagic
magic, err := crit.ReadMagic(f)
if err != nil {
return nil, fmt.Errorf("failed to read magic: %w", err)
}
if magic != "FILES" {
return nil, fmt.Errorf("unexpected magic: %s (expected FILES)", magic)
}
fileMap := make(map[uint32]fileInfo)
sizeBuf := make([]byte, 4)
for {
// Read entry size
_, err := io.ReadFull(f, sizeBuf)
if err == io.EOF || err == io.ErrUnexpectedEOF {
break
}
if err != nil {
return nil, fmt.Errorf("failed to read entry size: %w", err)
}
entrySize := binary.LittleEndian.Uint32(sizeBuf)
entryBuf := make([]byte, entrySize)
if _, err := io.ReadFull(f, entryBuf); err != nil {
return nil, fmt.Errorf("failed to read entry data: %w", err)
}
// Parse protobuf as FileEntry
entry := &fdinfo.FileEntry{}
if err := proto.Unmarshal(entryBuf, entry); err != nil {
return nil, fmt.Errorf("failed to unmarshal entry: %w", err)
}
// Extract fileinfo from embedded RegFileEntry if present
if entry.GetReg() != nil {
reg := entry.GetReg()
// Convert CRIU mode to os.FileMode (permission bits only)
mode := os.FileMode(reg.GetMode() & 0777)
if mode == 0 {
mode = 0600 // Default to owner read/write if mode not set
}
fileMap[entry.GetId()] = fileInfo{
name: reg.GetName(),
mode: mode,
}
}
}
return fileMap, nil
}
// createLinkRemapStub creates an empty stub file for CRIU link_remap.
// The file is created with the specified mode to match what CRIU expects.
func createLinkRemapStub(path string, mode os.FileMode) error {
// Ensure parent directory exists
dir := filepath.Dir(path)
if err := os.MkdirAll(dir, 0755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", dir, err)
}
// Create file with the specified mode
// CRIU validates the file mode matches what was recorded at checkpoint time
f, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode)
if err != nil {
return fmt.Errorf("failed to create file: %w", err)
}
defer f.Close()
// Write 32 bytes of zeros as stub content
// This provides a minimal valid file for CRIU to hardlink from
stub := make([]byte, 32)
if _, err := f.Write(stub); err != nil {
return fmt.Errorf("failed to write stub data: %w", err)
}
return nil
}
package restore
import (
"fmt"
criurpc "github.com/checkpoint-restore/go-criu/v7/rpc"
"google.golang.org/protobuf/proto"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint"
)
// GenerateExtMountMaps generates external mount mappings for CRIU restore.
// It reuses the exact dump-time ext-mount plan persisted in checkpoint manifest.
func GenerateExtMountMaps(data *checkpoint.CheckpointManifest) ([]*criurpc.ExtMountMap, error) {
if data == nil {
return nil, fmt.Errorf("checkpoint manifest is required")
}
if len(data.CRIUDump.ExtMnt) == 0 {
return nil, fmt.Errorf("checkpoint manifest is missing criuDump.extMnt")
}
maps := []*criurpc.ExtMountMap{{
Key: proto.String("/"),
Val: proto.String("."),
}}
addedMounts := map[string]struct{}{"/": {}}
// Replay dump-time ext-mount plan exactly, with restore-specific root remap.
for _, mount := range data.CRIUDump.ExtMnt {
key := mount.Key
if key == "" || key == "/" {
continue
}
if _, exists := addedMounts[key]; exists {
continue
}
val := mount.Val
if val == "" {
val = key
}
maps = append(maps, &criurpc.ExtMountMap{
Key: proto.String(key),
Val: proto.String(val),
})
addedMounts[key] = struct{}{}
}
return maps, nil
}
package restore
import (
criu "github.com/checkpoint-restore/go-criu/v7"
"github.com/sirupsen/logrus"
)
// RestoreNotify implements criu.Notify for restore callbacks.
// It captures the restored process PID from the PostRestore callback.
type RestoreNotify struct {
criu.NoNotify // Embed no-op implementation for all methods
// RestoredPID is the PID of the restored process, set by PostRestore callback
RestoredPID int32
// log is the logger for notification events
log *logrus.Entry
}
// NewRestoreNotify creates a new RestoreNotify handler.
func NewRestoreNotify(log *logrus.Entry) *RestoreNotify {
return &RestoreNotify{
log: log,
}
}
// PreRestore is called before CRIU starts the restore operation.
func (n *RestoreNotify) PreRestore() error {
if n.log != nil {
n.log.Debug("CRIU pre-restore notification")
}
return nil
}
// PostRestore is called after CRIU completes the restore operation.
// The pid parameter contains the PID of the restored process.
func (n *RestoreNotify) PostRestore(pid int32) error {
n.RestoredPID = pid
if n.log != nil {
n.log.WithField("pid", pid).Info("CRIU post-restore notification: process restored")
}
return nil
}
// PostResume is called after the restored process has resumed execution.
func (n *RestoreNotify) PostResume() error {
if n.log != nil {
n.log.Debug("CRIU post-resume notification")
}
return nil
}
package restore
import (
"errors"
"fmt"
"io"
"os"
"os/exec"
"os/signal"
"strconv"
"strings"
"sync"
"syscall"
"time"
"github.com/sirupsen/logrus"
)
// MonitorProcess monitors the restored process and returns its exit code.
// It blocks until the process exits. Does not forward stdout/stderr.
// For output forwarding, use ForwardProcessOutput instead.
func MonitorProcess(pid int, log *logrus.Entry) int {
log.WithField("pid", pid).Info("Monitoring restored process")
for {
// Check if process still exists by sending signal 0
proc, err := os.FindProcess(pid)
if err != nil {
log.WithError(err).Error("Failed to find process")
return 1
}
err = proc.Signal(syscall.Signal(0))
if err != nil {
// Process has exited
log.WithField("pid", pid).Info("Restored process exited")
// Try to read exit status from /proc/<pid>/stat
// If process is gone, assume exit code 0
exitCode := getExitCode(pid)
log.WithField("exit_code", exitCode).Info("Restored process exit status")
return exitCode
}
time.Sleep(time.Second)
}
}
// ForwardProcessOutput forwards the stdout and stderr of a restored process
// to our own stdout/stderr via /proc/<pid>/fd/1 and /proc/<pid>/fd/2.
// This ensures logs from the restored process appear in kubectl logs.
// Returns the exit code of the process.
func ForwardProcessOutput(pid int, log *logrus.Entry) int {
log.WithField("pid", pid).Info("Forwarding output from restored process")
// Try to open the process's stdout and stderr via /proc
stdoutPath := fmt.Sprintf("/proc/%d/fd/1", pid)
stderrPath := fmt.Sprintf("/proc/%d/fd/2", pid)
var wg sync.WaitGroup
// Forward stdout
wg.Add(1)
go forwardFD(stdoutPath, os.Stdout, "stdout", log, &wg)
// Forward stderr
wg.Add(1)
go forwardFD(stderrPath, os.Stderr, "stderr", log, &wg)
// Wait for process to exit (and reap it if it's our child).
exitCode := waitForProcess(pid, log)
// Give copy goroutines a short window to flush/finish.
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(2 * time.Second):
log.WithField("pid", pid).Warn("Timed out waiting for output forwarding goroutines to finish")
}
return exitCode
}
// forwardFD copies data from a file descriptor path to a writer.
// It handles the case where the FD may not be readable.
func forwardFD(fdPath string, dst io.Writer, name string, log *logrus.Entry, wg *sync.WaitGroup) {
defer wg.Done()
// Try to open the FD path
src, err := os.Open(fdPath)
if err != nil {
log.WithError(err).WithField("path", fdPath).Debug("Could not open process FD for forwarding")
return
}
defer src.Close()
// Check what kind of file this is
stat, err := src.Stat()
if err != nil {
log.WithError(err).WithField("path", fdPath).Debug("Could not stat process FD")
return
}
log.WithFields(logrus.Fields{
"name": name,
"mode": stat.Mode().String(),
"path": fdPath,
}).Debug("Forwarding process output")
_, err = io.Copy(dst, src)
if err != nil && !errors.Is(err, io.EOF) {
log.WithError(err).WithField("name", name).Debug("Error reading from process FD")
}
}
// waitForProcess waits for a process to exit and returns its exit code.
func waitForProcess(pid int, log *logrus.Entry) int {
// Preferred path: restored process is typically our direct child.
// Use wait4() so zombies are reaped and exit status is reliable.
var status syscall.WaitStatus
for {
wpid, err := syscall.Wait4(pid, &status, 0, nil)
if errors.Is(err, syscall.EINTR) {
continue
}
if err != nil {
if errors.Is(err, syscall.ECHILD) {
log.WithField("pid", pid).Warn("Restored process is not a child; falling back to signal-based monitoring")
return waitForProcessBySignal(pid, log)
}
log.WithError(err).WithField("pid", pid).Error("Wait4 failed for restored process")
return 1
}
if wpid != pid {
continue
}
if status.Exited() {
exitCode := status.ExitStatus()
log.WithFields(logrus.Fields{
"pid": pid,
"exit_code": exitCode,
}).Info("Restored process exited")
return exitCode
}
if status.Signaled() {
exitCode := 128 + int(status.Signal())
log.WithFields(logrus.Fields{
"pid": pid,
"signal": status.Signal().String(),
"exit_code": exitCode,
}).Warn("Restored process terminated by signal")
return exitCode
}
log.WithField("pid", pid).Warn("Restored process exited with unexpected wait status")
return 1
}
}
func waitForProcessBySignal(pid int, log *logrus.Entry) int {
for {
proc, err := os.FindProcess(pid)
if err != nil {
log.WithError(err).WithField("pid", pid).Error("Failed to find restored process")
return 1
}
if err := proc.Signal(syscall.Signal(0)); err != nil {
log.WithField("pid", pid).Info("Restored process no longer exists")
return 0
}
// Detect zombie state when wait4 is unavailable.
if state, err := readProcState(pid); err == nil && state == "Z" {
log.WithField("pid", pid).Warn("Restored process is zombie while not reaped by this process")
return 1
}
time.Sleep(100 * time.Millisecond)
}
}
// getExitCode attempts to get the exit code of a process.
// Returns 0 if unable to determine the exit code.
func getExitCode(pid int) int {
// Try to wait for the process (only works if we're the parent)
proc, err := os.FindProcess(pid)
if err != nil {
return 0
}
// Try waitpid with WNOHANG - this may not work for non-child processes
var wstatus syscall.WaitStatus
wpid, err := syscall.Wait4(pid, &wstatus, syscall.WNOHANG, nil)
if err == nil && wpid == pid {
if wstatus.Exited() {
return wstatus.ExitStatus()
}
if wstatus.Signaled() {
return 128 + int(wstatus.Signal())
}
}
// If we can't wait on it, check if it's still running
if proc.Signal(syscall.Signal(0)) != nil {
// Process is gone, assume clean exit
return 0
}
return 0
}
func readProcState(pid int) (string, error) {
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/status", pid))
if err != nil {
return "", err
}
for _, line := range strings.Split(string(data), "\n") {
if strings.HasPrefix(line, "State:") {
fields := strings.Fields(line)
if len(fields) >= 2 {
return fields[1], nil
}
break
}
}
return "", fmt.Errorf("state field not found in /proc/%d/status", pid)
}
// SetupSignalForwarding sets up signal forwarding to the restored process.
// Returns a cleanup function that should be called when done.
func SetupSignalForwarding(pid int, log *logrus.Entry) func() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT, syscall.SIGQUIT)
done := make(chan struct{})
go func() {
select {
case sig := <-sigChan:
log.WithFields(logrus.Fields{
"signal": sig,
"pid": pid,
}).Info("Forwarding signal to restored process")
proc, err := os.FindProcess(pid)
if err == nil {
proc.Signal(sig)
}
case <-done:
return
}
}()
return func() {
signal.Stop(sigChan)
close(done)
}
}
// WaitForPidFile waits for the CRIU PID file to be created and returns the PID.
func WaitForPidFile(pidFile string, timeout time.Duration, log *logrus.Entry) (int, error) {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
data, err := os.ReadFile(pidFile)
if err == nil {
pidStr := strings.TrimSpace(string(data))
pid, err := strconv.Atoi(pidStr)
if err == nil && pid > 0 {
return pid, nil
}
}
time.Sleep(100 * time.Millisecond)
}
return 0, fmt.Errorf("timeout waiting for PID file %s after %v", pidFile, timeout)
}
// ExecColdStart execs the cold start command (ColdStartArgs), replacing the current process.
// If no args are provided, falls back to sleep infinity.
func ExecColdStart(cfg *RestoreRequest, log *logrus.Entry) error {
if len(cfg.ColdStartArgs) == 0 {
log.Warn("No cold start command provided, sleeping indefinitely")
return ExecArgs([]string{"sleep", "infinity"}, log)
}
log.WithField("cmd", cfg.ColdStartArgs).Info("Executing cold start command")
return ExecArgs(cfg.ColdStartArgs, log)
}
// ExecArgs replaces the current process with the given command and arguments.
// Uses syscall.Exec for proper PID 1 behavior in containers.
func ExecArgs(args []string, log *logrus.Entry) error {
if len(args) == 0 {
return fmt.Errorf("empty command")
}
// Find the executable path
path, err := exec.LookPath(args[0])
if err != nil {
return fmt.Errorf("command not found: %s: %w", args[0], err)
}
log.WithFields(logrus.Fields{
"path": path,
"args": args,
}).Debug("Replacing process via syscall.Exec")
// Replace current process with the command
return syscall.Exec(path, args, os.Environ())
}
package restore
import (
"bufio"
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"sort"
"strings"
"syscall"
"time"
criu "github.com/checkpoint-restore/go-criu/v7"
"github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint"
)
// LogGPUDiagnostics logs nvidia-smi and /dev/nvidia* for debugging GPU visibility.
func LogGPUDiagnostics(label string, log *logrus.Entry) {
log.Infof("=== GPU DIAGNOSTICS [%s] ===", label)
diagCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if out, err := exec.CommandContext(diagCtx, "nvidia-smi", "-L").CombinedOutput(); err != nil {
log.Infof("nvidia-smi -L: error: %v", err)
} else {
log.Infof("nvidia-smi -L:\n%s", string(out))
}
// Also log memory usage per GPU to detect OOM conditions
diagCtx2, cancel2 := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel2()
if out, err := exec.CommandContext(diagCtx2, "nvidia-smi", "--query-gpu=index,uuid,memory.used,memory.total,memory.free", "--format=csv,noheader").CombinedOutput(); err != nil {
log.Infof("nvidia-smi memory query: error: %v", err)
} else {
log.Infof("nvidia-smi memory:\n%s", string(out))
}
matches, _ := filepath.Glob("/dev/nvidia*")
log.Infof("/dev/nvidia* devices: %s", strings.Join(matches, ", "))
log.Infof("NVIDIA_VISIBLE_DEVICES=%s", os.Getenv("NVIDIA_VISIBLE_DEVICES"))
log.Infof("=== END GPU DIAGNOSTICS [%s] ===", label)
}
func processSnapshotPIDs(restoredPID int) []int {
pidSet := map[int]struct{}{
1: {},
os.Getpid(): {},
}
if restoredPID > 0 {
pidSet[restoredPID] = struct{}{}
}
pids := make([]int, 0, len(pidSet))
for pid := range pidSet {
pids = append(pids, pid)
}
sort.Ints(pids)
return pids
}
func logProcessNamespaces(pid int, log *logrus.Entry) {
for _, ns := range []string{"mnt", "pid", "ipc", "net", "uts", "cgroup"} {
nsPath := fmt.Sprintf("/proc/%d/ns/%s", pid, ns)
link, err := os.Readlink(nsPath)
if err != nil {
log.WithError(err).WithFields(logrus.Fields{
"pid": pid,
"path": nsPath,
}).Warn("Failed to read namespace symlink")
continue
}
log.WithFields(logrus.Fields{
"pid": pid,
"namespace": ns,
"value": link,
}).Info("Namespace snapshot")
}
}
func logProcessCgroupPath(pid int, log *logrus.Entry) {
path := fmt.Sprintf("/proc/%d/cgroup", pid)
data, err := os.ReadFile(path)
if err != nil {
log.WithError(err).WithFields(logrus.Fields{
"pid": pid,
"path": path,
}).Warn("Failed to read cgroup path")
return
}
log.WithFields(logrus.Fields{
"pid": pid,
"path": path,
"contents": strings.TrimSpace(string(data)),
}).Info("Cgroup membership snapshot")
}
func logProcessFilteredMountInfo(pid int, log *logrus.Entry) {
// Mountinfo dumps are very large; only emit them in DEBUG mode.
if !log.Logger.IsLevelEnabled(logrus.DebugLevel) {
return
}
path := fmt.Sprintf("/proc/%d/mountinfo", pid)
f, err := os.Open(path)
if err != nil {
log.WithError(err).WithFields(logrus.Fields{
"pid": pid,
"path": path,
}).Warn("Failed to open mountinfo")
return
}
defer f.Close()
var selected []string
scanner := bufio.NewScanner(f)
scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024)
for scanner.Scan() {
line := scanner.Text()
if strings.Contains(line, " /dev ") ||
strings.Contains(line, "/dev/") ||
strings.Contains(line, "nvidia") ||
strings.Contains(line, "cgroup2") {
selected = append(selected, line)
}
}
if err := scanner.Err(); err != nil {
log.WithError(err).WithFields(logrus.Fields{
"pid": pid,
"path": path,
}).Warn("Failed while scanning mountinfo")
return
}
log.WithFields(logrus.Fields{
"pid": pid,
"path": path,
"count": len(selected),
}).Debug("Filtered mountinfo snapshot count")
if len(selected) > 0 {
for i, line := range selected {
log.WithFields(logrus.Fields{
"pid": pid,
"index": i + 1,
"total": len(selected),
}).Debugf("Filtered mountinfo: %s", line)
}
}
}
func logNvidiaDeviceNodeMetadata(log *logrus.Entry) {
devices, err := filepath.Glob("/dev/nvidia*")
if err != nil {
log.WithError(err).Warn("Failed to glob /dev/nvidia*")
return
}
if len(devices) == 0 {
log.Info("No /dev/nvidia* entries found")
return
}
for _, path := range devices {
fi, err := os.Lstat(path)
if err != nil {
log.WithError(err).WithField("path", path).Warn("Failed to stat NVIDIA device entry")
continue
}
stat, ok := fi.Sys().(*syscall.Stat_t)
if !ok {
log.WithFields(logrus.Fields{
"path": path,
"mode": fi.Mode().String(),
}).Warn("Unexpected stat type for NVIDIA device entry")
continue
}
log.WithFields(logrus.Fields{
"path": path,
"mode": fi.Mode().String(),
"inode": stat.Ino,
"rdev": fmt.Sprintf("0x%x", stat.Rdev),
}).Info("NVIDIA device entry metadata")
}
}
func logCgroupV2HostInfo(log *logrus.Entry) {
const controllersPath = "/sys/fs/cgroup/cgroup.controllers"
data, err := os.ReadFile(controllersPath)
if err != nil {
log.WithError(err).WithField("path", controllersPath).Warn("Failed to read cgroup v2 controllers")
return
}
log.WithFields(logrus.Fields{
"path": controllersPath,
"controllers": strings.TrimSpace(string(data)),
}).Info("cgroup v2 controllers")
}
// LogRestoreBoundaryDiagnostics captures cgroup and namespace state around CRIU restore.
func LogRestoreBoundaryDiagnostics(label string, restoredPID int, log *logrus.Entry) {
log.Infof("=== RESTORE BOUNDARY DIAGNOSTICS [%s] ===", label)
for _, pid := range processSnapshotPIDs(restoredPID) {
logProcessNamespaces(pid, log)
logProcessCgroupPath(pid, log)
logProcessFilteredMountInfo(pid, log)
}
logCgroupV2HostInfo(log)
logNvidiaDeviceNodeMetadata(log)
log.Infof("=== END RESTORE BOUNDARY DIAGNOSTICS [%s] ===", label)
}
// Restore performs the CRIU restore operation using go-criu.
// All CRIU options are read from the saved CheckpointManifest - no hardcoding.
// Returns the PID of the restored process.
func Restore(ctx context.Context, checkpointPath string, data *checkpoint.CheckpointManifest, log *logrus.Entry) (int, error) {
if data == nil {
return 0, fmt.Errorf("checkpoint manifest is required")
}
// Hardcoded restore constants
const (
rootPath = "/"
pidFile = "/tmp/restored.pid"
logFile = RestoreLogFilename
)
log.WithField("checkpoint", checkpointPath).Info("Starting CRIU restore")
// 1. Open checkpoint directory
imageDir, imageDirFD, err := OpenImageDir(checkpointPath)
if err != nil {
return 0, err
}
defer imageDir.Close()
// 2. Generate external mount mappings from saved CheckpointManifest
extMounts, err := GenerateExtMountMaps(data)
if err != nil {
return 0, fmt.Errorf("failed to generate mount maps: %w", err)
}
// 3. Open target network namespace
netNsFile, netNsFD, err := OpenNetworkNamespace("/proc/1/ns/net")
if err != nil {
return 0, err
}
defer netNsFile.Close()
// 4. Open work directory if specified in checkpoint dump settings.
var workDirFile *os.File
var workDirFD int32 = -1
if data.CRIUDump.CRIU.WorkDir != "" {
workDirFile, workDirFD = OpenWorkDir(data.CRIUDump.CRIU.WorkDir, log)
if workDirFile != nil {
defer workDirFile.Close()
}
}
// 5. Build CRIU options from saved checkpoint manifest.
plan := CRIURestorePlan{
// File descriptors
ImageDirFD: imageDirFD,
WorkDirFD: workDirFD,
NetNsFD: netNsFD,
// Paths
RootPath: rootPath,
LogFile: logFile,
// Options from CheckpointManifest.CRIUDump.CRIU
LogLevel: data.CRIUDump.CRIU.LogLevel,
Timeout: data.CRIUDump.CRIU.Timeout,
ShellJob: data.CRIUDump.CRIU.ShellJob,
TcpClose: data.CRIUDump.CRIU.TcpClose,
FileLocks: data.CRIUDump.CRIU.FileLocks,
ExtUnixSk: data.CRIUDump.CRIU.ExtUnixSk,
LinkRemap: data.CRIUDump.CRIU.LinkRemap,
ManageCgroupsMode: data.CRIUDump.CRIU.ManageCgroupsMode,
// External mounts
ExtMountMaps: extMounts,
}
criuOpts := BuildCRIURestoreOptions(plan)
// 6. Reuse criu.conf from checkpoint time if it exists.
criuConfPath := filepath.Join(checkpointPath, checkpoint.CheckpointCRIUConfFilename)
if _, err := os.Stat(criuConfPath); err == nil {
criuOpts.ConfigFile = proto.String(criuConfPath)
}
// 7. Execute CRIU restore
c := criu.MakeCriu()
notify := NewRestoreNotify(log)
log.Info("Executing CRIU restore")
criuExecStart := time.Now()
if err := c.Restore(criuOpts, notify); err != nil {
log.WithField("duration", time.Since(criuExecStart)).Error("CRIU c.Restore failed")
logCRIUErrors(checkpointPath, logFile, log)
return 0, fmt.Errorf("CRIU restore failed: %w", err)
}
log.WithFields(logrus.Fields{
"pid": notify.RestoredPID,
"duration": time.Since(criuExecStart),
}).Info("CRIU c.Restore completed successfully")
// 8. Get restored PID
if notify.RestoredPID > 0 {
return int(notify.RestoredPID), nil
}
// Fallback: try to read from PID file
pid, err := WaitForPidFile(pidFile, 10*time.Second, log)
if err != nil {
return 0, fmt.Errorf("failed to get restored PID: %w", err)
}
return pid, nil
}
// logCRIUErrors reads CRIU log file and logs errors.
func logCRIUErrors(checkpointPath, logFile string, log *logrus.Entry) {
logPath := filepath.Join(checkpointPath, logFile)
data, err := os.ReadFile(logPath)
if err != nil {
log.WithError(err).Warn("Could not read CRIU log file")
return
}
log.Error("=== CRIU RESTORE LOG START ===")
for _, line := range strings.Split(string(data), "\n") {
if line != "" {
log.Error(line)
}
}
log.Error("=== CRIU RESTORE LOG END ===")
// Copy log to shared directory for debugging
if err := os.MkdirAll(CRIULogDir, 0755); err == nil {
destPath := filepath.Join(CRIULogDir, fmt.Sprintf("restore-%d.log", time.Now().Unix()))
if err := os.WriteFile(destPath, data, 0644); err == nil {
log.WithField("path", destPath).Info("CRIU log copied to shared directory")
}
}
}
// Run is the main entry point for the restore entrypoint.
// It orchestrates the entire restore process.
func Run(ctx context.Context, cfg *RestoreRequest, log *logrus.Entry) error {
log.Info("=== Restore Entrypoint ===")
log.WithFields(logrus.Fields{
"checkpoint_path": cfg.CheckpointPath,
"checkpoint_hash": cfg.CheckpointHash,
"checkpoint_location": cfg.CheckpointLocation,
"skip_wait_for_checkpoint": cfg.SkipWaitForCheckpoint,
"cold_start_args": cfg.ColdStartArgs,
}).Debug("Configuration")
// Check CRIU availability
c := criu.MakeCriu()
if _, err := c.GetCriuVersion(); err != nil {
log.WithError(err).Error("CRIU is not available")
return ExecColdStart(cfg, log)
}
// Determine checkpoint path based on mode
var checkpointPath string
if cfg.SkipWaitForCheckpoint {
// Operator path: check once, restore if ready, otherwise cold start
var ready bool
checkpointPath, ready = ShouldRestore(cfg, log)
if !ready {
log.Info("No checkpoint ready, executing cold start command")
return ExecColdStart(cfg, log)
}
} else {
// Standalone/DaemonSet path: check first, then poll if needed
var ready bool
checkpointPath, ready = ShouldRestore(cfg, log)
if !ready {
log.Info("Waiting for checkpoint...")
var err error
checkpointPath, err = WaitForCheckpoint(ctx, cfg, log)
if err != nil {
log.WithError(err).Info("No checkpoint received")
return ExecColdStart(cfg, log)
}
}
}
// Perform restore
log.WithField("checkpoint", checkpointPath).Info("Checkpoint available, starting restore")
restoreStart := time.Now()
// Apply filesystem changes
if err := ApplyRootfsDiff(checkpointPath, "/", log); err != nil {
log.WithError(err).Error("Failed to apply rootfs diff")
}
if err := ApplyDeletedFiles(checkpointPath, "/", log); err != nil {
log.WithError(err).Error("Failed to apply deleted files")
}
// Load checkpoint manifest (contains CRIU settings + mounts + namespaces).
data, err := checkpoint.ReadCheckpointManifest(checkpointPath)
if err != nil {
log.WithError(err).Error("Failed to load checkpoint manifest")
return ExecColdStart(cfg, log)
}
// Write restore marker file before CRIU restore
restoreMarkerFile := cfg.RestoreMarkerFilePath
if err := os.MkdirAll(filepath.Dir(restoreMarkerFile), 0755); err != nil {
log.WithError(err).Warn("Failed to create restore marker directory")
}
if err := os.WriteFile(restoreMarkerFile, []byte("restored"), 0644); err != nil {
log.WithError(err).Warn("Failed to write restore marker file")
}
// Restore /dev/shm contents before CRIU restore
if err := RestoreDevShm(checkpointPath, log); err != nil {
log.WithError(err).Error("Failed to restore /dev/shm contents - CRIU restore may fail with missing FD errors")
}
// Create link_remap stub files for unlinked files referenced in CRIU images
if err := CreateLinkRemapStubs(checkpointPath, log); err != nil {
log.WithError(err).Warn("Failed to create link_remap stubs")
}
// Log GPU diagnostics right before CRIU restore to track device visibility changes
LogGPUDiagnostics("PRE-CRIU-RESTORE", log)
LogRestoreBoundaryDiagnostics("PRE-CRIU-RESTORE", 0, log)
// Perform CRIU restore (CUDA plugin handles CUDA state automatically)
criuRestoreStart := time.Now()
pid, err := Restore(ctx, checkpointPath, data, log)
if err != nil {
log.WithField("duration", time.Since(criuRestoreStart)).WithError(err).Error("Restore failed, falling back to default command")
if cfg.Debug {
log.Info("DEBUG mode: sleeping 300s to allow log collection...")
time.Sleep(300 * time.Second)
}
return ExecColdStart(cfg, log)
}
criuRestoreDuration := time.Since(criuRestoreStart)
log.WithField("duration", criuRestoreDuration).Info("CRIU Restore completed (CUDA state restored by plugin)")
// Log GPU diagnostics AFTER restore to compare with pre-restore
LogGPUDiagnostics("POST-RESTORE", log)
LogRestoreBoundaryDiagnostics("POST-RESTORE", pid, log)
totalDuration := time.Since(restoreStart)
log.WithFields(logrus.Fields{
"total_duration": totalDuration,
"criu_restore_duration": criuRestoreDuration,
}).Info("=== Restore operation completed ===")
// Set up signal forwarding and forward stdout/stderr from restored process
cleanup := SetupSignalForwarding(pid, log)
defer cleanup()
// Use ForwardProcessOutput to ensure restored process logs appear in kubectl logs
exitCode := ForwardProcessOutput(pid, log)
os.Exit(exitCode)
return nil
}
// Package restore provides CRIU restore operations.
package restore
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/sirupsen/logrus"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint"
)
// RestoreDevShm restores files from the checkpoint's dev-shm directory to /dev/shm.
// This must be called BEFORE CRIU restore so that the shared memory files exist
// when CRIU tries to restore file descriptors pointing to them.
func RestoreDevShm(checkpointPath string, log *logrus.Entry) error {
srcDir := filepath.Join(checkpointPath, checkpoint.DevShmDirName)
// Check if dev-shm directory exists in checkpoint
entries, err := os.ReadDir(srcDir)
if err != nil {
if os.IsNotExist(err) {
log.Debug("No dev-shm directory in checkpoint, skipping restore")
return nil
}
return fmt.Errorf("failed to read checkpoint dev-shm directory: %w", err)
}
if len(entries) == 0 {
log.Debug("Checkpoint dev-shm directory is empty")
return nil
}
// Ensure /dev/shm exists and is writable
destDir := "/dev/shm"
if err := os.MkdirAll(destDir, 0777); err != nil {
return fmt.Errorf("failed to ensure /dev/shm exists: %w", err)
}
var restored []string
var totalSize int64
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
srcPath := filepath.Join(srcDir, name)
destPath := filepath.Join(destDir, name)
info, err := entry.Info()
if err != nil {
log.WithError(err).WithField("file", name).Warn("Failed to get file info, skipping")
continue
}
size := info.Size()
// Copy the file to /dev/shm
if err := copyFileToShm(srcPath, destPath, info.Mode()); err != nil {
log.WithError(err).WithField("file", name).Warn("Failed to restore file, skipping")
continue
}
restored = append(restored, name)
totalSize += size
log.WithFields(logrus.Fields{
"file": name,
"size": size,
}).Debug("Restored /dev/shm file")
}
if len(restored) > 0 {
log.WithFields(logrus.Fields{
"count": len(restored),
"total_size": totalSize,
"files": restored,
}).Info("Restored /dev/shm files from checkpoint")
}
return nil
}
// copyFileToShm copies a file from src to dest in /dev/shm.
// Uses mode 0666 as default when mode is 0, otherwise preserves the original mode.
func copyFileToShm(src, dest string, mode os.FileMode) error {
srcFile, err := os.Open(src)
if err != nil {
return fmt.Errorf("failed to open source: %w", err)
}
defer srcFile.Close()
// Default to 0666 when mode is not set (mode == 0)
if mode == 0 {
mode = 0666
}
destFile, err := os.OpenFile(dest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode)
if err != nil {
return fmt.Errorf("failed to create destination: %w", err)
}
defer destFile.Close()
if _, err := io.Copy(destFile, srcFile); err != nil {
return fmt.Errorf("failed to copy contents: %w", err)
}
return nil
}
// Package types defines shared data types used across chrek packages.
package types
import (
"fmt"
"os"
"strings"
"time"
)
// AgentConfig holds the full agent configuration: static checkpoint settings
// from the ConfigMap YAML, plus runtime fields from environment variables.
type AgentConfig struct {
NodeName string `yaml:"-"`
RestrictedNamespace string `yaml:"-"`
BasePath string `yaml:"basePath"`
Overlay OverlaySettings `yaml:"overlay"`
Restore RestoreSpec `yaml:"restore"`
CRIU CRIUSettings `yaml:"criu"`
}
func (c *AgentConfig) LoadEnvOverrides() {
if v := os.Getenv("NODE_NAME"); v != "" {
c.NodeName = v
}
if v := os.Getenv("RESTRICTED_NAMESPACE"); v != "" {
c.RestrictedNamespace = v
}
}
func (c *AgentConfig) Validate() error {
if strings.TrimSpace(c.BasePath) == "" {
return &ConfigError{Field: "basePath", Message: "basePath is required"}
}
return c.Restore.Validate()
}
// RestoreSpec holds settings for the CRIU restore process.
type RestoreSpec struct {
NSRestorePath string `yaml:"nsRestorePath"`
RestoreReadyTimeoutSeconds int `yaml:"restoreReadyTimeoutSeconds"`
}
func (c *RestoreSpec) RestoreReadyTimeout() time.Duration {
if c.RestoreReadyTimeoutSeconds <= 0 {
return 0
}
return time.Duration(c.RestoreReadyTimeoutSeconds) * time.Second
}
func (c *RestoreSpec) Validate() error {
if c.NSRestorePath == "" {
return &ConfigError{Field: "nsRestorePath", Message: "nsRestorePath is required"}
}
return nil
}
// CRIUSettings holds CRIU-specific configuration options.
type CRIUSettings struct {
GhostLimit uint32 `yaml:"ghostLimit"`
LogLevel int32 `yaml:"logLevel"`
WorkDir string `yaml:"workDir"`
AutoDedup bool `yaml:"autoDedup"`
LazyPages bool `yaml:"lazyPages"`
LeaveRunning bool `yaml:"leaveRunning"`
ShellJob bool `yaml:"shellJob"`
TcpClose bool `yaml:"tcpClose"`
FileLocks bool `yaml:"fileLocks"`
OrphanPtsMaster bool `yaml:"orphanPtsMaster"`
ExtUnixSk bool `yaml:"extUnixSk"`
LinkRemap bool `yaml:"linkRemap"`
ExtMasters bool `yaml:"extMasters"`
ManageCgroupsMode string `yaml:"manageCgroupsMode"`
RstSibling bool `yaml:"rstSibling"`
MntnsCompatMode bool `yaml:"mntnsCompatMode"`
EvasiveDevices bool `yaml:"evasiveDevices"`
ForceIrmap bool `yaml:"forceIrmap"`
BinaryPath string `yaml:"binaryPath"`
LibDir string `yaml:"libDir"`
AllowUprobes bool `yaml:"allowUprobes"`
SkipInFlight bool `yaml:"skipInFlight"`
}
// OverlaySettings is the static config for rootfs exclusions.
type OverlaySettings struct {
SystemDirs []string `yaml:"systemDirs"`
CacheDirs []string `yaml:"cacheDirs"`
AdditionalExclusions []string `yaml:"additionalExclusions"`
}
// ConfigError represents a configuration validation error.
type ConfigError struct {
Field string
Message string
}
func (e *ConfigError) Error() string {
return fmt.Sprintf("config error: %s: %s", e.Field, e.Message)
}
package types
import (
specs "github.com/opencontainers/runtime-spec/specs-go"
)
// MountInfo holds parsed mount information from /proc/pid/mountinfo.
type MountInfo struct {
MountPoint string
FSType string
VFSOptions string // superblock options (e.g. "upperdir=...")
// IsOCIManaged is true when the mount destination matches an OCI spec entry
// (including /run/ ↔ /var/run/ aliasing). Set by ClassifyMounts.
IsOCIManaged bool
}
// CheckpointContainerSnapshot holds runtime container state collected during checkpoint inspection.
type CheckpointContainerSnapshot struct {
PID int
RootFS string
UpperDir string
OCISpec *specs.Spec
Mounts []MountInfo
NetNSInode uint64
StdioFDs []string // readlink targets for FDs 0, 1, 2 (e.g. "pipe:[12345]")
HostCgroupPath string // host filesystem path for CRIU's --freeze-cgroup
CUDAPIDs []int // PIDs with CUDA state in the container
GPUUUIDs []string // source GPU UUIDs from kubelet PodResources API
}
// RestoreContainerSnapshot holds inspected state for the restore target.
type RestoreContainerSnapshot struct {
CheckpointPath string
PlaceholderPID int
TargetRoot string
CgroupRoot string
CUDADeviceMap string
}
package types
import (
"fmt"
"os"
"path/filepath"
"time"
criurpc "github.com/checkpoint-restore/go-criu/v8/rpc"
specs "github.com/opencontainers/runtime-spec/specs-go"
"gopkg.in/yaml.v3"
)
const manifestFilename = "manifest.yaml"
// CheckpointManifest is saved as manifest.yaml at checkpoint time and loaded at restore.
type CheckpointManifest struct {
CheckpointHash string `yaml:"checkpointHash"`
CreatedAt time.Time `yaml:"createdAt"`
CRIUDump CRIUDumpManifest `yaml:"criuDump"`
K8s SourcePodManifest `yaml:"k8s"`
Overlay OverlayManifest `yaml:"overlay"`
CUDA CUDAManifest `yaml:"cudaRestore,omitempty"`
}
func NewCheckpointManifest(
checkpointHash string,
criuDump CRIUDumpManifest,
k8s SourcePodManifest,
overlay OverlayManifest,
) *CheckpointManifest {
return &CheckpointManifest{
CheckpointHash: checkpointHash,
CreatedAt: time.Now().UTC(),
CRIUDump: criuDump,
K8s: k8s,
Overlay: overlay,
}
}
// CRIUDumpManifest stores the resolved dump-time CRIU mount plan used for restore.
type CRIUDumpManifest struct {
CRIU CRIUSettings `yaml:"criu"`
ExtMnt map[string]string `yaml:"extMnt,omitempty"`
External []string `yaml:"external,omitempty"`
SkipMnt []string `yaml:"skipMnt,omitempty"`
}
func NewCRIUDumpManifest(criuOpts *criurpc.CriuOpts, settings CRIUSettings) CRIUDumpManifest {
m := CRIUDumpManifest{CRIU: settings}
if criuOpts == nil {
return m
}
m.ExtMnt = make(map[string]string, len(criuOpts.ExtMnt))
for _, mount := range criuOpts.ExtMnt {
if mount == nil || mount.GetKey() == "" {
continue
}
m.ExtMnt[mount.GetKey()] = mount.GetVal()
}
if len(m.ExtMnt) == 0 {
m.ExtMnt = nil
}
m.External = append([]string(nil), criuOpts.External...)
m.SkipMnt = append([]string(nil), criuOpts.SkipMnt...)
return m
}
// SourcePodManifest records the source pod identity at checkpoint time.
type SourcePodManifest struct {
ContainerID string `yaml:"containerId"`
PID int `yaml:"pid"`
SourceNode string `yaml:"sourceNode"`
PodName string `yaml:"podName"`
PodNamespace string `yaml:"podNamespace"`
// StdioFDs holds readlink targets for FDs 0, 1, 2 (e.g. "pipe:[12345]").
StdioFDs []string `yaml:"stdioFDs,omitempty"`
}
func NewSourcePodManifest(containerID string, pid int, sourceNode, podName, podNamespace string, stdioFDs []string) SourcePodManifest {
return SourcePodManifest{
ContainerID: containerID,
PID: pid,
SourceNode: sourceNode,
PodName: podName,
PodNamespace: podNamespace,
StdioFDs: append([]string(nil), stdioFDs...),
}
}
// OverlayManifest holds runtime overlay state captured at checkpoint time.
type OverlayManifest struct {
Exclusions OverlaySettings `yaml:"exclusions"`
UpperDir string `yaml:"upperDir,omitempty"`
ExternalPaths []string `yaml:"externalPaths,omitempty"`
BindMountDests []string `yaml:"bindMountDests,omitempty"`
}
func NewOverlayManifest(exclusions OverlaySettings, upperDir string, ociSpec *specs.Spec) OverlayManifest {
meta := OverlayManifest{
Exclusions: exclusions,
UpperDir: upperDir,
}
if ociSpec == nil {
return meta
}
if ociSpec.Linux != nil {
meta.ExternalPaths = make([]string, 0, len(ociSpec.Linux.MaskedPaths)+len(ociSpec.Linux.ReadonlyPaths))
meta.ExternalPaths = append(meta.ExternalPaths, ociSpec.Linux.MaskedPaths...)
meta.ExternalPaths = append(meta.ExternalPaths, ociSpec.Linux.ReadonlyPaths...)
}
for _, m := range ociSpec.Mounts {
if m.Type == "bind" {
meta.BindMountDests = append(meta.BindMountDests, m.Destination)
}
}
return meta
}
// CUDAManifest captures CUDA state from checkpoint time for restore.
type CUDAManifest struct {
PIDs []int `yaml:"pids"`
SourceGPUUUIDs []string `yaml:"sourceGpuUuids"`
}
func NewCUDAManifest(pids []int, sourceGPUUUIDs []string) CUDAManifest {
return CUDAManifest{
PIDs: append([]int(nil), pids...),
SourceGPUUUIDs: append([]string(nil), sourceGPUUUIDs...),
}
}
func (m CUDAManifest) IsEmpty() bool {
return len(m.PIDs) == 0
}
// WriteManifest writes a checkpoint manifest file in the checkpoint directory.
func WriteManifest(checkpointDir string, data *CheckpointManifest) error {
content, err := yaml.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal checkpoint manifest: %w", err)
}
manifestPath := filepath.Join(checkpointDir, manifestFilename)
if err := os.WriteFile(manifestPath, content, 0600); err != nil {
return fmt.Errorf("failed to write checkpoint manifest: %w", err)
}
return nil
}
// ReadManifest reads checkpoint manifest from a checkpoint directory.
func ReadManifest(checkpointDir string) (*CheckpointManifest, error) {
manifestPath := filepath.Join(checkpointDir, manifestFilename)
content, err := os.ReadFile(manifestPath)
if err != nil {
return nil, fmt.Errorf("failed to read checkpoint manifest: %w", err)
}
var data CheckpointManifest
if err := yaml.Unmarshal(content, &data); err != nil {
return nil, fmt.Errorf("failed to unmarshal checkpoint manifest: %w", err)
}
return &data, nil
}
package types
import (
"testing"
criurpc "github.com/checkpoint-restore/go-criu/v8/rpc"
"google.golang.org/protobuf/proto"
)
func TestManifestRoundTrip(t *testing.T) {
dir := t.TempDir()
original := NewCheckpointManifest(
"sha256:abc123",
CRIUDumpManifest{
CRIU: CRIUSettings{
LogLevel: 4,
ShellJob: true,
LibDir: "/usr/lib/criu",
},
ExtMnt: map[string]string{"/etc/hostname": "/etc/hostname", "/proc/acpi": "/dev/null"},
External: []string{"net[12345]:extNetNs"},
SkipMnt: []string{"/proc/kcore"},
},
NewSourcePodManifest("ctr-abc", 42, "node-1", "my-pod", "default", []string{"pipe:[111]", "pipe:[222]", "pipe:[333]"}),
OverlayManifest{
Exclusions: OverlaySettings{SystemDirs: []string{"/proc", "/sys"}},
UpperDir: "/var/lib/containerd/upper",
ExternalPaths: []string{"/proc/acpi"},
BindMountDests: []string{"/data"},
},
)
original.CUDA = NewCUDAManifest([]int{42, 43}, []string{"GPU-aaa", "GPU-bbb"})
if err := WriteManifest(dir, original); err != nil {
t.Fatalf("WriteManifest: %v", err)
}
loaded, err := ReadManifest(dir)
if err != nil {
t.Fatalf("ReadManifest: %v", err)
}
// Verify key fields survived the round-trip
if loaded.CheckpointHash != original.CheckpointHash {
t.Errorf("CheckpointHash = %q, want %q", loaded.CheckpointHash, original.CheckpointHash)
}
if loaded.CRIUDump.CRIU.LogLevel != 4 {
t.Errorf("CRIU.LogLevel = %d, want 4", loaded.CRIUDump.CRIU.LogLevel)
}
if loaded.CRIUDump.CRIU.ShellJob != true {
t.Error("CRIU.ShellJob should be true")
}
if len(loaded.CRIUDump.ExtMnt) != 2 {
t.Errorf("ExtMnt count = %d, want 2", len(loaded.CRIUDump.ExtMnt))
}
if loaded.CRIUDump.ExtMnt["/etc/hostname"] != "/etc/hostname" {
t.Errorf("ExtMnt[/etc/hostname] = %q", loaded.CRIUDump.ExtMnt["/etc/hostname"])
}
if len(loaded.CRIUDump.External) != 1 || loaded.CRIUDump.External[0] != "net[12345]:extNetNs" {
t.Errorf("External = %v", loaded.CRIUDump.External)
}
if len(loaded.CRIUDump.SkipMnt) != 1 || loaded.CRIUDump.SkipMnt[0] != "/proc/kcore" {
t.Errorf("SkipMnt = %v", loaded.CRIUDump.SkipMnt)
}
if loaded.K8s.PodName != "my-pod" {
t.Errorf("K8s.PodName = %q", loaded.K8s.PodName)
}
if len(loaded.K8s.StdioFDs) != 3 {
t.Errorf("StdioFDs count = %d, want 3", len(loaded.K8s.StdioFDs))
}
if loaded.Overlay.UpperDir != "/var/lib/containerd/upper" {
t.Errorf("Overlay.UpperDir = %q", loaded.Overlay.UpperDir)
}
if len(loaded.Overlay.BindMountDests) != 1 || loaded.Overlay.BindMountDests[0] != "/data" {
t.Errorf("Overlay.BindMountDests = %v", loaded.Overlay.BindMountDests)
}
if len(loaded.CUDA.PIDs) != 2 || loaded.CUDA.PIDs[0] != 42 {
t.Errorf("CUDA.PIDs = %v", loaded.CUDA.PIDs)
}
if len(loaded.CUDA.SourceGPUUUIDs) != 2 || loaded.CUDA.SourceGPUUUIDs[0] != "GPU-aaa" {
t.Errorf("CUDA.SourceGPUUUIDs = %v", loaded.CUDA.SourceGPUUUIDs)
}
}
func TestNewCRIUDumpManifest(t *testing.T) {
t.Run("nil CriuOpts does not panic", func(t *testing.T) {
m := NewCRIUDumpManifest(nil, CRIUSettings{LogLevel: 2})
if m.CRIU.LogLevel != 2 {
t.Errorf("LogLevel = %d, want 2", m.CRIU.LogLevel)
}
if m.ExtMnt != nil {
t.Errorf("ExtMnt should be nil, got %v", m.ExtMnt)
}
})
t.Run("extracts ExtMnt from protobuf correctly", func(t *testing.T) {
opts := &criurpc.CriuOpts{
ExtMnt: []*criurpc.ExtMountMap{
{Key: proto.String("/etc/hosts"), Val: proto.String("/etc/hosts")},
{Key: proto.String("/proc/acpi"), Val: proto.String("/dev/null")},
// nil entry and empty key should be skipped
nil,
{Key: proto.String(""), Val: proto.String("ignored")},
},
External: []string{"net[1234]:extNetNs"},
SkipMnt: []string{"/proc/kcore", "/sys/firmware"},
}
m := NewCRIUDumpManifest(opts, CRIUSettings{})
if len(m.ExtMnt) != 2 {
t.Fatalf("ExtMnt count = %d, want 2; got %v", len(m.ExtMnt), m.ExtMnt)
}
if m.ExtMnt["/etc/hosts"] != "/etc/hosts" {
t.Errorf("ExtMnt[/etc/hosts] = %q", m.ExtMnt["/etc/hosts"])
}
if m.ExtMnt["/proc/acpi"] != "/dev/null" {
t.Errorf("ExtMnt[/proc/acpi] = %q", m.ExtMnt["/proc/acpi"])
}
if len(m.External) != 1 {
t.Errorf("External = %v", m.External)
}
if len(m.SkipMnt) != 2 {
t.Errorf("SkipMnt = %v", m.SkipMnt)
}
})
t.Run("empty ExtMnt entries results in nil map", func(t *testing.T) {
opts := &criurpc.CriuOpts{
ExtMnt: []*criurpc.ExtMountMap{
nil,
{Key: proto.String(""), Val: proto.String("x")},
},
}
m := NewCRIUDumpManifest(opts, CRIUSettings{})
if m.ExtMnt != nil {
t.Errorf("expected nil ExtMnt when all entries are empty/nil, got %v", m.ExtMnt)
}
})
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment