"lib/engines/vscode:/vscode.git/clone" did not exist on "f8213242d768dd4dda051a150785eb4824e50212"
Unverified Commit d381e6ff authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

feat(chrek): config refactor, /dev/shm support, and mount-policy rewrite (#5946)

parent b6824ae0
......@@ -90,9 +90,6 @@ class Config:
# Use vLLM's tokenizer for pre/post processing
use_vllm_tokenizer: bool = False
# sleep mode support (enable_sleep_mode comes from vLLM's engine_args)
sleep_mode_level: int = 1
# Whether to enable NATS for KV events (derived from kv_events_config in overwrite_args)
use_kv_events: bool = False
......@@ -301,13 +298,6 @@ def parse_args() -> Config:
default=False,
help="Use vLLM's tokenizer for pre and post processing. This bypasses Dynamo's preprocessor and only v1/chat/completions will be available through the Dynamo frontend.",
)
parser.add_argument(
"--sleep-mode-level",
type=int,
default=1,
choices=[1, 2, 3],
help="Sleep mode level (1=offload to CPU, 2=discard weights, 3=discard all). Default: 1",
)
add_config_dump_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser)
......@@ -454,7 +444,6 @@ def parse_args() -> Config:
config.enable_local_indexer = not args.durable_kv_events
# For omni mode, use vLLM (AsyncOmni) tokenizer on backend
config.use_vllm_tokenizer = args.use_vllm_tokenizer or args.omni
config.sleep_mode_level = args.sleep_mode_level
# use_kv_events is set later in overwrite_args() based on kv_events_config
# Validate custom Jinja template file exists if provided
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Checkpoint/restore (chrek) integration for vLLM workers.
Handles the checkpoint job pod lifecycle:
1. Early exit if a checkpoint already exists (idempotency)
2. Sleep model for CRIU-friendly GPU state
3. Signal readiness for DaemonSet to begin checkpoint
4. Poll for checkpoint completion or CRIU restore detection
5. Wake model after restore
Environment variables (all required in checkpoint mode, no fallbacks):
- DYN_CHECKPOINT_SIGNAL_FILE: Path where DaemonSet writes completion signal
- DYN_READY_FOR_CHECKPOINT_FILE: Path where this worker writes readiness marker
- DYN_CHECKPOINT_STORAGE_TYPE: Storage backend (pvc, s3, oci)
- DYN_CHECKPOINT_LOCATION: Full checkpoint path (for idempotency check)
- DYN_RESTORE_MARKER_FILE: Path written by restore-entrypoint before CRIU restore
"""
import asyncio
import json
import logging
import os
from typing import Optional
logger = logging.getLogger(__name__)
_REQUIRED_ENV_VARS = [
"DYN_CHECKPOINT_SIGNAL_FILE",
"DYN_READY_FOR_CHECKPOINT_FILE",
"DYN_CHECKPOINT_STORAGE_TYPE",
"DYN_CHECKPOINT_LOCATION",
"DYN_RESTORE_MARKER_FILE",
]
class CheckpointConfig:
"""Parsed and validated checkpoint configuration from environment variables."""
def __init__(self):
self.signal_file = os.environ["DYN_CHECKPOINT_SIGNAL_FILE"]
self.ready_file = os.environ["DYN_READY_FOR_CHECKPOINT_FILE"]
self.storage_type = os.environ["DYN_CHECKPOINT_STORAGE_TYPE"]
self.location = os.environ["DYN_CHECKPOINT_LOCATION"]
self.restore_marker = os.environ["DYN_RESTORE_MARKER_FILE"]
def _read_status_file(self, path: str) -> dict:
with open(path) as f:
status = json.load(f)
success = status.get("success")
if not isinstance(success, bool):
raise ValueError(f"missing or invalid success field in {path}")
return status
def checkpoint_exists(self) -> bool:
"""Check if a completed checkpoint already exists (idempotency).
For PVC storage, checks for checkpoint.done marker at the location.
Returns True if the job should exit without loading the model.
"""
assert (
self.storage_type == "pvc"
), "Checkpoint existence check is only implemented for PVC storage"
if self.storage_type == "pvc" and self.location:
done_marker = f"{self.location}/checkpoint.done"
if os.path.exists(done_marker):
try:
status = self._read_status_file(done_marker)
except (OSError, ValueError, json.JSONDecodeError) as exc:
logger.warning(
f"Invalid checkpoint.done marker at {done_marker}, ignoring stale checkpoint: {exc}"
)
return False
if status["success"]:
logger.info(
f"Existing successful checkpoint found at {self.location}, skipping"
)
return True
logger.warning(
f"Existing checkpoint marker reports failure at {self.location}: "
f"{status.get('error', 'unknown error')}"
)
return False
logger.info(f"No checkpoint at {self.location}, creating new one")
return False
async def run_lifecycle(self, engine_client, sleep_level: int) -> bool:
"""Run the full checkpoint lifecycle after the engine is loaded.
1. Put model to sleep (CRIU-friendly GPU state)
2. Write ready file (triggers DaemonSet checkpoint via readiness probe)
3. Poll for signal file (checkpoint done) or restore marker (CRIU restored us)
4. If restored: wake model and return True (caller proceeds with registration)
5. If checkpoint done: return False (caller should exit)
"""
# Sleep model for checkpoint
logger.info(f"Putting model to sleep (level={sleep_level})")
await engine_client.sleep(level=sleep_level)
# Signal readiness
with open(self.ready_file, "w") as f:
f.write("ready")
logger.info(
f"Ready for checkpoint. Waiting for signal: {self.signal_file} "
f"or restore marker: {self.restore_marker}"
)
# Poll for signal or restore
while True:
if os.path.exists(self.restore_marker):
logger.info(f"Restore detected (marker: {self.restore_marker})")
logger.info("Waking up model after restore")
await engine_client.wake_up()
return True
if os.path.exists(self.signal_file):
try:
signal = self._read_status_file(self.signal_file)
except (OSError, ValueError, json.JSONDecodeError) as exc:
raise RuntimeError(
f"Invalid checkpoint signal file {self.signal_file}: {exc}"
) from exc
if signal["success"]:
logger.info(f"Checkpoint complete (signal: {self.signal_file})")
return False
raise RuntimeError(
f"Checkpoint failed (signal: {self.signal_file}): "
f"{signal.get('error', 'unknown error')}"
)
await asyncio.sleep(1)
def get_checkpoint_config() -> Optional[CheckpointConfig]:
"""Returns CheckpointConfig if in checkpoint mode, None otherwise.
Checkpoint mode is detected by DYN_CHECKPOINT_SIGNAL_FILE being set.
If in checkpoint mode, all required env vars must be present — raises
EnvironmentError if any are missing.
"""
if "DYN_CHECKPOINT_SIGNAL_FILE" not in os.environ:
return None
missing = [v for v in _REQUIRED_ENV_VARS if v not in os.environ]
if missing:
raise EnvironmentError(
f"Checkpoint mode requires these environment variables: {', '.join(missing)}"
)
return CheckpointConfig()
......@@ -56,6 +56,7 @@ from dynamo.vllm.multimodal_handlers import (
from dynamo.vllm.multimodal_utils.encode_utils import create_ec_transfer_config
from .args import Config, overwrite_args, parse_args
from .chrek import get_checkpoint_config
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
from .health_check import (
VllmHealthCheckPayload,
......@@ -66,6 +67,7 @@ from .publisher import DYNAMO_COMPONENT_REGISTRY, StatLoggerFactory
configure_dynamo_logging()
logger = logging.getLogger(__name__)
CHECKPOINT_SLEEP_MODE_LEVEL = 1
async def _handle_non_leader_node(dp_rank: int) -> None:
......@@ -81,46 +83,17 @@ async def _handle_non_leader_node(dp_rank: int) -> None:
await asyncio.Event().wait()
async def await_checkpoint_and_was_restored(signal_file: str) -> bool:
async def graceful_shutdown(runtime, shutdown_event):
"""
Wait for checkpoint signal file OR restore marker file.
In checkpoint creation mode, poll until either:
1. The signal file exists (checkpoint complete, should exit)
2. The restore marker file exists (restored by CRIU, should proceed)
The restore marker file is created by the restore-entrypoint before CRIU restore,
so the restored process can detect it was restored even though os.environ is
restored from the checkpoint and doesn't contain new container env vars.
Args:
signal_file: Path to the checkpoint signal file
Returns:
True if restored (should proceed with registration)
False if signal file detected (should exit)
Shutdown dynamo distributed runtime.
The endpoints will be immediately invalidated so no new requests will be accepted.
For endpoints served with graceful_shutdown=True, the serving function will wait until all in-flight requests are finished.
For endpoints served with graceful_shutdown=False, the serving function will return immediately.
"""
# Get restore marker file path (created by restore entrypoint before CRIU restore)
restore_marker = os.environ.get("DYN_RESTORE_MARKER_FILE", "/tmp/dynamo-restored")
logger.info(
f"CHECKPOINT_READY: Model loaded, ready for container checkpoint. Waiting for signal file: {signal_file} or restore marker file: {restore_marker}"
)
while True:
# Check if we've been restored (marker file created by restore entrypoint)
if os.path.exists(restore_marker):
logger.info(
f"Detected restore from checkpoint (marker file exists: {restore_marker})"
)
return True # Restored - proceed with registration
# Check if checkpoint is complete (signal file exists)
if os.path.exists(signal_file):
logger.info(f"Checkpoint signal file detected: {signal_file}")
return False # Checkpoint done - exit
await asyncio.sleep(1)
logging.info("Received shutdown signal, shutting down DistributedRuntime")
shutdown_event.set()
runtime.shutdown()
logging.info("DistributedRuntime shutdown complete")
async def worker():
......@@ -134,29 +107,10 @@ async def worker():
if not config.served_model_name:
config.served_model_name = config.engine_args.served_model_name = config.model
# Check checkpoint-related environment variables EARLY
signal_file = os.environ.get("DYN_CHECKPOINT_SIGNAL_FILE")
ready_file = os.environ.get("DYN_CHECKPOINT_READY_FILE")
is_checkpoint_mode = signal_file is not None
# EARLY EXIT: Check if checkpoint already exists (before downloading model!)
if is_checkpoint_mode:
storage_type = os.environ.get("DYN_CHECKPOINT_STORAGE_TYPE")
checkpoint_location = os.environ.get("DYN_CHECKPOINT_LOCATION")
if storage_type == "pvc" and checkpoint_location:
done_marker = f"{checkpoint_location}/checkpoint.done"
if os.path.exists(done_marker):
logger.info(
f"Found existing checkpoint at {checkpoint_location}. Storage type: {storage_type}"
)
return
else:
logger.info(
f"Checkpoint not found at: {checkpoint_location}. creating new checkpoint"
)
# Check checkpoint mode and validate env vars EARLY (fail fast if misconfigured)
checkpoint_cfg = get_checkpoint_config()
if checkpoint_cfg and checkpoint_cfg.checkpoint_exists():
return
# Download the model if necessary using modelexpress.
# We want it on disk before we start vllm to avoid downloading from HuggingFace.
......@@ -173,39 +127,20 @@ async def worker():
# CHECKPOINT MODE: Load engine BEFORE runtime creation
# This allows checkpointing GPU state before runtime connections are established
pre_created_engine = None
is_restored = False
if is_checkpoint_mode:
if checkpoint_cfg is not None:
logger.info(
f"Checkpoint mode enabled (DYN_CHECKPOINT_SIGNAL_FILE={signal_file})"
f"Checkpoint mode enabled (signal_file={checkpoint_cfg.signal_file})"
)
# CHECKPOINT MODE: Load model, sleep, wait for signal file or restore
# Checkpoint mode requires sleep mode — enable before engine init
config.engine_args.enable_sleep_mode = True
pre_created_engine = setup_vllm_engine(config)
engine_client = pre_created_engine[0]
# Put model to sleep before checkpoint (if sleep mode enabled)
if config.engine_args.enable_sleep_mode:
logger.info(f"Putting model to sleep (level={config.sleep_mode_level})")
await engine_client.sleep(level=config.sleep_mode_level)
# Write ready file to signal that we're ready for checkpointing
if ready_file:
with open(ready_file, "w") as f:
f.write("ready")
logger.info(f"Wrote checkpoint ready file: {ready_file}")
# Wait for checkpoint signal file OR restore detection
is_restored = await await_checkpoint_and_was_restored(signal_file)
if is_restored:
# Wake up model and proceed with registration
if config.engine_args.enable_sleep_mode:
logger.info("Waking up model after checkpoint restore")
await engine_client.wake_up()
logger.info("Proceeding with endpoint registration after restore")
else:
# Checkpoint complete, exit
logger.info("Exiting after checkpoint completion")
if not await checkpoint_cfg.run_lifecycle(
engine_client, CHECKPOINT_SLEEP_MODE_LEVEL
):
return
shutdown_event = asyncio.Event()
......
......@@ -142,14 +142,6 @@ COPY --from=builder /restore-entrypoint /restore-entrypoint
# Create checkpoint directory
RUN mkdir -p /checkpoints
# Set environment variables
ENV HOST_PROC=/host/proc \
CONTAINERD_SOCKET=/run/containerd/containerd.sock \
CHECKPOINT_DIR=/checkpoints \
LISTEN_ADDR=:8080
EXPOSE 8080
USER root
ENTRYPOINT ["/usr/local/bin/chrek-agent"]
......@@ -172,7 +164,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libnl-3-200 \
libnl-route-3-200 \
libprotobuf-c1 \
libgnutls30 \
libgnutls30t64 \
libnftables1 \
iproute2 \
iptables \
......@@ -191,23 +183,11 @@ COPY --from=criu-builder /tmp/cuda-checkpoint/bin/x86_64_Linux/cuda-checkpoint /
RUN chmod +x /usr/local/sbin/cuda-checkpoint
# Create directories
RUN mkdir -p /checkpoints /var/run/criu /tmp /var/criu-work
RUN mkdir -p /checkpoints /var/run/criu /var/criu-work
# Copy restore binaries
COPY --from=builder /restore-entrypoint /restore-entrypoint
RUN chmod +x /restore-entrypoint
COPY scripts/smart-entrypoint.sh /smart-entrypoint.sh
RUN chmod +x /smart-entrypoint.sh
# Set environment variables
ENV DYN_CHECKPOINT_PATH=/checkpoints \
RESTORE_TRIGGER=/tmp/restore-trigger \
RESTORE_WAIT_TIMEOUT=300 \
CRIU_LOG_LEVEL=4 \
WAIT_FOR_CHECKPOINT=0 \
CUDA_PLUGIN_DIR=/usr/local/lib/criu \
DEBUG=0
ENTRYPOINT ["/smart-entrypoint.sh"]
ENTRYPOINT ["/restore-entrypoint"]
CMD []
// config.go provides configuration loading for the checkpoint agent.
package main
import (
"fmt"
"os"
"gopkg.in/yaml.v3"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint"
)
// ConfigMapPath is the default path where the ConfigMap is mounted.
const ConfigMapPath = "/etc/chrek/config.yaml"
// CheckpointSignalSource determines how checkpoint operations are triggered.
type CheckpointSignalSource string
const (
// SignalFromHTTP triggers checkpoints via HTTP API requests.
SignalFromHTTP CheckpointSignalSource = "http"
// SignalFromWatcher triggers checkpoints automatically when pods become Ready.
SignalFromWatcher CheckpointSignalSource = "watcher"
)
// FullConfig is the root configuration structure loaded from the ConfigMap.
type FullConfig struct {
Agent AgentConfig `yaml:"agent"`
Checkpoint checkpoint.CheckpointSpec `yaml:"checkpoint"`
}
// AgentConfig holds the runtime configuration for the checkpoint agent daemon.
type AgentConfig struct {
// SignalSource determines how checkpoints are triggered: "http" or "watcher"
SignalSource string `yaml:"signalSource"`
// ListenAddr is the HTTP server address for health checks and API
ListenAddr string `yaml:"listenAddr"`
// NodeName is the Kubernetes node name (from NODE_NAME env, downward API)
NodeName string `yaml:"-"`
// RestrictedNamespace restricts pod watching to this namespace (optional)
RestrictedNamespace string `yaml:"-"`
}
// ConfigError represents a configuration validation error.
type ConfigError struct {
Field string
Message string
}
func (e *ConfigError) Error() string {
return fmt.Sprintf("config error: %s: %s", e.Field, e.Message)
}
// LoadConfig loads the full configuration from a YAML file.
func LoadConfig(path string) (*FullConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read config file %s: %w", path, err)
}
cfg := &FullConfig{}
if err := yaml.Unmarshal(data, cfg); err != nil {
return nil, fmt.Errorf("failed to parse config file %s: %w", path, err)
}
// Apply environment variable overrides
cfg.Agent.loadEnvOverrides()
return cfg, nil
}
// LoadConfigOrDefault loads configuration from a file, falling back to zero values if the file doesn't exist.
func LoadConfigOrDefault(path string) (*FullConfig, error) {
cfg, err := LoadConfig(path)
if err != nil {
if os.IsNotExist(err) {
cfg = &FullConfig{}
cfg.Agent.loadEnvOverrides()
return cfg, nil
}
return nil, err
}
return cfg, nil
}
// loadEnvOverrides applies environment variable overrides to the AgentConfig.
func (c *AgentConfig) loadEnvOverrides() {
if v := os.Getenv("NODE_NAME"); v != "" {
c.NodeName = v
}
if v := os.Getenv("RESTRICTED_NAMESPACE"); v != "" {
c.RestrictedNamespace = v
}
}
// GetSignalSource returns the signal source as a CheckpointSignalSource type.
func (c *AgentConfig) GetSignalSource() CheckpointSignalSource {
return CheckpointSignalSource(c.SignalSource)
}
// Validate checks that the AgentConfig has valid values.
func (c *AgentConfig) Validate() error {
if c.SignalSource != string(SignalFromHTTP) && c.SignalSource != string(SignalFromWatcher) {
return &ConfigError{
Field: "signalSource",
Message: "must be 'http' or 'watcher'",
}
}
if c.SignalSource == string(SignalFromHTTP) && c.ListenAddr == "" {
return &ConfigError{
Field: "listenAddr",
Message: "cannot be empty when signalSource is 'http'",
}
}
return nil
}
// Validate validates the full configuration.
func (c *FullConfig) Validate() error {
if err := c.Agent.Validate(); err != nil {
return err
}
if err := c.Checkpoint.Validate(); err != nil {
return err
}
return nil
}
......@@ -6,174 +6,39 @@ package main
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint"
checkpointk8s "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint/k8s"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
httpApiServer "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/http_api_server"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/watcher"
)
// CheckpointSignalSource determines how checkpoint operations are triggered
type CheckpointSignalSource string
const (
// SignalFromHTTP triggers checkpoints via HTTP API requests
SignalFromHTTP CheckpointSignalSource = "http"
// SignalFromWatcher triggers checkpoints automatically when pods become Ready
SignalFromWatcher CheckpointSignalSource = "watcher"
)
// Config holds the agent configuration
type Config struct {
// Common settings
ContainerdSocket string
CheckpointDir string
HostProc string
NodeName string
RestrictedNamespace string // Optional: restrict pod watching to this namespace
// Mode selection
SignalSource CheckpointSignalSource // "http" or "watcher"
// HTTP API mode settings (used when SignalSource = "http")
ListenAddr string
// CRIU settings (configurable options only; LeaveRunning, ShellJob, etc. are hardcoded in pkg/checkpoint/criu.go)
CUDAPluginDir string // Path to CRIU CUDA plugin directory
GhostLimit uint32 // CRIU ghost limit in bytes
Timeout uint32 // CRIU timeout in seconds
ExternalMounts []string // External mount mappings
}
// Server is the HTTP API server
type Server struct {
config Config
discoveryClient *checkpointk8s.DiscoveryClient
checkpointer *checkpoint.Checkpointer
}
// CheckpointRequest is the request body for checkpoint operations
type CheckpointRequest struct {
ContainerID string `json:"container_id"`
CheckpointID string `json:"checkpoint_id"`
PodName string `json:"pod_name,omitempty"`
PodNamespace string `json:"pod_namespace,omitempty"`
DisableCUDA bool `json:"disable_cuda,omitempty"` // Disable CUDA plugin for non-GPU workloads
}
// TriggerRestoreRequest is the request body for Option A self-restoring trigger
type TriggerRestoreRequest struct {
CheckpointID string `json:"checkpoint_id"`
PlaceholderContainerID string `json:"placeholder_container_id"`
SkipImageValidation bool `json:"skip_image_validation,omitempty"` // Skip image matching check
}
// TriggerRestoreResponse is the response for trigger restore operations
type TriggerRestoreResponse struct {
Success bool `json:"success"`
Message string `json:"message,omitempty"`
Error string `json:"error,omitempty"`
TriggerPath string `json:"trigger_path,omitempty"`
CheckpointImage string `json:"checkpoint_image,omitempty"`
PlaceholderImage string `json:"placeholder_image,omitempty"`
}
// CheckpointResponse is the response for checkpoint operations
type CheckpointResponse struct {
Success bool `json:"success"`
CheckpointID string `json:"checkpoint_id,omitempty"`
Message string `json:"message,omitempty"`
Error string `json:"error,omitempty"`
}
// CheckpointInfo represents information about a checkpoint
type CheckpointInfo struct {
ID string `json:"id"`
CreatedAt time.Time `json:"created_at"`
SourceNode string `json:"source_node"`
ContainerID string `json:"container_id"`
PodName string `json:"pod_name"`
PodNamespace string `json:"pod_namespace"`
Image string `json:"image"`
}
// ListCheckpointsResponse is the response for list checkpoints
type ListCheckpointsResponse struct {
Checkpoints []CheckpointInfo `json:"checkpoints"`
}
// HealthResponse is the response for health check
type HealthResponse struct {
Status string `json:"status"`
NodeName string `json:"node_name"`
}
func main() {
// Parse signal source - default to HTTP for backward compatibility
signalSource := CheckpointSignalSource(strings.ToLower(getEnv("CHECKPOINT_SIGNAL_FROM", "http")))
if signalSource != SignalFromHTTP && signalSource != SignalFromWatcher {
log.Fatalf("Invalid CHECKPOINT_SIGNAL_FROM value: %q (must be 'http' or 'watcher')", signalSource)
}
// Parse CRIU settings
var ghostLimit, timeout uint32
if v := os.Getenv("CRIU_GHOST_LIMIT"); v != "" {
if parsed, err := strconv.ParseUint(v, 10, 32); err == nil {
ghostLimit = uint32(parsed)
}
}
if v := os.Getenv("CRIU_TIMEOUT"); v != "" {
if parsed, err := strconv.ParseUint(v, 10, 32); err == nil {
timeout = uint32(parsed)
}
}
// Parse external mounts (comma-separated)
var externalMounts []string
if v := os.Getenv("EXTERNAL_MOUNTS"); v != "" {
externalMounts = strings.Split(v, ",")
// Load configuration from ConfigMap (or use defaults if not found)
cfg, err := LoadConfigOrDefault(ConfigMapPath)
if err != nil {
log.Fatalf("Failed to load configuration: %v", err)
}
config := Config{
// Common settings
ContainerdSocket: getEnv("CONTAINERD_SOCKET", "/run/containerd/containerd.sock"),
CheckpointDir: getEnv("CHECKPOINT_DIR", "/checkpoints"),
HostProc: getEnv("HOST_PROC", "/host/proc"),
NodeName: getEnv("NODE_NAME", "unknown"),
RestrictedNamespace: os.Getenv("RESTRICTED_NAMESPACE"), // Optional: empty = cluster-wide watching
// Mode selection
SignalSource: signalSource,
// HTTP API settings
ListenAddr: getEnv("LISTEN_ADDR", ":8080"),
// CRIU settings
CUDAPluginDir: getEnv("CUDA_PLUGIN_DIR", ""),
GhostLimit: ghostLimit,
Timeout: timeout,
ExternalMounts: externalMounts,
// Validate configuration
if err := cfg.Agent.Validate(); err != nil {
log.Fatalf("Invalid configuration: %v", err)
}
// Create discovery client
discoveryClient, err := checkpointk8s.NewDiscoveryClient(config.ContainerdSocket)
discoveryClient, err := checkpoint.NewDiscoveryClient()
if err != nil {
log.Fatalf("Failed to create discovery client: %v", err)
}
defer discoveryClient.Close()
// Create checkpointer
checkpointer := checkpoint.NewCheckpointer(discoveryClient, config.HostProc)
checkpointer := checkpoint.NewCheckpointer(discoveryClient)
// Context for graceful shutdown
ctx, cancel := context.WithCancel(context.Background())
......@@ -183,60 +48,39 @@ func main() {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
log.Printf("CRIU Node Agent starting (node: %s)", config.NodeName)
log.Printf("Checkpoint directory: %s", config.CheckpointDir)
log.Printf("Signal source: %s", config.SignalSource)
log.Printf("CRIU Node Agent starting (node: %s)", cfg.Agent.NodeName)
log.Printf("Checkpoint directory: %s", cfg.Checkpoint.BasePath)
log.Printf("Signal source: %s", cfg.Agent.SignalSource)
switch config.SignalSource {
switch cfg.Agent.GetSignalSource() {
case SignalFromHTTP:
server := &Server{
config: config,
discoveryClient: discoveryClient,
checkpointer: checkpointer,
}
// Setup routes
mux := http.NewServeMux()
mux.HandleFunc("/health", server.handleHealth)
mux.HandleFunc("/checkpoint", server.handleCheckpoint)
mux.HandleFunc("/restore/trigger", server.handleTriggerRestore)
mux.HandleFunc("/checkpoints", server.handleListCheckpoints)
httpServer := &http.Server{
Addr: config.ListenAddr,
Handler: loggingMiddleware(mux),
ReadTimeout: 30 * time.Second,
WriteTimeout: 300 * time.Second,
IdleTimeout: 120 * time.Second,
serverCfg := httpApiServer.ServerConfig{
ListenAddr: cfg.Agent.ListenAddr,
NodeName: cfg.Agent.NodeName,
CheckpointSpec: &cfg.Checkpoint,
}
srv := httpApiServer.NewServer(serverCfg, checkpointer)
// Handle graceful shutdown
go func() {
<-sigChan
log.Println("Shutting down HTTP server...")
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Printf("HTTP server shutdown error: %v", err)
}
}()
log.Printf("HTTP API server listening on %s", config.ListenAddr)
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
if err := srv.Start(); err != http.ErrServerClosed {
log.Fatalf("HTTP server error: %v", err)
}
case SignalFromWatcher:
watcherConfig := watcher.Config{
NodeName: config.NodeName,
CheckpointDir: config.CheckpointDir,
HostProc: config.HostProc,
ListenAddr: config.ListenAddr, // For health check endpoint
RestrictedNamespace: config.RestrictedNamespace,
CUDAPluginDir: config.CUDAPluginDir,
GhostLimit: config.GhostLimit,
Timeout: config.Timeout,
ExternalMounts: config.ExternalMounts,
watcherConfig := watcher.WatcherConfig{
NodeName: cfg.Agent.NodeName,
ListenAddr: cfg.Agent.ListenAddr,
RestrictedNamespace: cfg.Agent.RestrictedNamespace,
CheckpointSpec: &cfg.Checkpoint,
}
podWatcher, err := watcher.NewWatcher(watcherConfig, discoveryClient, checkpointer)
......@@ -251,304 +95,15 @@ func main() {
cancel()
}()
log.Printf("Pod watcher started (watching for label: nvidia.com/checkpoint-source=true)")
log.Printf("Health check endpoint: http://0.0.0.0%s/health", config.ListenAddr)
log.Printf("Pod watcher started (watching for label: %s=true)", checkpoint.KubeLabelCheckpointSource)
log.Printf("Health check endpoint: http://0.0.0.0%s/health", cfg.Agent.ListenAddr)
if err := podWatcher.Start(ctx); err != nil {
log.Printf("Pod watcher error: %v", err)
}
}
log.Println("Agent stopped")
}
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
resp := HealthResponse{
Status: "healthy",
NodeName: s.config.NodeName,
}
writeJSON(w, http.StatusOK, resp)
}
func (s *Server) handleCheckpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req CheckpointRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, CheckpointResponse{
Success: false,
Error: fmt.Sprintf("Invalid request body: %v", err),
})
return
}
if req.ContainerID == "" {
writeJSON(w, http.StatusBadRequest, CheckpointResponse{
Success: false,
Error: "container_id is required",
})
return
}
if req.CheckpointID == "" {
req.CheckpointID = fmt.Sprintf("ckpt-%d", time.Now().UnixNano())
}
// Determine CUDA plugin directory - only use if not explicitly disabled
cudaPluginDir := s.config.CUDAPluginDir
if req.DisableCUDA {
cudaPluginDir = ""
}
// Parse optional CRIU settings from environment
var ghostLimit, timeout uint32
if v := os.Getenv("CRIU_GHOST_LIMIT"); v != "" {
if parsed, err := strconv.ParseUint(v, 10, 32); err == nil {
ghostLimit = uint32(parsed)
}
}
if v := os.Getenv("CRIU_TIMEOUT"); v != "" {
if parsed, err := strconv.ParseUint(v, 10, 32); err == nil {
timeout = uint32(parsed)
}
default:
log.Fatalf("Unknown signal source: %s", cfg.Agent.SignalSource)
}
opts := checkpoint.Options{
ContainerID: req.ContainerID,
CheckpointID: req.CheckpointID,
CheckpointDir: s.config.CheckpointDir,
NodeName: s.config.NodeName,
PodName: req.PodName,
PodNamespace: req.PodNamespace,
GhostLimit: ghostLimit,
Timeout: timeout,
CUDAPluginDir: cudaPluginDir,
}
ctx := r.Context()
result, err := s.checkpointer.Checkpoint(ctx, opts)
if err != nil {
log.Printf("Checkpoint failed: %v", err)
writeJSON(w, http.StatusInternalServerError, CheckpointResponse{
Success: false,
Error: err.Error(),
})
return
}
log.Printf("Checkpoint successful: %s", result.CheckpointID)
writeJSON(w, http.StatusOK, CheckpointResponse{
Success: true,
CheckpointID: result.CheckpointID,
Message: fmt.Sprintf("Checkpoint created successfully at %s", result.CheckpointDir),
})
}
// handleTriggerRestore implements Option A from RESTORE_ANALYSIS.md
// It triggers a self-restoring placeholder container to start CRIU restore.
// The agent writes a trigger file to the placeholder's filesystem, which
// the placeholder's entrypoint script detects and uses to start restoration.
func (s *Server) handleTriggerRestore(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req TriggerRestoreRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, TriggerRestoreResponse{
Success: false,
Error: fmt.Sprintf("Invalid request body: %v", err),
})
return
}
if req.CheckpointID == "" {
writeJSON(w, http.StatusBadRequest, TriggerRestoreResponse{
Success: false,
Error: "checkpoint_id is required",
})
return
}
if req.PlaceholderContainerID == "" {
writeJSON(w, http.StatusBadRequest, TriggerRestoreResponse{
Success: false,
Error: "placeholder_container_id is required",
})
return
}
// Verify checkpoint exists and load metadata
checkpointPath := common.GetCheckpointDir(s.config.CheckpointDir, req.CheckpointID)
checkpointMeta, err := common.LoadMetadata(checkpointPath)
if err != nil {
writeJSON(w, http.StatusNotFound, TriggerRestoreResponse{
Success: false,
Error: fmt.Sprintf("Checkpoint not found: %v", err),
})
return
}
// Resolve placeholder container to get PID and image
ctx := r.Context()
containerInfo, err := s.discoveryClient.ResolveContainer(ctx, req.PlaceholderContainerID)
if err != nil {
writeJSON(w, http.StatusInternalServerError, TriggerRestoreResponse{
Success: false,
Error: fmt.Sprintf("Failed to resolve placeholder container: %v", err),
})
return
}
// Validate that placeholder image matches checkpoint's original image
// This is critical because rootfs-diff.tar only contains upperdir modifications,
// not the base image layers (lowerdir). If images differ, CRIU restore will fail.
if !req.SkipImageValidation && checkpointMeta.Image != "" {
if !imagesCompatible(checkpointMeta.Image, containerInfo.Image) {
writeJSON(w, http.StatusBadRequest, TriggerRestoreResponse{
Success: false,
Error: fmt.Sprintf("Image mismatch: checkpoint was from '%s' but placeholder uses '%s'. The placeholder must use the same base image. Use skip_image_validation=true to override.", checkpointMeta.Image, containerInfo.Image),
CheckpointImage: checkpointMeta.Image,
PlaceholderImage: containerInfo.Image,
})
return
}
log.Printf("Image validation passed: checkpoint=%s, placeholder=%s", checkpointMeta.Image, containerInfo.Image)
}
// Write trigger file to placeholder's filesystem via /proc/<pid>/root
// The trigger file contains the checkpoint path
triggerPath := fmt.Sprintf("%s/%d/root/tmp/restore-trigger", s.config.HostProc, containerInfo.PID)
// Write the checkpoint path to the trigger file
if err := os.WriteFile(triggerPath, []byte(checkpointPath), 0644); err != nil {
writeJSON(w, http.StatusInternalServerError, TriggerRestoreResponse{
Success: false,
Error: fmt.Sprintf("Failed to write trigger file: %v", err),
})
return
}
log.Printf("Triggered restore for placeholder %s (PID %d) from checkpoint %s",
req.PlaceholderContainerID, containerInfo.PID, req.CheckpointID)
writeJSON(w, http.StatusOK, TriggerRestoreResponse{
Success: true,
Message: fmt.Sprintf("Restore triggered for checkpoint %s", req.CheckpointID),
TriggerPath: triggerPath,
CheckpointImage: checkpointMeta.Image,
PlaceholderImage: containerInfo.Image,
})
}
func (s *Server) handleListCheckpoints(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
checkpointIDs, err := common.ListCheckpoints(s.config.CheckpointDir)
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]string{
"error": err.Error(),
})
return
}
var checkpoints []CheckpointInfo
for _, id := range checkpointIDs {
meta, err := common.GetCheckpointInfo(s.config.CheckpointDir, id)
if err != nil {
continue
}
checkpoints = append(checkpoints, CheckpointInfo{
ID: meta.CheckpointID,
CreatedAt: meta.CreatedAt,
SourceNode: meta.SourceNode,
ContainerID: meta.ContainerID,
PodName: meta.PodName,
PodNamespace: meta.PodNamespace,
Image: meta.Image,
})
}
writeJSON(w, http.StatusOK, ListCheckpointsResponse{
Checkpoints: checkpoints,
})
}
func writeJSON(w http.ResponseWriter, status int, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
func loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
log.Printf("Started %s %s", r.Method, r.URL.Path)
next.ServeHTTP(w, r)
log.Printf("Completed %s %s in %v", r.Method, r.URL.Path, time.Since(start))
})
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
// imagesCompatible checks if two container images are compatible for CRIU restore.
// The placeholder image must be based on the same image as the checkpoint.
// Handles various image name formats:
// - nginx:alpine vs nginx:alpine (exact match)
// - docker.io/library/nginx:alpine vs nginx:alpine (registry prefix)
// - criu-placeholder-nginx-alpine vs nginx:alpine (placeholder naming convention)
func imagesCompatible(checkpointImage, placeholderImage string) bool {
// Exact match
if checkpointImage == placeholderImage {
return true
}
// Normalize images by removing common registry prefixes
normalize := func(img string) string {
// Remove docker.io/library/ prefix
img = strings.TrimPrefix(img, "docker.io/library/")
// Remove docker.io/ prefix
img = strings.TrimPrefix(img, "docker.io/")
return img
}
normalizedCheckpoint := normalize(checkpointImage)
normalizedPlaceholder := normalize(placeholderImage)
if normalizedCheckpoint == normalizedPlaceholder {
return true
}
// Check if placeholder follows criu-placeholder-<image> naming convention
// e.g., criu-placeholder-nginx-alpine should match nginx:alpine
if strings.HasPrefix(normalizedPlaceholder, "criu-placeholder-") {
// Convert nginx:alpine to nginx-alpine for comparison
checkpointSanitized := strings.ReplaceAll(normalizedCheckpoint, ":", "-")
checkpointSanitized = strings.ReplaceAll(checkpointSanitized, "/", "-")
expectedPlaceholder := "criu-placeholder-" + checkpointSanitized
if normalizedPlaceholder == expectedPlaceholder ||
strings.HasPrefix(normalizedPlaceholder, expectedPlaceholder+":") {
return true
}
}
return false
log.Println("Agent stopped")
}
......@@ -5,8 +5,12 @@ package main
import (
"context"
"fmt"
"os"
"os/exec"
"os/signal"
"path/filepath"
"strings"
"syscall"
"github.com/sirupsen/logrus"
......@@ -14,7 +18,50 @@ import (
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/restore"
)
// logGPUDiagnostics logs nvidia-smi output and /dev/nvidia* devices for debugging GPU visibility.
func logGPUDiagnostics(label string) {
fmt.Printf("=== GPU DIAGNOSTICS [%s] ===\n", label)
// nvidia-smi
if out, err := exec.Command("nvidia-smi", "-L").CombinedOutput(); err != nil {
fmt.Printf("nvidia-smi -L: error: %v\n", err)
} else {
fmt.Printf("nvidia-smi -L:\n%s", out)
}
// GPU memory usage
if out, err := exec.Command("nvidia-smi", "--query-gpu=index,uuid,memory.used,memory.total,memory.free", "--format=csv,noheader").CombinedOutput(); err != nil {
fmt.Printf("nvidia-smi memory query: error: %v\n", err)
} else {
fmt.Printf("nvidia-smi memory:\n%s", out)
}
// /dev/nvidia* devices
matches, _ := filepath.Glob("/dev/nvidia*")
fmt.Printf("/dev/nvidia* devices: %s\n", strings.Join(matches, ", "))
// NVIDIA_VISIBLE_DEVICES env
fmt.Printf("NVIDIA_VISIBLE_DEVICES=%s\n", os.Getenv("NVIDIA_VISIBLE_DEVICES"))
fmt.Printf("CUDA_VISIBLE_DEVICES=%s\n", os.Getenv("CUDA_VISIBLE_DEVICES"))
// Linux namespaces for PID 1
for _, ns := range []string{"mnt", "pid", "ipc", "net", "uts", "cgroup"} {
link, err := os.Readlink(fmt.Sprintf("/proc/1/ns/%s", ns))
if err != nil {
link = err.Error()
}
fmt.Printf("ns/%s: %s\n", ns, link)
}
fmt.Printf("=== END GPU DIAGNOSTICS [%s] ===\n", label)
}
func main() {
// Log GPU diagnostics BEFORE anything else (gated on DEBUG for production quietness)
if os.Getenv("DEBUG") == "1" {
logGPUDiagnostics("PRE-RESTORE")
}
// Set up logging
log := logrus.New()
log.SetOutput(os.Stdout)
......@@ -23,8 +70,12 @@ func main() {
TimestampFormat: "2006-01-02 15:04:05",
})
// Load configuration from environment
cfg := restore.ConfigFromEnv()
// Load configuration from hardcoded defaults + operator-injected env vars.
// os.Args[1:] are the cold start command args (passed by the operator via pod spec).
cfg, err := restore.NewRestoreRequest(os.Args[1:])
if err != nil {
log.WithError(err).Fatal("Failed to load restore configuration")
}
// Set log level based on DEBUG flag
if cfg.Debug {
......
......@@ -6,19 +6,57 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
criu "github.com/checkpoint-restore/go-criu/v7"
criurpc "github.com/checkpoint-restore/go-criu/v7/rpc"
specs "github.com/opencontainers/runtime-spec/specs-go"
"github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto"
checkpointk8s "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint/k8s"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
)
// Options configures the checkpoint operation
type Options struct {
// ContainerInfoSnapshot holds runtime/container info needed for checkpointing.
type ContainerInfoSnapshot struct {
PID int
RootFS string
UpperDir string
OCISpec *specs.Spec
MountInfo []MountInfo
Namespaces map[NamespaceType]*NamespaceInfo
}
// CheckpointManifest is saved as manifest.yaml at checkpoint time and loaded at restore.
type CheckpointManifest struct {
CheckpointID string `yaml:"checkpointId"`
CreatedAt time.Time `yaml:"createdAt"`
CRIUDump CRIUDumpManifest `yaml:"criuDump"`
K8s SourcePodManifest `yaml:"k8s"`
Filesystem FilesystemManifest `yaml:"filesystem"`
Namespaces []NamespaceManifestEntry `yaml:"namespaces"`
}
// NewCheckpointManifest assembles a CheckpointManifest from per-module builders.
func NewCheckpointManifest(
checkpointID string,
criuDump CRIUDumpManifest,
k8s SourcePodManifest,
filesystem FilesystemManifest,
namespaces []NamespaceManifestEntry,
) *CheckpointManifest {
return &CheckpointManifest{
CheckpointID: checkpointID,
CreatedAt: time.Now().UTC(),
CRIUDump: criuDump,
K8s: k8s,
Filesystem: filesystem,
Namespaces: namespaces,
}
}
// CheckpointRequest holds per-checkpoint identifiers for a checkpoint operation.
type CheckpointRequest struct {
ContainerID string
ContainerName string // K8s container name (for K8s API volume type lookup)
CheckpointID string
......@@ -26,237 +64,180 @@ type Options struct {
NodeName string
PodName string
PodNamespace string
// CRIU options (from environment variables)
GhostLimit uint32 // From CRIU_GHOST_LIMIT: ghost file size limit in bytes (0 = CRIU default)
Timeout uint32 // From CRIU_TIMEOUT: timeout in seconds (0 = no timeout)
// GPU/CUDA checkpoint options
CUDAPluginDir string // Path to CRIU CUDA plugin (e.g., /home/mmshin/work/criu/plugins/cuda)
ExternalMounts []string // Additional external mount mappings (e.g., "mnt[path]:path")
}
// Result contains the result of a checkpoint operation
type Result struct {
// CheckpointOutcome contains the result of a checkpoint operation.
type CheckpointOutcome struct {
CheckpointID string
CheckpointDir string
Metadata *common.CheckpointMetadata
Data *CheckpointManifest
}
// Checkpointer performs CRIU checkpoint operations
type Checkpointer struct {
discoveryClient *checkpointk8s.DiscoveryClient
k8sClient *checkpointk8s.K8sClient // Optional: for accurate volume type discovery from K8s API
hostProc string
discoveryClient *DiscoveryClient
log *logrus.Entry
}
// NewCheckpointer creates a new checkpointer
func NewCheckpointer(discoveryClient *checkpointk8s.DiscoveryClient, hostProc string) *Checkpointer {
if hostProc == "" {
hostProc = os.Getenv("HOST_PROC")
if hostProc == "" {
hostProc = "/proc"
}
}
func NewCheckpointer(discoveryClient *DiscoveryClient) *Checkpointer {
return &Checkpointer{
discoveryClient: discoveryClient,
hostProc: hostProc,
log: logrus.WithField("component", "checkpointer"),
}
}
// WithK8sClient sets an optional Kubernetes client for accurate volume type discovery.
// When set, volume types are fetched from the K8s API instead of being inferred from paths.
func (c *Checkpointer) WithK8sClient(client *checkpointk8s.K8sClient) *Checkpointer {
c.k8sClient = client
return c
}
// Checkpoint performs a CRIU dump of a container
func (c *Checkpointer) Checkpoint(ctx context.Context, opts Options) (*Result, error) {
// Checkpoint performs a CRIU dump of a container.
// The operation has three phases: introspect, configure, capture.
func (c *Checkpointer) Checkpoint(ctx context.Context, req CheckpointRequest, spec *CheckpointSpec) (*CheckpointOutcome, error) {
if spec == nil {
return nil, fmt.Errorf("checkpoint spec is required")
}
checkpointStart := time.Now()
c.log.Info("=== Starting checkpoint operation ===")
// 1. Resolve container to get PID
resolveStart := time.Now()
containerInfo, err := c.discoveryClient.ResolveContainer(ctx, opts.ContainerID)
if err != nil {
return nil, fmt.Errorf("failed to resolve container: %w", err)
}
pid := int(containerInfo.PID)
c.log.WithField("duration", time.Since(resolveStart)).Info("Container resolution completed")
// 2. Create checkpoint directory
checkpointDir := common.GetCheckpointDir(opts.CheckpointDir, opts.CheckpointID)
checkpointDir := filepath.Join(req.CheckpointDir, req.CheckpointID)
if err := os.MkdirAll(checkpointDir, 0700); err != nil {
return nil, fmt.Errorf("failed to create checkpoint directory: %w", err)
}
// 3. Introspect container state
introspectStart := time.Now()
rootFS, err := GetRootFS(pid, c.hostProc)
// Open image directory FD for CRIU — must stay open through both configure and capture
// phases since CRIU's swrk child process inherits this FD.
imageDir, imageDirFD, err := common.OpenPathForCRIU(checkpointDir)
if err != nil {
return nil, fmt.Errorf("failed to get rootfs: %w", err)
return nil, fmt.Errorf("failed to open image directory: %w", err)
}
mounts, err := GetKubernetesVolumeMounts(pid, c.hostProc)
defer imageDir.Close()
// Phase 1: Introspect container state
state, err := c.introspect(ctx, req.ContainerID)
if err != nil {
return nil, fmt.Errorf("failed to get mounts: %w", err)
return nil, err
}
namespaces, err := GetAllNamespaces(pid, c.hostProc)
// Phase 2: Configure CRIU options and build checkpoint manifest.
criuOpts, data, err := c.configure(state, req, spec, checkpointDir, imageDirFD)
if err != nil {
return nil, fmt.Errorf("failed to get namespaces: %w", err)
return nil, err
}
c.log.WithField("duration", time.Since(introspectStart)).Info("Container introspection completed")
// 4. Open image directory FD
imageDir, imageDirFD, err := OpenImageDir(checkpointDir)
// Phase 3: Capture — CRIU dump, /dev/shm, rootfs diff
criuDumpDuration, err := c.capture(criuOpts, data, state, checkpointDir)
if err != nil {
return nil, err
}
defer imageDir.Close()
// 5. Build CRIU options
criuOpts := BuildCRIUOptsFromCheckpointOpts(opts, pid, imageDirFD, rootFS)
totalDuration := time.Since(checkpointStart)
c.log.WithFields(logrus.Fields{
"total_duration": totalDuration,
"criu_dump_duration": criuDumpDuration,
}).Info("=== Checkpoint operation completed ===")
// 6. Create CRIU config file for CUDA plugin (libdir is not available via RPC)
if opts.CUDAPluginDir != "" {
if opts.Timeout == 0 {
return nil, fmt.Errorf("CRIU_TIMEOUT environment variable must be set for CUDA checkpoints")
}
configPath := filepath.Join(checkpointDir, "criu.conf")
configContent := fmt.Sprintf(`enable-external-masters
libdir %s
tcp-close
link-remap
timeout %d
allow-uprobes
skip-in-flight
`, opts.CUDAPluginDir, opts.Timeout)
if err := os.WriteFile(configPath, []byte(configContent), 0644); err != nil {
return nil, fmt.Errorf("failed to write CRIU config file: %w", err)
}
criuOpts.ConfigFile = proto.String(configPath)
c.log.WithFields(logrus.Fields{
"config_path": configPath,
"plugin_dir": opts.CUDAPluginDir,
}).Info("Created CRIU config file for CUDA plugin")
}
return &CheckpointOutcome{
CheckpointID: req.CheckpointID,
CheckpointDir: checkpointDir,
Data: data,
}, nil
}
// 7. Configure external mounts and namespaces
if err := ConfigureExternalMounts(criuOpts, pid, c.hostProc, containerInfo); err != nil {
return nil, err
}
netNsInode := ConfigureExternalNamespaces(criuOpts, namespaces, opts.ExternalMounts)
if netNsInode > 0 {
c.log.WithField("inode", netNsInode).Debug("Marked network namespace as external")
}
for _, extMount := range opts.ExternalMounts {
c.log.WithField("external", extMount).Debug("Added external mount mapping")
// introspect resolves the container and gathers all runtime state from containerd and /proc.
func (c *Checkpointer) introspect(ctx context.Context, containerID string) (*ContainerInfoSnapshot, error) {
pid, ociSpec, err := c.discoveryClient.ResolveContainer(ctx, containerID)
if err != nil {
return nil, fmt.Errorf("failed to resolve container: %w", err)
}
// 8. Get overlay upperdir for rootfs diff capture
upperDir, upperDirErr := GetOverlayUpperDir(pid, c.hostProc)
if upperDirErr != nil {
c.log.WithError(upperDirErr).Warn("Could not get overlay upperdir - rootfs diff will not be captured")
} else {
c.log.WithField("upperdir", upperDir).Debug("Found overlay upperdir")
rootFS, err := GetRootFS(pid)
if err != nil {
return nil, fmt.Errorf("failed to get rootfs: %w", err)
}
// 9. Build and save initial metadata before dump
metaCfg := MetadataBuilderConfig{
CheckpointID: opts.CheckpointID,
NodeName: opts.NodeName,
ContainerID: opts.ContainerID,
ContainerName: opts.ContainerName,
PodName: opts.PodName,
PodNamespace: opts.PodNamespace,
PID: pid,
CUDAPluginDir: opts.CUDAPluginDir,
upperDir, err := GetOverlayUpperDir(pid)
if err != nil {
return nil, fmt.Errorf("failed to get overlay upperdir: %w", err)
}
meta := BuildCheckpointMetadata(ctx, metaCfg, containerInfo, mounts, namespaces, c.k8sClient, c.log)
if upperDir != "" {
meta.UpperDir = upperDir
mountInfo, err := ReadMountInfoFromHostProcPath(pid)
if err != nil {
return nil, fmt.Errorf("failed to parse mountinfo: %w", err)
}
if err := common.SaveMetadata(checkpointDir, meta); err != nil {
return nil, fmt.Errorf("failed to save metadata: %w", err)
namespaces, err := GetAllNamespaces(pid)
if err != nil {
return nil, fmt.Errorf("failed to get namespaces: %w", err)
}
// 10. Remove semaphores from /dev/shm before checkpoint
// Semaphores cause CRIU restore to fail with "Can't link dev/shm/link_remap.X -> dev/shm/sem.Y"
if err := c.removeSemaphores(pid); err != nil {
return nil, fmt.Errorf("failed to remove semaphores: %w", err)
return &ContainerInfoSnapshot{
PID: pid,
RootFS: rootFS,
UpperDir: upperDir,
OCISpec: ociSpec,
MountInfo: mountInfo,
Namespaces: namespaces,
}, nil
}
// configure builds CRIU options and checkpoint manifest from runtime snapshot and spec.
func (c *Checkpointer) configure(
state *ContainerInfoSnapshot,
req CheckpointRequest,
spec *CheckpointSpec,
checkpointDir string,
imageDirFD int32,
) (*criurpc.CriuOpts, *CheckpointManifest, error) {
criuOpts, err := BuildCRIUDumpOptions(
&spec.CRIU,
state.PID,
imageDirFD,
state.RootFS,
state.MountInfo,
state.OCISpec,
state.Namespaces,
)
if err != nil {
return nil, nil, err
}
// 11. Execute CRIU dump via go-criu
criuDumpStart := time.Now()
criuClient := criu.MakeCriu()
if err := criuClient.Dump(criuOpts, nil); err != nil {
c.log.WithField("duration", time.Since(criuDumpStart)).Error("CRIU dump failed")
return nil, fmt.Errorf("CRIU dump failed: %w", err)
// Write CRIU config file (for options unavailable via RPC)
configPath := filepath.Join(checkpointDir, CheckpointCRIUConfFilename)
if err := os.WriteFile(configPath, []byte(spec.CRIU.GenerateCRIUConfContent()), 0644); err != nil {
return nil, nil, fmt.Errorf("failed to write CRIU config file: %w", err)
}
criuDumpDuration := time.Since(criuDumpStart)
c.log.WithField("duration", criuDumpDuration).Info("CRIU dump completed successfully")
criuOpts.ConfigFile = proto.String(configPath)
// 12. Capture rootfs diff and deleted files
rootfsCaptureStart := time.Now()
CaptureRootfsState(upperDir, checkpointDir, meta, c.log)
c.log.WithField("duration", time.Since(rootfsCaptureStart)).Info("Rootfs capture completed")
// Build and save the checkpoint manifest.
manifest := NewCheckpointManifest(
req.CheckpointID,
NewCRIUDumpManifest(criuOpts, spec.CRIU),
NewSourcePodManifest(req, state.PID),
NewFilesystemManifest(spec.RootfsExclusions, state.UpperDir, state.OCISpec),
NewNamespaceManifestEntries(state.Namespaces),
)
totalDuration := time.Since(checkpointStart)
c.log.WithFields(logrus.Fields{
"total_duration": totalDuration,
"criu_dump_duration": criuDumpDuration,
}).Info("=== Checkpoint operation completed ===")
if err := WriteCheckpointManifest(checkpointDir, manifest); err != nil {
return nil, nil, fmt.Errorf("failed to write checkpoint manifest: %w", err)
}
return &Result{
CheckpointID: opts.CheckpointID,
CheckpointDir: checkpointDir,
Metadata: meta,
}, nil
return criuOpts, manifest, nil
}
// removeSemaphores removes POSIX semaphores from the container's /dev/shm.
// Semaphores can cause issues during CRIU checkpoint/restore because they
// maintain kernel state that may not transfer correctly between processes.
// This accesses the container's filesystem via /proc/<pid>/root/dev/shm/.
func (c *Checkpointer) removeSemaphores(pid int) error {
shmPath := filepath.Join(c.hostProc, fmt.Sprintf("%d/root/dev/shm", pid))
entries, err := os.ReadDir(shmPath)
// capture executes the CRIU dump and post-dump captures (/dev/shm, rootfs diff).
// Returns the CRIU dump duration for timing reporting.
func (c *Checkpointer) capture(
criuOpts *criurpc.CriuOpts,
data *CheckpointManifest,
state *ContainerInfoSnapshot,
checkpointDir string,
) (time.Duration, error) {
criuDumpDuration, err := ExecuteCRIUDump(criuOpts, checkpointDir, c.log)
if err != nil {
// It's okay if /dev/shm doesn't exist (container may not have it)
c.log.WithError(err).Debug("Could not read container /dev/shm (may not exist)")
return nil
return 0, err
}
var removed []string
var errors []error
for _, entry := range entries {
name := entry.Name()
if strings.HasPrefix(name, "sem.") {
semPath := filepath.Join(shmPath, name)
if err := os.Remove(semPath); err != nil {
c.log.WithError(err).WithField("semaphore", name).Error("Failed to remove semaphore")
errors = append(errors, fmt.Errorf("failed to remove semaphore %s: %w", name, err))
} else {
removed = append(removed, name)
}
}
// Capture /dev/shm contents (must happen after dump for final process state)
if err := CaptureDevShm(state.PID, checkpointDir, c.log); err != nil {
c.log.WithError(err).Warn("Failed to capture /dev/shm contents")
}
if len(errors) > 0 {
return fmt.Errorf("failed to remove %d semaphore(s): %v", len(errors), errors)
}
if len(removed) > 0 {
c.log.WithFields(logrus.Fields{
"count": len(removed),
"semaphores": removed,
}).Info("Removed semaphores from container /dev/shm before checkpoint")
} else {
c.log.Debug("No semaphores found in container /dev/shm")
}
// Capture rootfs diff and deleted files
CaptureRootfsState(state.UpperDir, checkpointDir, data, c.log)
return nil
return criuDumpDuration, nil
}
// config.go defines the static checkpoint spec loaded from ConfigMap YAML.
package checkpoint
import "fmt"
// CheckpointSpec is the static checkpoint spec loaded from ConfigMap YAML.
type CheckpointSpec struct {
// BasePath is the base directory for checkpoint storage (PVC mount point).
BasePath string `yaml:"basePath"`
// CRIU options for dump operations
CRIU CRIUSettings `yaml:"criu"`
// RootfsExclusions defines paths to exclude from rootfs diff capture
RootfsExclusions FilesystemConfig `yaml:"rootfsExclusions"`
}
// Validate checks that the CheckpointSpec has valid values.
func (c *CheckpointSpec) Validate() error {
return c.RootfsExclusions.Validate()
}
// ConfigError represents a configuration validation error.
type ConfigError struct {
Field string
Message string
}
func (e *ConfigError) Error() string {
return fmt.Sprintf("config error: %s: %s", e.Field, e.Message)
}
// constants.go defines shared constants used across checkpoint and restore packages.
package checkpoint
const (
// HostProcPath is the mount point for the host's /proc in DaemonSet pods.
HostProcPath = "/host/proc"
// DevShmDirName is the directory name for captured /dev/shm contents.
DevShmDirName = "dev-shm"
// KubeLabelCheckpointSource is the pod label that triggers automatic checkpointing.
// Set by the operator on checkpoint-eligible pods.
KubeLabelCheckpointSource = "nvidia.com/checkpoint-source"
// KubeLabelCheckpointHash is the pod label specifying the checkpoint identity hash.
// Set by the operator on checkpoint-eligible pods.
KubeLabelCheckpointHash = "nvidia.com/checkpoint-hash"
// DumpLogFilename is the CRIU dump (checkpoint) log filename.
DumpLogFilename = "dump.log"
// CheckpointCRIUConfFilename is the CRIU config file written at checkpoint time.
CheckpointCRIUConfFilename = "criu.conf"
// CheckpointDoneFilename is the marker file written to the checkpoint directory
// after all checkpoint artifacts are complete. Used to detect checkpoint readiness.
// Also hard-coded in vLLM for early-exit when checkpoint already exists.
CheckpointDoneFilename = "checkpoint.done"
// CheckpointManifestFilename is the name of the manifest file in checkpoint directories.
CheckpointManifestFilename = "manifest.yaml"
// DescriptorsFilename is the name of the file descriptors file.
DescriptorsFilename = "descriptors.yaml"
// RootfsDiffFilename is the name of the rootfs diff tar in checkpoint directories.
RootfsDiffFilename = "rootfs-diff.tar"
// DeletedFilesFilename is the name of the deleted files JSON in checkpoint directories.
DeletedFilesFilename = "deleted-files.json"
)
......@@ -3,176 +3,256 @@ package checkpoint
import (
"fmt"
"os"
"time"
criu "github.com/checkpoint-restore/go-criu/v7"
criurpc "github.com/checkpoint-restore/go-criu/v7/rpc"
specs "github.com/opencontainers/runtime-spec/specs-go"
"github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto"
checkpointk8s "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint/k8s"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
)
// CRIUConfig holds configuration for CRIU dump operations.
// Most options are always-on with safe defaults for K8s environments.
type CRIUConfig struct {
PID int
ImageDirFD int32
RootFS string
GhostLimit uint32 // From env CRIU_GHOST_LIMIT: max ghost file size (0 = CRIU default)
Timeout uint32 // From env CRIU_TIMEOUT: checkpoint timeout in seconds (0 = no timeout)
}
// CRIUSettings holds CRIU-specific configuration options.
// Options are categorized by how they are passed to CRIU:
// - RPC options: Passed via go-criu CriuOpts protobuf
// - CRIU conf file options: Written to criu.conf (NOT available via RPC)
type CRIUSettings struct {
// === RPC Options (passed via go-criu CriuOpts) ===
// GhostLimit is the maximum ghost file size in bytes.
// Ghost files are deleted-but-open files that CRIU needs to checkpoint.
// 512MB is recommended for GPU workloads with large memory allocations.
GhostLimit uint32 `yaml:"ghostLimit"`
// Timeout is the CRIU operation timeout in seconds.
// 6 hours (21600s) is recommended for large GPU model checkpoints.
Timeout uint32 `yaml:"timeout"`
// LogLevel is the CRIU logging verbosity (0-4).
LogLevel int32 `yaml:"logLevel"`
// WorkDir is the CRIU work directory for temporary files.
WorkDir string `yaml:"workDir"`
// AutoDedup enables auto-deduplication of memory pages.
AutoDedup bool `yaml:"autoDedup"`
// LazyPages enables lazy page migration (experimental).
LazyPages bool `yaml:"lazyPages"`
// LeaveRunning keeps the process running after checkpoint (dump only).
LeaveRunning bool `yaml:"leaveRunning"`
// ShellJob allows checkpointing session leaders (containers are often session leaders).
ShellJob bool `yaml:"shellJob"`
// TcpClose closes TCP connections instead of preserving them (pod IPs change on restore).
TcpClose bool `yaml:"tcpClose"`
// FileLocks allows checkpointing processes with file locks.
FileLocks bool `yaml:"fileLocks"`
// OrphanPtsMaster allows checkpointing containers with TTYs.
OrphanPtsMaster bool `yaml:"orphanPtsMaster"`
// ExtUnixSk allows external Unix sockets.
ExtUnixSk bool `yaml:"extUnixSk"`
// OpenImageDir opens a checkpoint directory and prepares it for CRIU.
// Returns the opened file and its FD. The caller must close the file when done.
// The file descriptor has CLOEXEC cleared so it can be inherited by CRIU.
func OpenImageDir(checkpointDir string) (*os.File, int32, error) {
return common.OpenDirForCRIU(checkpointDir)
// LinkRemap handles deleted-but-open files.
LinkRemap bool `yaml:"linkRemap"`
// ExtMasters allows external bind mount masters.
ExtMasters bool `yaml:"extMasters"`
// ManageCgroupsMode controls cgroup handling: "ignore" lets K8s manage cgroups.
ManageCgroupsMode string `yaml:"manageCgroupsMode"`
// === CRIU Conf File Options (NOT available via RPC - written to criu.conf) ===
// LibDir is the path to CRIU plugin directory (e.g., /usr/local/lib/criu).
// Required for CUDA checkpoint/restore.
LibDir string `yaml:"libDir"`
// AllowUprobes allows user-space probes (required for CUDA checkpoints).
AllowUprobes bool `yaml:"allowUprobes"`
// SkipInFlight skips in-flight TCP connections during checkpoint/restore.
SkipInFlight bool `yaml:"skipInFlight"`
}
// BuildCRIUOpts creates CRIU options from a config struct.
// This sets up the base options; external mounts and namespaces are added separately.
//
// Always-on options for K8s:
// - LeaveRunning: always keep process running after checkpoint
// - ShellJob: containers are often session leaders
// - TcpClose: pod IPs change on restore/migration
// - FileLocks: applications use file locks
// - OrphanPtsMaster: containers with TTYs
// - ExtUnixSk: containers have external Unix sockets
// - ManageCgroups (IGNORE): let K8s manage cgroups
// - LinkRemap: handle deleted-but-open files (safe for all workloads)
// - ExtMasters: external bind mount masters (safe for all workloads)
func BuildCRIUOpts(cfg CRIUConfig) *criurpc.CriuOpts {
cgMode := criurpc.CriuCgMode_IGNORE
criuOpts := &criurpc.CriuOpts{
Pid: proto.Int32(int32(cfg.PID)),
ImagesDirFd: proto.Int32(cfg.ImageDirFD),
LogLevel: proto.Int32(4),
LogFile: proto.String("dump.log"),
Root: proto.String(cfg.RootFS),
ManageCgroups: proto.Bool(true),
ManageCgroupsMode: &cgMode,
// Always-on for K8s environments
LeaveRunning: proto.Bool(true),
ShellJob: proto.Bool(true),
TcpClose: proto.Bool(true),
FileLocks: proto.Bool(true),
OrphanPtsMaster: proto.Bool(true),
ExtUnixSk: proto.Bool(true),
LinkRemap: proto.Bool(true),
ExtMasters: proto.Bool(true),
}
// GenerateCRIUConfContent generates the criu.conf file content for options
// that cannot be passed via RPC.
func (c *CRIUSettings) GenerateCRIUConfContent() string {
var content string
// Optional: ghost limit from env (0 = use CRIU default)
if cfg.GhostLimit > 0 {
criuOpts.GhostLimit = proto.Uint32(cfg.GhostLimit)
if c.LibDir != "" {
content += "libdir " + c.LibDir + "\n"
}
// Optional: timeout from env (0 = no timeout)
if cfg.Timeout > 0 {
criuOpts.Timeout = proto.Uint32(cfg.Timeout)
if c.AllowUprobes {
content += "allow-uprobes\n"
}
if c.SkipInFlight {
content += "skip-in-flight\n"
}
return criuOpts
return content
}
// AddExternalMounts adds mount points as external mounts to CRIU options.
// CRIU requires all mounts to be marked as external for successful restore.
func AddExternalMounts(criuOpts *criurpc.CriuOpts, mounts []AllMountInfo) {
addedMounts := make(map[string]bool)
// ExternalMountManifestEntry is a serializable CRIU ext-mount entry in checkpoint manifests.
type ExternalMountManifestEntry struct {
Key string `yaml:"key"`
Val string `yaml:"val"`
}
for _, m := range mounts {
if addedMounts[m.MountPoint] {
continue
}
criuOpts.ExtMnt = append(criuOpts.ExtMnt, &criurpc.ExtMountMap{
Key: proto.String(m.MountPoint),
Val: proto.String(m.MountPoint),
})
addedMounts[m.MountPoint] = true
}
// CRIUDumpManifest stores the resolved dump-time CRIU mount plan used for restore.
type CRIUDumpManifest struct {
CRIU CRIUSettings `yaml:"criu"`
ExtMnt []ExternalMountManifestEntry `yaml:"extMnt,omitempty"`
External []string `yaml:"external,omitempty"`
SkipMnt []string `yaml:"skipMnt,omitempty"`
}
// AddExternalPaths adds additional paths (masked/readonly) as external mounts.
// These may not appear in mountinfo but CRIU still needs them marked as external.
func AddExternalPaths(criuOpts *criurpc.CriuOpts, paths []string) {
// Build set of existing mount points
existing := make(map[string]bool)
for _, m := range criuOpts.ExtMnt {
existing[m.GetKey()] = true
// NewCRIUDumpManifest serializes resolved dump options for restore.
func NewCRIUDumpManifest(criuOpts *criurpc.CriuOpts, settings CRIUSettings) CRIUDumpManifest {
manifest := CRIUDumpManifest{CRIU: settings}
if criuOpts == nil {
return manifest
}
for _, path := range paths {
if existing[path] {
for _, mount := range criuOpts.ExtMnt {
if mount == nil || mount.GetKey() == "" {
continue
}
criuOpts.ExtMnt = append(criuOpts.ExtMnt, &criurpc.ExtMountMap{
Key: proto.String(path),
Val: proto.String(path),
manifest.ExtMnt = append(manifest.ExtMnt, ExternalMountManifestEntry{
Key: mount.GetKey(),
Val: mount.GetVal(),
})
existing[path] = true
}
manifest.External = append([]string(nil), criuOpts.External...)
manifest.SkipMnt = append([]string(nil), criuOpts.SkipMnt...)
return manifest
}
// AddExternalNamespace adds a namespace as external to CRIU options.
// Format: "<type>[<inode>]:<key>"
func AddExternalNamespace(criuOpts *criurpc.CriuOpts, nsType NamespaceType, inode uint64, key string) {
extNs := fmt.Sprintf("%s[%d]:%s", nsType, inode, key)
criuOpts.External = append(criuOpts.External, extNs)
}
// BuildCRIUDumpOptions creates CRIU options directly from spec settings and runtime state.
func BuildCRIUDumpOptions(
settings *CRIUSettings,
pid int,
imageDirFD int32,
rootFS string,
mountInfo []MountInfo,
ociSpec *specs.Spec,
namespaces map[NamespaceType]*NamespaceInfo,
) (*criurpc.CriuOpts, error) {
mountPolicy := BuildMountPolicy(mountInfo, ociSpec, rootFS)
// AddExternalStrings adds raw external strings to CRIU options.
// Used for additional external mount mappings (e.g., NVIDIA firmware files).
func AddExternalStrings(criuOpts *criurpc.CriuOpts, externals []string) {
criuOpts.External = append(criuOpts.External, externals...)
}
extMnt := buildExternalMountMaps(mountPolicy.Externalized)
skipMnt := mountPolicy.Skipped
external := buildExternalNamespaces(namespaces)
logrus.WithFields(logrus.Fields{
"externalized_count": len(mountPolicy.Externalized),
"skipped_count": len(mountPolicy.Skipped),
}).Debug("Resolved mount policy for CRIU dump")
criuOpts := &criurpc.CriuOpts{
Pid: proto.Int32(int32(pid)),
ImagesDirFd: proto.Int32(imageDirFD),
Root: proto.String(rootFS),
LogFile: proto.String(DumpLogFilename),
}
criuOpts.ExtMnt = extMnt
criuOpts.External = external
criuOpts.SkipMnt = skipMnt
if settings == nil {
return criuOpts, nil
}
// RPC options from spec.
criuOpts.LogLevel = proto.Int32(settings.LogLevel)
criuOpts.LeaveRunning = proto.Bool(settings.LeaveRunning)
criuOpts.ShellJob = proto.Bool(settings.ShellJob)
criuOpts.TcpClose = proto.Bool(settings.TcpClose)
criuOpts.FileLocks = proto.Bool(settings.FileLocks)
criuOpts.OrphanPtsMaster = proto.Bool(settings.OrphanPtsMaster)
criuOpts.ExtUnixSk = proto.Bool(settings.ExtUnixSk)
criuOpts.LinkRemap = proto.Bool(settings.LinkRemap)
criuOpts.ExtMasters = proto.Bool(settings.ExtMasters)
criuOpts.AutoDedup = proto.Bool(settings.AutoDedup)
criuOpts.LazyPages = proto.Bool(settings.LazyPages)
// Cgroup management mode
criuOpts.ManageCgroups = proto.Bool(true)
cgMode := criurpc.CriuCgMode_IGNORE
switch settings.ManageCgroupsMode {
case "soft":
cgMode = criurpc.CriuCgMode_SOFT
case "full":
cgMode = criurpc.CriuCgMode_FULL
case "strict":
cgMode = criurpc.CriuCgMode_STRICT
}
criuOpts.ManageCgroupsMode = &cgMode
// ConfigureExternalMounts adds all required external mounts to CRIU options.
// This includes mounts from /proc/pid/mountinfo plus masked/readonly paths from OCI spec.
func ConfigureExternalMounts(criuOpts *criurpc.CriuOpts, pid int, hostProc string, containerInfo *checkpointk8s.ContainerInfo) error {
// Get all mounts from mountinfo - CRIU needs every mount marked as external
allMounts, err := GetAllMountsFromMountinfo(pid, hostProc)
if err != nil {
return fmt.Errorf("failed to get all mounts from mountinfo: %w", err)
// Optional numeric options
if settings.GhostLimit > 0 {
criuOpts.GhostLimit = proto.Uint32(settings.GhostLimit)
}
if settings.Timeout > 0 {
criuOpts.Timeout = proto.Uint32(settings.Timeout)
}
// Add mounts from mountinfo
AddExternalMounts(criuOpts, allMounts)
return criuOpts, nil
}
// Add masked and readonly paths from OCI spec
AddExternalPaths(criuOpts, containerInfo.GetMaskedPaths())
AddExternalPaths(criuOpts, containerInfo.GetReadonlyPaths())
// buildExternalMountMaps serializes externalized mount paths into CRIU map entries.
func buildExternalMountMaps(paths []string) []*criurpc.ExtMountMap {
extMnt := make([]*criurpc.ExtMountMap, 0, len(paths))
existing := make(map[string]struct{}, len(paths))
for _, path := range paths {
if path == "" {
continue
}
if _, ok := existing[path]; ok {
continue
}
extMnt = append(extMnt, &criurpc.ExtMountMap{
Key: proto.String(path),
Val: proto.String(path),
})
existing[path] = struct{}{}
}
return nil
return extMnt
}
// ConfigureExternalNamespaces adds external namespaces to CRIU options.
// Returns the network namespace inode if found, for logging purposes.
func ConfigureExternalNamespaces(criuOpts *criurpc.CriuOpts, namespaces map[NamespaceType]*NamespaceInfo, externalMounts []string) uint64 {
var netNsInode uint64
// buildExternalNamespaces builds external namespace/mount references.
func buildExternalNamespaces(namespaces map[NamespaceType]*NamespaceInfo) []string {
external := make([]string, 0, 1)
// Mark network namespace as external for socket binding preservation
if netNs, ok := namespaces[NamespaceNet]; ok {
AddExternalNamespace(criuOpts, NamespaceNet, netNs.Inode, "extNetNs")
netNsInode = netNs.Inode
external = append(external, fmt.Sprintf("%s[%d]:%s", NamespaceNet, netNs.Inode, "extNetNs"))
logrus.WithField("inode", netNs.Inode).Debug("Marked network namespace as external")
}
// Add additional external mounts (e.g., for NVIDIA firmware files)
AddExternalStrings(criuOpts, externalMounts)
return netNsInode
return external
}
// BuildCRIUOptsFromCheckpointOpts constructs CRIU options from checkpoint Options.
// Returns the configured CriuOpts ready for external mount/namespace configuration.
func BuildCRIUOptsFromCheckpointOpts(opts Options, pid int, imageDirFD int32, rootFS string) *criurpc.CriuOpts {
cfg := CRIUConfig{
PID: pid,
ImageDirFD: imageDirFD,
RootFS: rootFS,
GhostLimit: opts.GhostLimit,
Timeout: opts.Timeout,
// ExecuteCRIUDump runs the CRIU dump and logs timing plus dump-log location on failure.
func ExecuteCRIUDump(criuOpts *criurpc.CriuOpts, checkpointDir string, log *logrus.Entry) (time.Duration, error) {
criuDumpStart := time.Now()
criuClient := criu.MakeCriu()
if err := criuClient.Dump(criuOpts, nil); err != nil {
dumpDuration := time.Since(criuDumpStart)
log.WithFields(logrus.Fields{
"duration": dumpDuration,
"checkpoint_dir": checkpointDir,
"dump_log_path": fmt.Sprintf("%s/%s", checkpointDir, DumpLogFilename),
}).Error("CRIU dump failed")
return 0, fmt.Errorf("CRIU dump failed: %w", err)
}
return BuildCRIUOpts(cfg)
criuDumpDuration := time.Since(criuDumpStart)
log.WithField("duration", criuDumpDuration).Info("CRIU dump completed")
return criuDumpDuration, nil
}
// rootfs provides container rootfs introspection via /proc for CRIU checkpoint.
// filesystem.go provides container rootfs introspection, filesystem config/metadata types,
// and rootfs diff capture for CRIU checkpoint.
package checkpoint
import (
"bufio"
"encoding/json"
"fmt"
"os"
......@@ -10,131 +10,152 @@ import (
"path/filepath"
"strings"
specs "github.com/opencontainers/runtime-spec/specs-go"
"github.com/sirupsen/logrus"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
)
// GetRootFS returns the container's root filesystem path
// For containers using overlayfs, this extracts the upperdir
func GetRootFS(pid int, hostProc string) (string, error) {
if hostProc == "" {
hostProc = "/proc"
}
// FilesystemConfig is the static config for rootfs exclusions (from values.yaml).
type FilesystemConfig struct {
// SystemDirs are system directories that should be excluded from rootfs diff.
// These directories are typically injected/bind-mounted by NVIDIA GPU Operator
// at container start time, so they already exist in the restore target.
// Excluding them prevents conflicts (especially socket files which cannot be overwritten).
// Default: ["./usr", "./etc", "./opt", "./var", "./run"]
SystemDirs []string `yaml:"systemDirs"`
// CacheDirs are cache directories that can safely be excluded to reduce checkpoint size.
// Model weights and other cached data are typically re-downloaded if needed.
// Default: ["./.cache/huggingface", "./.cache/torch"]
CacheDirs []string `yaml:"cacheDirs"`
// AdditionalExclusions are custom paths to exclude from the rootfs diff.
// Use this for application-specific exclusions.
// Paths should be relative with "./" prefix (e.g., "./data/temp").
AdditionalExclusions []string `yaml:"additionalExclusions"`
}
// The rootfs is accessible via /proc/<pid>/root
// But for CRIU, we need the actual filesystem path
rootPath := fmt.Sprintf("%s/%d/root", hostProc, pid)
// GetAllExclusions returns all exclusion paths combined.
// This is used when building tar arguments for rootfs diff capture.
func (c *FilesystemConfig) GetAllExclusions() []string {
if c == nil {
return nil
}
total := len(c.SystemDirs) + len(c.CacheDirs) + len(c.AdditionalExclusions)
exclusions := make([]string, 0, total)
exclusions = append(exclusions, c.SystemDirs...)
exclusions = append(exclusions, c.CacheDirs...)
exclusions = append(exclusions, c.AdditionalExclusions...)
return exclusions
}
// Verify it exists
if _, err := os.Stat(rootPath); err != nil {
return "", fmt.Errorf("rootfs not accessible at %s: %w", rootPath, err)
// Validate checks that the FilesystemConfig has valid values.
func (c *FilesystemConfig) Validate() error {
if c == nil {
return nil
}
// All paths should start with "./" for tar relative path handling
for _, path := range c.GetAllExclusions() {
if !strings.HasPrefix(path, "./") {
return &ConfigError{
Field: "rootfsExclusions",
Message: "all exclusion paths must start with './' (got: " + path + ")",
}
}
}
return nil
}
return rootPath, nil
// FilesystemManifest holds runtime filesystem state captured at checkpoint time.
type FilesystemManifest struct {
Exclusions FilesystemConfig `yaml:"exclusions"`
UpperDir string `yaml:"upperDir,omitempty"`
ExternalPaths []string `yaml:"externalPaths,omitempty"`
BindMountDests []string `yaml:"bindMountDests,omitempty"`
HasRootfsDiff bool `yaml:"hasRootfsDiff"`
HasDeletedFiles bool `yaml:"hasDeletedFiles"`
}
// GetOverlayUpperDir extracts the overlay upperdir from mountinfo
// This is the writable layer of the container's filesystem
func GetOverlayUpperDir(pid int, hostProc string) (string, error) {
if hostProc == "" {
hostProc = "/proc"
// NewFilesystemManifest constructs FilesystemManifest from config, overlay state, and OCI spec.
func NewFilesystemManifest(exclusions FilesystemConfig, upperDir string, ociSpec *specs.Spec) FilesystemManifest {
meta := FilesystemManifest{
Exclusions: exclusions,
UpperDir: upperDir,
}
mountinfoPath := fmt.Sprintf("%s/%d/mountinfo", hostProc, pid)
file, err := os.Open(mountinfoPath)
if err != nil {
return "", fmt.Errorf("failed to open mountinfo: %w", err)
if ociSpec == nil {
return meta
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Fields(line)
// Look for the root mount (mount point is /)
// mountinfo format: id parent major:minor root mount-point options ... - fstype source super-options
if len(fields) < 5 {
continue
if ociSpec.Linux != nil {
meta.ExternalPaths = make([]string, 0, len(ociSpec.Linux.MaskedPaths)+len(ociSpec.Linux.ReadonlyPaths))
meta.ExternalPaths = append(meta.ExternalPaths, ociSpec.Linux.MaskedPaths...)
meta.ExternalPaths = append(meta.ExternalPaths, ociSpec.Linux.ReadonlyPaths...)
}
for _, m := range ociSpec.Mounts {
if m.Type == "bind" {
meta.BindMountDests = append(meta.BindMountDests, m.Destination)
}
}
return meta
}
mountPoint := fields[4]
if mountPoint != "/" {
continue
}
// GetRootFS returns the container's root filesystem path.
func GetRootFS(pid int) (string, error) {
rootPath := fmt.Sprintf("%s/%d/root", HostProcPath, pid)
// Find the separator (-) to get fstype and options
sepIdx := -1
for i, f := range fields {
if f == "-" {
sepIdx = i
break
}
}
if _, err := os.Stat(rootPath); err != nil {
return "", fmt.Errorf("rootfs not accessible at %s: %w", rootPath, err)
}
if sepIdx == -1 || sepIdx+2 >= len(fields) {
continue
}
return rootPath, nil
}
fsType := fields[sepIdx+1]
if fsType != "overlay" {
// GetOverlayUpperDir extracts the overlay upperdir from mountinfo.
// This is the writable layer of the container's filesystem.
func GetOverlayUpperDir(pid int) (string, error) {
mountInfo, err := ReadMountInfoFromHostProcPath(pid)
if err != nil {
return "", fmt.Errorf("failed to parse mountinfo: %w", err)
}
for _, mount := range mountInfo {
if mount.MountPoint != "/" || mount.FSType != "overlay" {
continue
}
// Parse super options to find upperdir
superOptions := fields[sepIdx+3]
for _, opt := range strings.Split(superOptions, ",") {
for _, opt := range strings.Split(mount.SuperOptions, ",") {
if strings.HasPrefix(opt, "upperdir=") {
return strings.TrimPrefix(opt, "upperdir="), nil
}
}
}
if err := scanner.Err(); err != nil {
return "", fmt.Errorf("error reading mountinfo: %w", err)
}
return "", fmt.Errorf("overlay upperdir not found for pid %d", pid)
}
// DefaultRootfsDiffExclusions are paths excluded from the rootfs diff capture.
// These directories are injected/bind-mounted by NVIDIA GPU Operator at container
// start time, so they already exist in the restore target and cause conflicts
// (especially socket files which cannot be overwritten).
var DefaultRootfsDiffExclusions = []string{
// NVIDIA GPU Operator injects drivers, libraries, and config here
"./usr",
"./etc",
"./opt",
"./var",
// NVIDIA GPU Operator creates runtime sockets and firmware mounts here
// Socket files cause fatal tar errors even with --keep-old-files
"./run",
}
// CaptureRootfsDiff captures the overlay upperdir to a tar file.
// The upperdir contains all filesystem modifications made by the container.
// Excludes bind mount destinations and system directories to avoid conflicts during restore.
// Excludes bind mount destinations and configured directories to avoid conflicts during restore.
// Returns the path to the tar file or empty string if capture failed.
func CaptureRootfsDiff(upperDir, checkpointDir string, excludePaths []string) (string, error) {
func CaptureRootfsDiff(upperDir, checkpointDir string, exclusions *FilesystemConfig, bindMountDests []string) (string, error) {
if upperDir == "" {
return "", fmt.Errorf("upperdir is empty")
}
rootfsDiffPath := filepath.Join(checkpointDir, "rootfs-diff.tar")
rootfsDiffPath := filepath.Join(checkpointDir, RootfsDiffFilename)
// Build tar arguments with xattrs and exclusions
tarArgs := []string{"--xattrs"}
// Add default exclusions for system directories and caches
for _, excl := range DefaultRootfsDiffExclusions {
tarArgs = append(tarArgs, "--exclude="+excl)
// Add configured exclusions (systemDirs, cacheDirs, additionalExclusions from values.yaml)
if exclusions != nil {
for _, excl := range exclusions.GetAllExclusions() {
tarArgs = append(tarArgs, "--exclude="+excl)
}
}
// Add bind mount exclusions passed from caller
for _, dest := range excludePaths {
for _, dest := range bindMountDests {
// Convert absolute path to relative for tar (e.g., /etc/hosts -> ./etc/hosts)
tarArgs = append(tarArgs, "--exclude=."+dest)
}
......@@ -165,7 +186,7 @@ func CaptureDeletedFiles(upperDir, checkpointDir string) (bool, error) {
return false, nil
}
deletedFilesPath := filepath.Join(checkpointDir, "deleted-files.json")
deletedFilesPath := filepath.Join(checkpointDir, DeletedFilesFilename)
data, err := json.Marshal(whiteouts)
if err != nil {
return false, fmt.Errorf("failed to marshal whiteouts: %w", err)
......@@ -195,11 +216,11 @@ func FindWhiteoutFiles(upperDir string) ([]string, error) {
relPath, _ := filepath.Rel(upperDir, path)
dir := filepath.Dir(relPath)
deletedFile := strings.TrimPrefix(name, ".wh.")
if dir == "." {
whiteouts = append(whiteouts, deletedFile)
} else {
whiteouts = append(whiteouts, filepath.Join(dir, deletedFile))
deletedPath := deletedFile
if dir != "." {
deletedPath = filepath.Join(dir, deletedFile)
}
whiteouts = append(whiteouts, deletedPath)
}
return nil
})
......@@ -208,23 +229,23 @@ func FindWhiteoutFiles(upperDir string) ([]string, error) {
}
// CaptureRootfsState captures the overlay upperdir and deleted files after CRIU dump.
// Updates the metadata with rootfs diff information and saves it.
func CaptureRootfsState(upperDir, checkpointDir string, meta *common.CheckpointMetadata, log *logrus.Entry) {
if upperDir == "" {
// Updates the checkpoint manifest with rootfs diff information and saves it.
func CaptureRootfsState(upperDir, checkpointDir string, data *CheckpointManifest, log *logrus.Entry) {
if upperDir == "" || data == nil {
return
}
// Capture rootfs diff
// Capture rootfs diff using exclusions from the checkpoint manifest.
configuredExclusions := data.Filesystem.Exclusions.GetAllExclusions()
log.WithFields(logrus.Fields{
"default_exclusions": DefaultRootfsDiffExclusions,
"bind_mount_exclusions": meta.BindMountDests,
"configured_exclusions": configuredExclusions,
"bind_mount_exclusions": data.Filesystem.BindMountDests,
}).Debug("Rootfs diff exclusions")
rootfsDiffPath, err := CaptureRootfsDiff(upperDir, checkpointDir, meta.BindMountDests)
rootfsDiffPath, err := CaptureRootfsDiff(upperDir, checkpointDir, &data.Filesystem.Exclusions, data.Filesystem.BindMountDests)
if err != nil {
log.WithError(err).Warn("Failed to capture rootfs diff")
} else {
meta.RootfsDiffPath = rootfsDiffPath
meta.HasRootfsDiff = true
data.Filesystem.HasRootfsDiff = true
log.WithFields(logrus.Fields{
"upperdir": upperDir,
"tar_path": rootfsDiffPath,
......@@ -236,12 +257,12 @@ func CaptureRootfsState(upperDir, checkpointDir string, meta *common.CheckpointM
if err != nil {
log.WithError(err).Warn("Failed to capture deleted files")
} else if hasDeletedFiles {
meta.HasDeletedFiles = true
data.Filesystem.HasDeletedFiles = true
log.Info("Recorded deleted files (whiteouts)")
}
// Update metadata with rootfs diff info
if err := common.SaveMetadata(checkpointDir, meta); err != nil {
log.WithError(err).Warn("Failed to update metadata with rootfs diff info")
// Update checkpoint manifest with rootfs diff info.
if err := WriteCheckpointManifest(checkpointDir, data); err != nil {
log.WithError(err).Warn("Failed to update checkpoint manifest with rootfs diff info")
}
}
// k8s contains containerd discovery and Kubernetes path classification helpers.
package checkpoint
import (
"context"
"fmt"
"github.com/containerd/containerd"
"github.com/containerd/containerd/namespaces"
specs "github.com/opencontainers/runtime-spec/specs-go"
)
const (
// K8sNamespace is the containerd namespace used by Kubernetes.
K8sNamespace = "k8s.io"
// ContainerdSocket is the default containerd socket path.
ContainerdSocket = "/run/containerd/containerd.sock"
)
type SourcePodManifest struct {
ContainerID string `yaml:"containerId"`
PID int `yaml:"pid"`
SourceNode string `yaml:"sourceNode"`
PodName string `yaml:"podName"`
PodNamespace string `yaml:"podNamespace"`
}
func NewSourcePodManifest(params CheckpointRequest, pid int) SourcePodManifest {
return SourcePodManifest{
ContainerID: params.ContainerID,
PID: pid,
SourceNode: params.NodeName,
PodName: params.PodName,
PodNamespace: params.PodNamespace,
}
}
type DiscoveryClient struct {
client *containerd.Client
}
func NewDiscoveryClient() (*DiscoveryClient, error) {
client, err := containerd.New(ContainerdSocket)
if err != nil {
return nil, fmt.Errorf("failed to connect to containerd at %s: %w", ContainerdSocket, err)
}
return &DiscoveryClient{client: client}, nil
}
func (c *DiscoveryClient) Close() error {
if c.client != nil {
return c.client.Close()
}
return nil
}
func (c *DiscoveryClient) ResolveContainer(ctx context.Context, containerID string) (int, *specs.Spec, error) {
ctx = namespaces.WithNamespace(ctx, K8sNamespace)
container, err := c.client.LoadContainer(ctx, containerID)
if err != nil {
return 0, nil, fmt.Errorf("failed to load container %s: %w", containerID, err)
}
task, err := container.Task(ctx, nil)
if err != nil {
return 0, nil, fmt.Errorf("failed to get task for container %s: %w", containerID, err)
}
pid := task.Pid()
spec, err := container.Spec(ctx)
if err != nil {
return 0, nil, fmt.Errorf("failed to get spec for container %s: %w", containerID, err)
}
return int(pid), spec, nil
}
// discovery provides container information resolution via containerd.
// This prefers containerd RPCs for configuration over /proc inspection,
// following the principle that configuration should come from the container runtime
// while runtime state (like namespace inodes) requires /proc.
package k8s
import (
"context"
"fmt"
"os"
"path/filepath"
"github.com/containerd/containerd"
"github.com/containerd/containerd/namespaces"
specs "github.com/opencontainers/runtime-spec/specs-go"
)
const (
// K8sNamespace is the containerd namespace used by Kubernetes
K8sNamespace = "k8s.io"
// DefaultSocket is the default containerd socket path
DefaultSocket = "/run/containerd/containerd.sock"
)
// ContainerInfo holds resolved container information from containerd.
// Configuration data comes from containerd RPCs, runtime state from /proc.
type ContainerInfo struct {
ContainerID string
PID uint32
RootFS string // Actual rootfs path (bundle path + spec.Root.Path)
BundlePath string // Path to container bundle directory
Image string
Spec *specs.Spec // OCI spec from containerd (mounts, namespaces config)
Labels map[string]string
}
// MountInfo represents a mount from the OCI spec.
type MountInfo struct {
Destination string // Mount point inside container
Source string // Source path on host
Type string // Filesystem type (bind, tmpfs, etc.)
Options []string // Mount options
}
// NamespaceConfig represents namespace configuration from OCI spec.
type NamespaceConfig struct {
Type string // Namespace type (network, pid, mount, etc.)
Path string // Path to namespace (empty for new namespace)
}
// DiscoveryClient wraps the containerd client for container discovery.
type DiscoveryClient struct {
client *containerd.Client
socket string
}
// NewDiscoveryClient creates a new discovery client.
func NewDiscoveryClient(socket string) (*DiscoveryClient, error) {
if socket == "" {
socket = DefaultSocket
}
client, err := containerd.New(socket)
if err != nil {
return nil, fmt.Errorf("failed to connect to containerd at %s: %w", socket, err)
}
return &DiscoveryClient{
client: client,
socket: socket,
}, nil
}
// Close closes the containerd client connection.
func (c *DiscoveryClient) Close() error {
if c.client != nil {
return c.client.Close()
}
return nil
}
// ResolveContainer resolves a container ID to its process information.
// This retrieves configuration from containerd RPCs (OCI spec, labels, image)
// and runtime paths from /proc (rootfs access path).
func (c *DiscoveryClient) ResolveContainer(ctx context.Context, containerID string) (*ContainerInfo, error) {
// Use the Kubernetes namespace for containerd
ctx = namespaces.WithNamespace(ctx, K8sNamespace)
// Load the container
container, err := c.client.LoadContainer(ctx, containerID)
if err != nil {
return nil, fmt.Errorf("failed to load container %s: %w", containerID, err)
}
// Get the task (running process)
task, err := container.Task(ctx, nil)
if err != nil {
return nil, fmt.Errorf("failed to get task for container %s: %w", containerID, err)
}
// Get the PID
pid := task.Pid()
// Get container image
image, err := container.Image(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get image for container %s: %w", containerID, err)
}
// Get OCI spec from containerd - this contains mount config, namespace config, etc.
// This is preferred over parsing /proc for configuration data.
spec, err := container.Spec(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get spec for container %s: %w", containerID, err)
}
// Get container labels (includes K8s pod info)
labels, err := container.Labels(ctx)
if err != nil {
// Labels are optional, don't fail
labels = make(map[string]string)
}
// Construct the bundle path where containerd stores the container runtime files
// Standard containerd layout: /run/containerd/io.containerd.runtime.v2.task/<namespace>/<container_id>/
containerdRunRoot := os.Getenv("CONTAINERD_RUN_ROOT")
if containerdRunRoot == "" {
containerdRunRoot = "/run/containerd"
}
bundlePath := filepath.Join(containerdRunRoot, "io.containerd.runtime.v2.task", K8sNamespace, containerID)
// Get the rootfs path from the OCI spec (usually "rootfs" relative to bundle)
rootfsRelPath := "rootfs"
if spec.Root != nil && spec.Root.Path != "" {
rootfsRelPath = spec.Root.Path
}
// Construct full rootfs path
var rootFS string
if filepath.IsAbs(rootfsRelPath) {
rootFS = rootfsRelPath
} else {
rootFS = filepath.Join(bundlePath, rootfsRelPath)
}
return &ContainerInfo{
ContainerID: containerID,
PID: pid,
RootFS: rootFS,
BundlePath: bundlePath,
Image: image.Name(),
Spec: spec,
Labels: labels,
}, nil
}
// GetMounts returns the mount configuration from the OCI spec.
// This is preferred over parsing /proc/mountinfo for configuration,
// though /proc is still needed for runtime mount state.
func (info *ContainerInfo) GetMounts() []MountInfo {
if info.Spec == nil || info.Spec.Mounts == nil {
return nil
}
mounts := make([]MountInfo, len(info.Spec.Mounts))
for i, m := range info.Spec.Mounts {
mounts[i] = MountInfo{
Destination: m.Destination,
Source: m.Source,
Type: m.Type,
Options: m.Options,
}
}
return mounts
}
// GetNamespaces returns the namespace configuration from the OCI spec.
func (info *ContainerInfo) GetNamespaces() []NamespaceConfig {
if info.Spec == nil || info.Spec.Linux == nil {
return nil
}
namespaces := make([]NamespaceConfig, len(info.Spec.Linux.Namespaces))
for i, ns := range info.Spec.Linux.Namespaces {
namespaces[i] = NamespaceConfig{
Type: string(ns.Type),
Path: ns.Path,
}
}
return namespaces
}
// GetMaskedPaths returns the masked paths from the OCI spec.
func (info *ContainerInfo) GetMaskedPaths() []string {
if info.Spec == nil || info.Spec.Linux == nil {
return nil
}
return info.Spec.Linux.MaskedPaths
}
// GetReadonlyPaths returns the readonly paths from the OCI spec.
func (info *ContainerInfo) GetReadonlyPaths() []string {
if info.Spec == nil || info.Spec.Linux == nil {
return nil
}
return info.Spec.Linux.ReadonlyPaths
}
// GetRootfsPath returns the rootfs path from the OCI spec.
// Note: For CRIU, use info.RootFS which is the /proc/<pid>/root path.
func (info *ContainerInfo) GetRootfsPath() string {
if info.Spec == nil || info.Spec.Root == nil {
return ""
}
return info.Spec.Root.Path
}
// IsRootReadonly returns whether the root filesystem is readonly.
func (info *ContainerInfo) IsRootReadonly() bool {
if info.Spec == nil || info.Spec.Root == nil {
return false
}
return info.Spec.Root.Readonly
}
// GetHostname returns the container's hostname from the OCI spec.
func (info *ContainerInfo) GetHostname() string {
if info.Spec == nil {
return ""
}
return info.Spec.Hostname
}
// ListContainers lists all containers in the K8s namespace.
func (c *DiscoveryClient) ListContainers(ctx context.Context) ([]string, error) {
ctx = namespaces.WithNamespace(ctx, K8sNamespace)
containers, err := c.client.Containers(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list containers: %w", err)
}
ids := make([]string, len(containers))
for i, container := range containers {
ids[i] = container.ID()
}
return ids, nil
}
// GetContainerLabels returns the labels for a container.
func (c *DiscoveryClient) GetContainerLabels(ctx context.Context, containerID string) (map[string]string, error) {
ctx = namespaces.WithNamespace(ctx, K8sNamespace)
container, err := c.client.LoadContainer(ctx, containerID)
if err != nil {
return nil, fmt.Errorf("failed to load container %s: %w", containerID, err)
}
labels, err := container.Labels(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get labels for container %s: %w", containerID, err)
}
return labels, nil
}
// Package k8s provides Kubernetes-specific functionality for checkpoint operations.
// This includes volume type discovery via K8s API and containerd container discovery.
package k8s
import (
"context"
"fmt"
"strings"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd"
)
// VolumeInfo contains Kubernetes volume information for a mount.
type VolumeInfo struct {
VolumeName string // Name from pod.spec.volumes[].name
VolumeType string // Type: emptyDir, configMap, secret, persistentVolumeClaim, etc.
MountPath string // Container path from volumeMounts[].mountPath
SubPath string // SubPath if specified
ReadOnly bool // Whether mount is read-only
// Type-specific details
ConfigMapName string // For configMap volumes
SecretName string // For secret volumes
PVCName string // For persistentVolumeClaim volumes
}
// K8sClient wraps the Kubernetes clientset for volume discovery.
type K8sClient struct {
clientset *kubernetes.Clientset
}
// NewK8sClient creates a new Kubernetes client.
// It attempts in-cluster config first, then falls back to kubeconfig.
func NewK8sClient() (*K8sClient, error) {
config, err := rest.InClusterConfig()
if err != nil {
// Fall back to kubeconfig for local development
config, err = clientcmd.BuildConfigFromFlags("", clientcmd.RecommendedHomeFile)
if err != nil {
return nil, fmt.Errorf("failed to create k8s config: %w", err)
}
}
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
return nil, fmt.Errorf("failed to create k8s clientset: %w", err)
}
return &K8sClient{clientset: clientset}, nil
}
// NewK8sClientWithConfig creates a client with explicit config.
func NewK8sClientWithConfig(config *rest.Config) (*K8sClient, error) {
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
return nil, fmt.Errorf("failed to create k8s clientset: %w", err)
}
return &K8sClient{clientset: clientset}, nil
}
// GetPodVolumes returns volume information for all mounts in a container.
// Returns a map from mount path to VolumeInfo.
func (c *K8sClient) GetPodVolumes(ctx context.Context, namespace, podName, containerName string) (map[string]*VolumeInfo, error) {
pod, err := c.clientset.CoreV1().Pods(namespace).Get(ctx, podName, metav1.GetOptions{})
if err != nil {
return nil, fmt.Errorf("failed to get pod %s/%s: %w", namespace, podName, err)
}
return ExtractVolumeInfo(pod, containerName)
}
// ExtractVolumeInfo extracts volume information from a Pod spec.
// This is the core logic that maps volumeMounts to volumes and determines types.
func ExtractVolumeInfo(pod *corev1.Pod, containerName string) (map[string]*VolumeInfo, error) {
// Build volume name -> type mapping from pod.spec.volumes
volumeTypes := make(map[string]*volumeDetails)
for _, vol := range pod.Spec.Volumes {
volumeTypes[vol.Name] = getVolumeDetails(&vol)
}
// Find the target container
var container *corev1.Container
for i := range pod.Spec.Containers {
if pod.Spec.Containers[i].Name == containerName {
container = &pod.Spec.Containers[i]
break
}
}
if container == nil {
// Try init containers
for i := range pod.Spec.InitContainers {
if pod.Spec.InitContainers[i].Name == containerName {
container = &pod.Spec.InitContainers[i]
break
}
}
}
if container == nil {
return nil, fmt.Errorf("container %s not found in pod", containerName)
}
// Build mount path -> volume info mapping
result := make(map[string]*VolumeInfo)
for _, mount := range container.VolumeMounts {
details, ok := volumeTypes[mount.Name]
if !ok {
continue // Mount references unknown volume
}
result[mount.MountPath] = &VolumeInfo{
VolumeName: mount.Name,
VolumeType: details.volumeType,
MountPath: mount.MountPath,
SubPath: mount.SubPath,
ReadOnly: mount.ReadOnly,
ConfigMapName: details.configMapName,
SecretName: details.secretName,
PVCName: details.pvcName,
}
}
return result, nil
}
// volumeDetails holds extracted volume type information.
type volumeDetails struct {
volumeType string
configMapName string
secretName string
pvcName string
}
// getVolumeDetails extracts type and details from a Volume spec.
func getVolumeDetails(vol *corev1.Volume) *volumeDetails {
d := &volumeDetails{volumeType: "unknown"}
switch {
case vol.EmptyDir != nil:
d.volumeType = "emptyDir"
case vol.ConfigMap != nil:
d.volumeType = "configMap"
d.configMapName = vol.ConfigMap.Name
case vol.Secret != nil:
d.volumeType = "secret"
d.secretName = vol.Secret.SecretName
case vol.PersistentVolumeClaim != nil:
d.volumeType = "persistentVolumeClaim"
d.pvcName = vol.PersistentVolumeClaim.ClaimName
case vol.HostPath != nil:
d.volumeType = "hostPath"
case vol.Projected != nil:
d.volumeType = "projected"
case vol.DownwardAPI != nil:
d.volumeType = "downwardAPI"
case vol.CSI != nil:
d.volumeType = "csi"
case vol.NFS != nil:
d.volumeType = "nfs"
case vol.ISCSI != nil:
d.volumeType = "iscsi"
case vol.GCEPersistentDisk != nil:
d.volumeType = "gcePersistentDisk"
case vol.AWSElasticBlockStore != nil:
d.volumeType = "awsElasticBlockStore"
case vol.AzureDisk != nil:
d.volumeType = "azureDisk"
case vol.AzureFile != nil:
d.volumeType = "azureFile"
case vol.CephFS != nil:
d.volumeType = "cephfs"
case vol.Cinder != nil:
d.volumeType = "cinder"
case vol.FC != nil:
d.volumeType = "fc"
case vol.FlexVolume != nil:
d.volumeType = "flexVolume"
case vol.Flocker != nil:
d.volumeType = "flocker"
case vol.GitRepo != nil:
d.volumeType = "gitRepo"
case vol.Glusterfs != nil:
d.volumeType = "glusterfs"
case vol.PhotonPersistentDisk != nil:
d.volumeType = "photonPersistentDisk"
case vol.PortworxVolume != nil:
d.volumeType = "portworxVolume"
case vol.Quobyte != nil:
d.volumeType = "quobyte"
case vol.RBD != nil:
d.volumeType = "rbd"
case vol.ScaleIO != nil:
d.volumeType = "scaleIO"
case vol.StorageOS != nil:
d.volumeType = "storageos"
case vol.VsphereVolume != nil:
d.volumeType = "vsphereVolume"
case vol.Ephemeral != nil:
d.volumeType = "ephemeral"
}
return d
}
// DetectVolumeTypeFromPath attempts to identify volume type from kubelet path patterns.
// This is a best-effort fallback; accurate volume types require K8s API access via GetPodVolumes.
func DetectVolumeTypeFromPath(hostPath string) (volumeType, volumeName string) {
volumeType = "unknown"
volumeName = ""
// Map of path patterns to volume types
patterns := map[string]string{
"/kubernetes.io~empty-dir/": "emptyDir",
"/kubernetes.io~configmap/": "configMap",
"/kubernetes.io~secret/": "secret",
"/kubernetes.io~projected/": "projected",
"/kubernetes.io~downward-api/": "downwardAPI",
"/kubernetes.io~persistentvolumeclaim/": "persistentVolumeClaim",
"/kubernetes.io~hostpath/": "hostPath",
}
for pattern, vType := range patterns {
if strings.Contains(hostPath, pattern) {
volumeType = vType
// Extract volume name from path
parts := strings.Split(hostPath, pattern)
if len(parts) > 1 {
volumeName = strings.Split(parts[1], "/")[0]
}
break
}
}
return volumeType, volumeName
}
// metadata_builder provides checkpoint metadata construction.
package checkpoint
import (
"context"
"strings"
"github.com/sirupsen/logrus"
checkpointk8s "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint/k8s"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
)
// MetadataBuilderConfig holds configuration for building checkpoint metadata.
type MetadataBuilderConfig struct {
CheckpointID string
NodeName string
ContainerID string
ContainerName string
PodName string
PodNamespace string
PID int
CUDAPluginDir string
}
// BuildCheckpointMetadata constructs checkpoint metadata from container state.
func BuildCheckpointMetadata(
ctx context.Context,
cfg MetadataBuilderConfig,
containerInfo *checkpointk8s.ContainerInfo,
mounts []MountMapping,
namespaces map[NamespaceType]*NamespaceInfo,
k8sClient *checkpointk8s.K8sClient,
log *logrus.Entry,
) *common.CheckpointMetadata {
meta := common.NewCheckpointMetadata(cfg.CheckpointID)
meta.SourceNode = cfg.NodeName
meta.ContainerID = cfg.ContainerID
meta.PodName = cfg.PodName
meta.PodNamespace = cfg.PodNamespace
meta.PID = cfg.PID
meta.Image = containerInfo.Image
// Populate OCI spec derived paths
meta.MaskedPaths = containerInfo.GetMaskedPaths()
meta.ReadonlyPaths = containerInfo.GetReadonlyPaths()
// Build mount metadata
ociMountByDest := buildOCIMountLookup(containerInfo, meta)
// Get K8s volume types if available
k8sVolumes := getK8sVolumes(ctx, k8sClient, cfg, log)
// Add mount metadata
for _, mount := range mounts {
mountMeta := buildMountMetadata(mount, k8sVolumes, ociMountByDest)
meta.Mounts = append(meta.Mounts, mountMeta)
}
// Add namespace metadata
for nsType, nsInfo := range namespaces {
meta.Namespaces = append(meta.Namespaces, common.NamespaceMetadata{
Type: string(nsType),
Inode: nsInfo.Inode,
IsExternal: nsInfo.IsExternal,
})
}
// Set CRIU options (hardcoded as always-on for K8s, stored for compatibility)
meta.CRIUOptions = common.CRIUOptionsMetadata{
TcpEstablished: false, // Always false - we close TCP connections
TcpClose: true, // Always true - pod IPs change on restore
ShellJob: true, // Always true - containers are session leaders
FileLocks: true, // Always true - apps use file locks
LeaveRunning: true, // Always true - keep process running after checkpoint
LinkRemap: true, // Always true - handle deleted-but-open files
ExtMasters: true, // Always true - external bind mount masters
}
return meta
}
// buildOCIMountLookup builds a lookup map from OCI mounts and populates bind mount destinations.
func buildOCIMountLookup(containerInfo *checkpointk8s.ContainerInfo, meta *common.CheckpointMetadata) map[string]checkpointk8s.MountInfo {
ociMounts := containerInfo.GetMounts()
ociMountByDest := make(map[string]checkpointk8s.MountInfo)
for _, m := range ociMounts {
ociMountByDest[m.Destination] = m
if m.Type == "bind" {
meta.BindMountDests = append(meta.BindMountDests, m.Destination)
}
}
return ociMountByDest
}
// getK8sVolumes fetches volume types from K8s API if available.
func getK8sVolumes(ctx context.Context, k8sClient *checkpointk8s.K8sClient, cfg MetadataBuilderConfig, log *logrus.Entry) map[string]*checkpointk8s.VolumeInfo {
if k8sClient == nil || cfg.PodNamespace == "" || cfg.PodName == "" || cfg.ContainerName == "" {
return nil
}
k8sVolumes, err := k8sClient.GetPodVolumes(ctx, cfg.PodNamespace, cfg.PodName, cfg.ContainerName)
if err != nil {
log.WithError(err).Warn("Failed to get volume types from K8s API, falling back to path-based detection")
return nil
}
log.WithField("volume_count", len(k8sVolumes)).Debug("Got volume types from K8s API")
return k8sVolumes
}
// buildMountMetadata constructs metadata for a single mount.
func buildMountMetadata(mount MountMapping, k8sVolumes map[string]*checkpointk8s.VolumeInfo, ociMountByDest map[string]checkpointk8s.MountInfo) common.MountMetadata {
var volumeType, volumeName string
// Try K8s API first for accurate volume types
if k8sVolumes != nil {
if volInfo, ok := k8sVolumes[mount.InsidePath]; ok {
volumeType = volInfo.VolumeType
volumeName = volInfo.VolumeName
}
}
// Fall back to path-based detection if K8s API didn't provide info
if volumeType == "" {
volumeType, volumeName = checkpointk8s.DetectVolumeTypeFromPath(mount.OutsidePath)
}
mountMeta := common.MountMetadata{
ContainerPath: mount.InsidePath,
HostPath: mount.OutsidePath,
VolumeType: volumeType,
VolumeName: volumeName,
FSType: mount.FSType,
ReadOnly: strings.Contains(mount.Options, "ro"),
}
// Cross-reference with OCI spec mount if available
if ociMount, ok := ociMountByDest[mount.InsidePath]; ok {
mountMeta.OCISource = ociMount.Source
mountMeta.OCIType = ociMount.Type
mountMeta.OCIOptions = ociMount.Options
}
return mountMeta
}
// mounts provides mount parsing from /proc for CRIU checkpoint.
// This is used for runtime mount state that requires /proc inspection.
// mounts parses runtime mount state from /proc.
package checkpoint
import (
"bufio"
"fmt"
"os"
"path"
"path/filepath"
"strings"
)
// MountMapping represents an external mount for CRIU
type MountMapping struct {
InsidePath string // Path inside container (mount point)
OutsidePath string // Path on host (source)
FSType string // Filesystem type
Source string // Mount source
Options string // Mount options
}
specs "github.com/opencontainers/runtime-spec/specs-go"
// System mount types that should be filtered out
var systemMountTypes = map[string]bool{
"proc": true,
"sysfs": true,
"devpts": true,
"mqueue": true,
"tmpfs": true, // Note: some tmpfs mounts may need special handling
"cgroup": true,
"cgroup2": true,
"securityfs": true,
"debugfs": true,
"tracefs": true,
"fusectl": true,
"configfs": true,
"devtmpfs": true,
"hugetlbfs": true,
"pstore": true,
"bpf": true,
}
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
)
// System mount paths that should always be filtered
var systemMountPaths = map[string]bool{
"/proc": true,
"/sys": true,
"/dev": true,
"/dev/pts": true,
"/dev/shm": true,
"/dev/mqueue": true,
"/run": true,
"/run/secrets": true,
type MountInfo struct {
MountID string
ParentID string
MountPoint string
Root string
FSType string
Source string
Options string
SuperOptions string
}
// ParseMountInfo parses /proc/<pid>/mountinfo and returns bind mounts
// that need to be handled by CRIU as external mounts
func ParseMountInfo(pid int, hostProc string) ([]MountMapping, error) {
if hostProc == "" {
hostProc = "/proc"
}
mountinfoPath := fmt.Sprintf("%s/%d/mountinfo", hostProc, pid)
file, err := os.Open(mountinfoPath)
if err != nil {
return nil, fmt.Errorf("failed to open mountinfo: %w", err)
}
defer file.Close()
var mounts []MountMapping
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
mount, skip := parseMountInfoLine(line)
if skip {
continue
}
mounts = append(mounts, mount)
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading mountinfo: %w", err)
}
return mounts, nil
// MountPolicy is the classified mount plan for CRIU dump options.
type MountPolicy struct {
Externalized []string
Skipped []string
}
// parseMountInfoLine parses a single line from mountinfo
// Returns the mount mapping and whether to skip this mount
// BuildMountPolicy classifies mounts into CRIU extMnt and skipMnt lists.
//
// mountinfo format:
// 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
// (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
// Rule order and precedence (top to bottom):
// 1. Skip non-OCI proc/sys submounts and non-OCI runtime /run submounts.
// These mounts are typically node/kernel/runtime specific and are the
// highest-risk source of cross-node restore failures, so skip wins.
// 2. Externalize mounts owned by runtime/OCI:
// - "/" (rootfs is recreated by runtime in OCI restore path)
// - OCI mount destinations
// - OCI masked/readonly paths
// 3. Externalize non-OCI bind-like mounts (mount root is not "/" or ".").
// This captures runtime-injected file mounts (for example driver files)
// so CRIU does not try to recreate them from checkpoint data.
// 4. Anything else is left unflagged and handled by CRIU default behavior.
//
// (1) mount ID
// (2) parent ID
// (3) major:minor
// (4) root: root of the mount within the filesystem (host-side path for bind mounts)
// (5) mount point: mount point relative to process's root
// (6) mount options
// (7) optional fields (terminated by single hyphen)
// (8) separator (hyphen)
// (9) filesystem type
// (10) mount source (device)
// (11) super options
func parseMountInfoLine(line string) (MountMapping, bool) {
fields := strings.Fields(line)
if len(fields) < 10 {
return MountMapping{}, true
}
// Precedence: skip > externalize. If a path is classified as skipped, it is
// removed from the externalized set.
func BuildMountPolicy(mountInfo []MountInfo, ociSpec *specs.Spec, rootFS string) *MountPolicy {
ociManagedSet := collectOCIManagedDestinations(ociSpec, rootFS)
root := fields[3] // Host-side path within the filesystem (important for bind mounts)
mountPoint := fields[4] // Container-side mount point
mountOptions := fields[5]
externalizedSet := make(map[string]struct{}, len(mountInfo)+len(ociManagedSet))
skippedSet := make(map[string]struct{}, len(mountInfo))
// Find separator (-) to get fstype and source
sepIdx := -1
for i, f := range fields {
if f == "-" {
sepIdx = i
break
for _, mount := range mountInfo {
mp := normalizeMountPath(mount.MountPoint)
if mp == "" {
continue
}
}
if sepIdx == -1 || sepIdx+2 >= len(fields) {
return MountMapping{}, true
}
source := path.Clean(strings.TrimSpace(mount.Source))
root := path.Clean(strings.TrimSpace(mount.Root))
isOCIManaged := false
if _, ok := ociManagedSet[mp]; ok {
isOCIManaged = true
}
if !isOCIManaged && strings.HasPrefix(mp, "/run/") {
if _, ok := ociManagedSet["/var"+mp]; ok {
isOCIManaged = true
}
}
if !isOCIManaged && strings.HasPrefix(mp, "/var/run/") {
if _, ok := ociManagedSet[strings.TrimPrefix(mp, "/var")]; ok {
isOCIManaged = true
}
}
fsType := fields[sepIdx+1]
source := fields[sepIdx+2]
superOptions := ""
if sepIdx+3 < len(fields) {
superOptions = fields[sepIdx+3]
}
// Runtime-owned /run mounts are usually ephemeral tmpfs/overlay mounts
// or bind-like mounts sourced from host runtime directories.
// We skip these unless OCI explicitly manages that destination.
isRunRuntimeMount := strings.HasPrefix(mp, "/run/") &&
(mount.FSType == "tmpfs" ||
mount.FSType == "overlay" ||
strings.HasPrefix(source, "/run/") ||
strings.HasPrefix(source, "/var/run/") ||
strings.HasPrefix(root, "/run/") ||
strings.HasPrefix(root, "/var/run/"))
if !isOCIManaged && (strings.HasPrefix(mp, "/proc/") || strings.HasPrefix(mp, "/sys/") || isRunRuntimeMount) {
skippedSet[mp] = struct{}{}
delete(externalizedSet, mp)
continue
}
// Skip system mount types
if systemMountTypes[fsType] {
return MountMapping{}, true
if mp == "/" || isOCIManaged || (root != "." && root != "/") {
externalizedSet[mp] = struct{}{}
continue
}
}
// Skip system mount paths
if systemMountPaths[mountPoint] {
return MountMapping{}, true
// Ensure OCI-managed destinations are externalized, even when mountinfo does not
// include a direct entry (e.g., runtime-managed masked/readonly paths).
for mp := range ociManagedSet {
if _, skipped := skippedSet[mp]; skipped {
continue
}
externalizedSet[mp] = struct{}{}
}
// Skip /sys and /proc prefixed paths
if strings.HasPrefix(mountPoint, "/sys/") || strings.HasPrefix(mountPoint, "/proc/") {
return MountMapping{}, true
externalized := make([]string, 0, len(externalizedSet))
for mp := range externalizedSet {
externalized = append(externalized, mp)
}
// Skip overlay (the root filesystem itself)
if fsType == "overlay" && mountPoint == "/" {
return MountMapping{}, true
skipped := make([]string, 0, len(skippedSet))
for mp := range skippedSet {
skipped = append(skipped, mp)
}
// For bind mounts, the root field contains the actual host path
// Use root as OutsidePath since it gives us the host-side path for volume mounts
outsidePath := root
if root == "/" {
// If root is /, this isn't a bind mount from a subdirectory
outsidePath = source
return &MountPolicy{
Externalized: externalized,
Skipped: skipped,
}
return MountMapping{
InsidePath: mountPoint,
OutsidePath: outsidePath,
FSType: fsType,
Source: source,
Options: mountOptions + "," + superOptions,
}, false
}
// GetBindMounts returns only bind mounts (type "bind" or with bind option)
func GetBindMounts(pid int, hostProc string) ([]MountMapping, error) {
mounts, err := ParseMountInfo(pid, hostProc)
if err != nil {
return nil, err
// collectOCIManagedDestinations returns the canonical set of OCI-owned mount
// targets. This includes regular OCI mounts plus Linux masked/readonly paths.
// Those masked/readonly paths may not appear as direct mountinfo entries, but
// still need to be treated as runtime-owned and externalized.
func collectOCIManagedDestinations(ociSpec *specs.Spec, rootFS string) map[string]struct{} {
set := map[string]struct{}{}
if ociSpec == nil {
return set
}
var bindMounts []MountMapping
for _, m := range mounts {
// Bind mounts typically show the underlying filesystem type
// and have paths that look like kubelet volume paths
if strings.Contains(m.OutsidePath, "/var/lib/kubelet/pods/") ||
strings.Contains(m.OutsidePath, "/volumes/") ||
strings.Contains(m.Options, "bind") {
bindMounts = append(bindMounts, m)
paths := make([]string, 0, len(ociSpec.Mounts))
for _, mount := range ociSpec.Mounts {
paths = append(paths, mount.Destination)
}
if ociSpec.Linux != nil {
paths = append(paths, ociSpec.Linux.MaskedPaths...)
paths = append(paths, ociSpec.Linux.ReadonlyPaths...)
}
for _, raw := range paths {
if p := normalizeOCIDestinationPath(raw, rootFS); p != "" {
set[p] = struct{}{}
}
}
return bindMounts, nil
return set
}
// GetKubernetesVolumeMounts returns mounts that appear to be Kubernetes volumes
func GetKubernetesVolumeMounts(pid int, hostProc string) ([]MountMapping, error) {
mounts, err := ParseMountInfo(pid, hostProc)
if err != nil {
return nil, err
// normalizeMountPath applies lexical normalization only.
// Mountinfo paths are already kernel truth for the container namespace.
func normalizeMountPath(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
var k8sMounts []MountMapping
for _, m := range mounts {
// Kubernetes volumes are identified by:
// 1. Standard kubelet paths: /var/lib/kubelet/pods/
// 2. Minikube/Docker paths: /var/lib/docker/volumes/minikube/_data/lib/kubelet/pods/
// 3. Kubernetes volume markers: kubernetes.io~empty-dir, kubernetes.io~configmap, etc.
if strings.Contains(m.OutsidePath, "/kubelet/pods/") ||
strings.Contains(m.OutsidePath, "/kubernetes.io~") ||
strings.Contains(m.OutsidePath, "/containerd/io.containerd") {
k8sMounts = append(k8sMounts, m)
}
p := path.Clean(raw)
if !strings.HasPrefix(p, "/") {
p = "/" + p
}
return k8sMounts, nil
return path.Clean(p)
}
// AllMountInfo represents a mount entry from /proc/<pid>/mountinfo
// This includes ALL mounts without filtering, which CRIU captures during checkpoint.
type AllMountInfo struct {
MountID string // Mount ID
ParentID string // Parent mount ID
MountPoint string // Mount point inside container (container-side path)
Root string // Root of mount within filesystem (host-side path for bind mounts)
FSType string // Filesystem type
Source string // Mount source
Options string // Mount options
SuperOptions string // Super block options
}
// GetAllMountsFromMountinfo parses /proc/<pid>/mountinfo and returns ALL mounts.
// This is used for CRIU checkpoint to mark ALL mounts as external, since CRIU
// captures everything from mountinfo, not just the filtered subset.
// Without marking ALL mounts as external, CRIU restore fails with
// "No mapping for <mount_id>:(null) mountpoint" errors.
func GetAllMountsFromMountinfo(pid int, hostProc string) ([]AllMountInfo, error) {
if hostProc == "" {
hostProc = "/proc"
// normalizeOCIDestinationPath canonicalizes OCI destinations against container
// rootfs symlinks (for example /var/run -> /run) with lexical fallback.
func normalizeOCIDestinationPath(raw, rootFS string) string {
p := normalizeMountPath(raw)
if p == "" || rootFS == "" {
return p
}
mountinfoPath := fmt.Sprintf("%s/%d/mountinfo", hostProc, pid)
file, err := os.Open(mountinfoPath)
hostPath := filepath.Join(rootFS, strings.TrimPrefix(p, "/"))
resolved, err := filepath.EvalSymlinks(hostPath)
if err != nil {
return nil, fmt.Errorf("failed to open mountinfo: %w", err)
return p
}
defer file.Close()
var mounts []AllMountInfo
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
mount, err := parseAllMountInfoLine(line)
if err != nil {
continue // Skip malformed lines
}
mounts = append(mounts, mount)
rel, err := filepath.Rel(rootFS, resolved)
if err != nil {
return p
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading mountinfo: %w", err)
rel = filepath.ToSlash(rel)
if rel == "." {
return "/"
}
return mounts, nil
}
// parseAllMountInfoLine parses a single line from mountinfo without filtering.
// mountinfo format:
// 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
// (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
func parseAllMountInfoLine(line string) (AllMountInfo, error) {
fields := strings.Fields(line)
if len(fields) < 10 {
return AllMountInfo{}, fmt.Errorf("malformed mountinfo line: %s", line)
if strings.HasPrefix(rel, "../") || rel == ".." {
return p
}
mountID := fields[0]
parentID := fields[1]
root := fields[3] // Host-side path within the filesystem
mountPoint := fields[4] // Container-side mount point
mountOptions := fields[5]
// Find separator (-) to get fstype and source
sepIdx := -1
for i, f := range fields {
if f == "-" {
sepIdx = i
break
}
}
return normalizeMountPath("/" + rel)
}
if sepIdx == -1 || sepIdx+2 >= len(fields) {
return AllMountInfo{}, fmt.Errorf("malformed mountinfo line (no separator): %s", line)
func ReadMountInfoFromHostProcPath(pid int) ([]MountInfo, error) {
mountinfoPath := fmt.Sprintf("%s/%d/mountinfo", HostProcPath, pid)
parsedMounts, err := common.ParseMountInfoFile(mountinfoPath)
if err != nil {
return nil, fmt.Errorf("failed to parse mountinfo at %s: %w", mountinfoPath, err)
}
fsType := fields[sepIdx+1]
source := fields[sepIdx+2]
superOptions := ""
if sepIdx+3 < len(fields) {
superOptions = fields[sepIdx+3]
mounts := make([]MountInfo, 0, len(parsedMounts))
for _, parsed := range parsedMounts {
mounts = append(mounts, MountInfo{
MountID: parsed.MountID,
ParentID: parsed.ParentID,
MountPoint: parsed.Path,
Root: parsed.Root,
FSType: parsed.FSType,
Source: parsed.Source,
Options: parsed.Options,
SuperOptions: parsed.SuperOpts,
})
}
return AllMountInfo{
MountID: mountID,
ParentID: parentID,
MountPoint: mountPoint,
Root: root,
FSType: fsType,
Source: source,
Options: mountOptions,
SuperOptions: superOptions,
}, nil
return mounts, nil
}
......@@ -3,12 +3,17 @@ package checkpoint
import (
"fmt"
"os"
"strings"
"golang.org/x/sys/unix"
)
// NamespaceManifestEntry stores namespace information saved in checkpoint manifests.
type NamespaceManifestEntry struct {
Type string `yaml:"type"` // net, pid, mnt, etc.
Inode uint64 `yaml:"inode"` // Namespace inode
IsExternal bool `yaml:"isExternal"` // Whether namespace is external (shared)
}
// NamespaceType represents a Linux namespace type
type NamespaceType string
......@@ -26,17 +31,29 @@ const (
type NamespaceInfo struct {
Type NamespaceType
Inode uint64
Path string
IsExternal bool // Whether NS is external (shared with pause container)
}
// GetNamespaceInode returns the inode number for a namespace
func GetNamespaceInode(pid int, nsType NamespaceType, hostProc string) (uint64, error) {
if hostProc == "" {
hostProc = "/proc"
// NewNamespaceManifestEntries constructs namespace manifest entries from introspected namespaces.
func NewNamespaceManifestEntries(namespaces map[NamespaceType]*NamespaceInfo) []NamespaceManifestEntry {
if len(namespaces) == 0 {
return nil
}
nsPath := fmt.Sprintf("%s/%d/ns/%s", hostProc, pid, nsType)
result := make([]NamespaceManifestEntry, 0, len(namespaces))
for nsType, nsInfo := range namespaces {
result = append(result, NamespaceManifestEntry{
Type: string(nsType),
Inode: nsInfo.Inode,
IsExternal: nsInfo.IsExternal,
})
}
return result
}
// GetNamespaceInode returns the inode number for a namespace
func GetNamespaceInode(pid int, nsType NamespaceType) (uint64, error) {
nsPath := fmt.Sprintf("%s/%d/ns/%s", HostProcPath, pid, nsType)
var stat unix.Stat_t
if err := unix.Stat(nsPath, &stat); err != nil {
return 0, fmt.Errorf("failed to stat namespace %s: %w", nsPath, err)
......@@ -46,12 +63,8 @@ func GetNamespaceInode(pid int, nsType NamespaceType, hostProc string) (uint64,
}
// GetNamespaceInfo returns detailed namespace information
func GetNamespaceInfo(pid int, nsType NamespaceType, hostProc string) (*NamespaceInfo, error) {
if hostProc == "" {
hostProc = "/proc"
}
nsPath := fmt.Sprintf("%s/%d/ns/%s", hostProc, pid, nsType)
func GetNamespaceInfo(pid int, nsType NamespaceType) (*NamespaceInfo, error) {
nsPath := fmt.Sprintf("%s/%d/ns/%s", HostProcPath, pid, nsType)
// Get inode
var stat unix.Stat_t
......@@ -59,14 +72,8 @@ func GetNamespaceInfo(pid int, nsType NamespaceType, hostProc string) (*Namespac
return nil, fmt.Errorf("failed to stat namespace %s: %w", nsPath, err)
}
// Read the symlink to get the namespace identifier
link, err := os.Readlink(nsPath)
if err != nil {
return nil, fmt.Errorf("failed to readlink %s: %w", nsPath, err)
}
// Check if this is different from init's namespace (PID 1)
initNsPath := fmt.Sprintf("%s/1/ns/%s", hostProc, nsType)
initNsPath := fmt.Sprintf("%s/1/ns/%s", HostProcPath, nsType)
var initStat unix.Stat_t
isExternal := false
if err := unix.Stat(initNsPath, &initStat); err == nil {
......@@ -77,13 +84,12 @@ func GetNamespaceInfo(pid int, nsType NamespaceType, hostProc string) (*Namespac
return &NamespaceInfo{
Type: nsType,
Inode: stat.Ino,
Path: link,
IsExternal: isExternal,
}, nil
}
// GetAllNamespaces returns information about all namespaces for a process
func GetAllNamespaces(pid int, hostProc string) (map[NamespaceType]*NamespaceInfo, error) {
func GetAllNamespaces(pid int) (map[NamespaceType]*NamespaceInfo, error) {
nsTypes := []NamespaceType{
NamespaceNet,
NamespacePID,
......@@ -96,66 +102,10 @@ func GetAllNamespaces(pid int, hostProc string) (map[NamespaceType]*NamespaceInf
namespaces := make(map[NamespaceType]*NamespaceInfo)
for _, nsType := range nsTypes {
info, err := GetNamespaceInfo(pid, nsType, hostProc)
if err != nil {
// Some namespaces might not exist, skip them
continue
if info, err := GetNamespaceInfo(pid, nsType); err == nil {
namespaces[nsType] = info
}
namespaces[nsType] = info
}
return namespaces, nil
}
// IsNetNamespaceExternal checks if the network namespace is external
// (i.e., shared with the pause container in Kubernetes)
func IsNetNamespaceExternal(pid int, hostProc string) (bool, uint64, error) {
info, err := GetNamespaceInfo(pid, NamespaceNet, hostProc)
if err != nil {
return false, 0, err
}
return info.IsExternal, info.Inode, nil
}
// IsPIDNamespaceExternal checks if the PID namespace is external
func IsPIDNamespaceExternal(pid int, hostProc string) (bool, uint64, error) {
info, err := GetNamespaceInfo(pid, NamespacePID, hostProc)
if err != nil {
return false, 0, err
}
return info.IsExternal, info.Inode, nil
}
// OpenNamespaceFD opens a file descriptor to a namespace
// The caller is responsible for closing the returned file
func OpenNamespaceFD(pid int, nsType NamespaceType, hostProc string) (*os.File, error) {
if hostProc == "" {
hostProc = "/proc"
}
nsPath := fmt.Sprintf("%s/%d/ns/%s", hostProc, pid, nsType)
return os.Open(nsPath)
}
// FormatExternalNamespace formats namespace info for CRIU's External option
// Format: <type>[<inode>]:<key>
func FormatExternalNamespace(nsType NamespaceType, inode uint64) string {
key := formatNamespaceKey(nsType)
return fmt.Sprintf("%s[%d]:%s", nsType, inode, key)
}
// formatNamespaceKey creates the CRIU key for external namespaces
// Format: extRoot<Type>NS (e.g., extRootNetNS, extRootPidNS)
func formatNamespaceKey(nsType NamespaceType) string {
// Capitalize first letter of namespace type
nsName := string(nsType)
if len(nsName) > 0 {
nsName = strings.ToUpper(nsName[:1]) + nsName[1:]
}
return "extRoot" + nsName + "NS"
}
// GetNamespaceKey returns the CRIU key for a namespace type
func GetNamespaceKey(nsType NamespaceType) string {
return formatNamespaceKey(nsType)
}
// Package checkpoint provides CRIU checkpoint (dump) operations.
package checkpoint
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/sirupsen/logrus"
)
// CaptureDevShm captures files from /dev/shm to the checkpoint directory.
// This is needed because /dev/shm is a tmpfs mount that is not part of the
// container's overlay filesystem, so rootfs diff doesn't capture it.
//
// Semaphores (sem.* files) are included so that sem_unlink() calls succeed
// after restore. The semaphore kernel state won't be perfectly restored,
// but the files will exist for cleanup operations.
//
// The files are saved to <checkpointDir>/dev-shm/ and can be restored
// using RestoreDevShm before CRIU restore.
func CaptureDevShm(pid int, checkpointDir string, log *logrus.Entry) error {
// Access container's /dev/shm via /proc/<pid>/root
shmPath := filepath.Join(HostProcPath, fmt.Sprintf("%d/root/dev/shm", pid))
entries, err := os.ReadDir(shmPath)
if err != nil {
if os.IsNotExist(err) {
log.Debug("Container /dev/shm does not exist, skipping capture")
return nil
}
return fmt.Errorf("failed to read container /dev/shm: %w", err)
}
// Filter out directories
var filesToCapture []os.DirEntry
for _, entry := range entries {
// Skip directories (unlikely in /dev/shm but be safe)
if entry.IsDir() {
log.WithField("dir", entry.Name()).Debug("Skipping directory in /dev/shm")
continue
}
filesToCapture = append(filesToCapture, entry)
}
if len(filesToCapture) == 0 {
log.Debug("No files to capture from /dev/shm")
return nil
}
// Create destination directory
destDir := filepath.Join(checkpointDir, DevShmDirName)
if err := os.MkdirAll(destDir, 0755); err != nil {
return fmt.Errorf("failed to create dev-shm directory: %w", err)
}
var captured []string
var totalSize int64
for _, entry := range filesToCapture {
name := entry.Name()
srcPath := filepath.Join(shmPath, name)
destPath := filepath.Join(destDir, name)
info, err := entry.Info()
if err != nil {
log.WithError(err).WithField("file", name).Warn("Failed to get file info, skipping")
continue
}
size := info.Size()
// Copy the file
if err := copyFile(srcPath, destPath, info.Mode()); err != nil {
log.WithError(err).WithField("file", name).Warn("Failed to copy file, skipping")
continue
}
captured = append(captured, name)
totalSize += size
log.WithFields(logrus.Fields{
"file": name,
"size": size,
}).Debug("Captured /dev/shm file")
}
if len(captured) > 0 {
log.WithFields(logrus.Fields{
"count": len(captured),
"total_size": totalSize,
"files": captured,
}).Info("Captured /dev/shm files")
}
return nil
}
// copyFile copies a file from src to dest with the given permissions.
func copyFile(src, dest string, mode os.FileMode) error {
srcFile, err := os.Open(src)
if err != nil {
return fmt.Errorf("failed to open source: %w", err)
}
defer srcFile.Close()
destFile, err := os.OpenFile(dest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode)
if err != nil {
return fmt.Errorf("failed to create destination: %w", err)
}
defer destFile.Close()
if _, err := io.Copy(destFile, srcFile); err != nil {
return fmt.Errorf("failed to copy contents: %w", err)
}
// Sync to ensure durability for checkpoint data
if err := destFile.Sync(); err != nil {
return fmt.Errorf("failed to sync destination: %w", err)
}
return nil
}
// storage.go provides checkpoint storage I/O: write/read manifests, listing, deletion.
package checkpoint
import (
"fmt"
"os"
"path/filepath"
"strings"
"gopkg.in/yaml.v3"
)
// WriteCheckpointManifest writes a checkpoint manifest file in the checkpoint directory.
func WriteCheckpointManifest(checkpointDir string, data *CheckpointManifest) error {
content, err := yaml.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal checkpoint manifest: %w", err)
}
manifestPath := filepath.Join(checkpointDir, CheckpointManifestFilename)
if err := os.WriteFile(manifestPath, content, 0600); err != nil {
return fmt.Errorf("failed to write checkpoint manifest: %w", err)
}
return nil
}
// ReadCheckpointManifest reads checkpoint manifest from a checkpoint directory.
func ReadCheckpointManifest(checkpointDir string) (*CheckpointManifest, error) {
manifestPath := filepath.Join(checkpointDir, CheckpointManifestFilename)
content, err := os.ReadFile(manifestPath)
if err != nil {
return nil, fmt.Errorf("failed to read checkpoint manifest: %w", err)
}
var data CheckpointManifest
if err := yaml.Unmarshal(content, &data); err != nil {
return nil, fmt.Errorf("failed to unmarshal checkpoint manifest: %w", err)
}
return &data, nil
}
// SaveDescriptors writes file descriptor information to the checkpoint directory.
func SaveDescriptors(checkpointDir string, descriptors []string) error {
content, err := yaml.Marshal(descriptors)
if err != nil {
return fmt.Errorf("failed to marshal descriptors: %w", err)
}
descriptorsPath := filepath.Join(checkpointDir, DescriptorsFilename)
if err := os.WriteFile(descriptorsPath, content, 0600); err != nil {
return fmt.Errorf("failed to write descriptors file: %w", err)
}
return nil
}
// LoadDescriptors reads file descriptor information from checkpoint directory.
func LoadDescriptors(checkpointDir string) ([]string, error) {
descriptorsPath := filepath.Join(checkpointDir, DescriptorsFilename)
content, err := os.ReadFile(descriptorsPath)
if err != nil {
return nil, fmt.Errorf("failed to read descriptors file: %w", err)
}
var descriptors []string
if err := yaml.Unmarshal(content, &descriptors); err != nil {
return nil, fmt.Errorf("failed to unmarshal descriptors: %w", err)
}
return descriptors, nil
}
// ListCheckpoints returns all checkpoint IDs in the base directory.
func ListCheckpoints(baseDir string) ([]string, error) {
entries, err := os.ReadDir(baseDir)
if err != nil {
return nil, fmt.Errorf("failed to read checkpoint directory: %w", err)
}
var checkpoints []string
for _, entry := range entries {
if !entry.IsDir() {
continue
}
// Check if manifest file exists.
manifestPath := filepath.Join(baseDir, entry.Name(), CheckpointManifestFilename)
if _, err := os.Stat(manifestPath); err == nil {
checkpoints = append(checkpoints, entry.Name())
}
}
return checkpoints, nil
}
// DeleteCheckpoint removes a checkpoint directory.
func DeleteCheckpoint(baseDir, checkpointID string) error {
checkpointDir := filepath.Join(baseDir, checkpointID)
// Ensure resolved path is within baseDir to prevent path traversal
absBase, _ := filepath.Abs(baseDir)
absDir, _ := filepath.Abs(checkpointDir)
if !strings.HasPrefix(absDir, absBase+string(filepath.Separator)) && absDir != absBase {
return fmt.Errorf("invalid checkpoint ID: resolved path %s is outside base directory %s", absDir, absBase)
}
return os.RemoveAll(checkpointDir)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment