Unverified Commit d381e6ff authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

feat(chrek): config refactor, /dev/shm support, and mount-policy rewrite (#5946)

parent b6824ae0
...@@ -90,9 +90,6 @@ class Config: ...@@ -90,9 +90,6 @@ class Config:
# Use vLLM's tokenizer for pre/post processing # Use vLLM's tokenizer for pre/post processing
use_vllm_tokenizer: bool = False use_vllm_tokenizer: bool = False
# sleep mode support (enable_sleep_mode comes from vLLM's engine_args)
sleep_mode_level: int = 1
# Whether to enable NATS for KV events (derived from kv_events_config in overwrite_args) # Whether to enable NATS for KV events (derived from kv_events_config in overwrite_args)
use_kv_events: bool = False use_kv_events: bool = False
...@@ -301,13 +298,6 @@ def parse_args() -> Config: ...@@ -301,13 +298,6 @@ def parse_args() -> Config:
default=False, default=False,
help="Use vLLM's tokenizer for pre and post processing. This bypasses Dynamo's preprocessor and only v1/chat/completions will be available through the Dynamo frontend.", help="Use vLLM's tokenizer for pre and post processing. This bypasses Dynamo's preprocessor and only v1/chat/completions will be available through the Dynamo frontend.",
) )
parser.add_argument(
"--sleep-mode-level",
type=int,
default=1,
choices=[1, 2, 3],
help="Sleep mode level (1=offload to CPU, 2=discard weights, 3=discard all). Default: 1",
)
add_config_dump_args(parser) add_config_dump_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
...@@ -454,7 +444,6 @@ def parse_args() -> Config: ...@@ -454,7 +444,6 @@ def parse_args() -> Config:
config.enable_local_indexer = not args.durable_kv_events config.enable_local_indexer = not args.durable_kv_events
# For omni mode, use vLLM (AsyncOmni) tokenizer on backend # For omni mode, use vLLM (AsyncOmni) tokenizer on backend
config.use_vllm_tokenizer = args.use_vllm_tokenizer or args.omni config.use_vllm_tokenizer = args.use_vllm_tokenizer or args.omni
config.sleep_mode_level = args.sleep_mode_level
# use_kv_events is set later in overwrite_args() based on kv_events_config # use_kv_events is set later in overwrite_args() based on kv_events_config
# Validate custom Jinja template file exists if provided # Validate custom Jinja template file exists if provided
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Checkpoint/restore (chrek) integration for vLLM workers.
Handles the checkpoint job pod lifecycle:
1. Early exit if a checkpoint already exists (idempotency)
2. Sleep model for CRIU-friendly GPU state
3. Signal readiness for DaemonSet to begin checkpoint
4. Poll for checkpoint completion or CRIU restore detection
5. Wake model after restore
Environment variables (all required in checkpoint mode, no fallbacks):
- DYN_CHECKPOINT_SIGNAL_FILE: Path where DaemonSet writes completion signal
- DYN_READY_FOR_CHECKPOINT_FILE: Path where this worker writes readiness marker
- DYN_CHECKPOINT_STORAGE_TYPE: Storage backend (pvc, s3, oci)
- DYN_CHECKPOINT_LOCATION: Full checkpoint path (for idempotency check)
- DYN_RESTORE_MARKER_FILE: Path written by restore-entrypoint before CRIU restore
"""
import asyncio
import json
import logging
import os
from typing import Optional
logger = logging.getLogger(__name__)
_REQUIRED_ENV_VARS = [
"DYN_CHECKPOINT_SIGNAL_FILE",
"DYN_READY_FOR_CHECKPOINT_FILE",
"DYN_CHECKPOINT_STORAGE_TYPE",
"DYN_CHECKPOINT_LOCATION",
"DYN_RESTORE_MARKER_FILE",
]
class CheckpointConfig:
"""Parsed and validated checkpoint configuration from environment variables."""
def __init__(self):
self.signal_file = os.environ["DYN_CHECKPOINT_SIGNAL_FILE"]
self.ready_file = os.environ["DYN_READY_FOR_CHECKPOINT_FILE"]
self.storage_type = os.environ["DYN_CHECKPOINT_STORAGE_TYPE"]
self.location = os.environ["DYN_CHECKPOINT_LOCATION"]
self.restore_marker = os.environ["DYN_RESTORE_MARKER_FILE"]
def _read_status_file(self, path: str) -> dict:
with open(path) as f:
status = json.load(f)
success = status.get("success")
if not isinstance(success, bool):
raise ValueError(f"missing or invalid success field in {path}")
return status
def checkpoint_exists(self) -> bool:
"""Check if a completed checkpoint already exists (idempotency).
For PVC storage, checks for checkpoint.done marker at the location.
Returns True if the job should exit without loading the model.
"""
assert (
self.storage_type == "pvc"
), "Checkpoint existence check is only implemented for PVC storage"
if self.storage_type == "pvc" and self.location:
done_marker = f"{self.location}/checkpoint.done"
if os.path.exists(done_marker):
try:
status = self._read_status_file(done_marker)
except (OSError, ValueError, json.JSONDecodeError) as exc:
logger.warning(
f"Invalid checkpoint.done marker at {done_marker}, ignoring stale checkpoint: {exc}"
)
return False
if status["success"]:
logger.info(
f"Existing successful checkpoint found at {self.location}, skipping"
)
return True
logger.warning(
f"Existing checkpoint marker reports failure at {self.location}: "
f"{status.get('error', 'unknown error')}"
)
return False
logger.info(f"No checkpoint at {self.location}, creating new one")
return False
async def run_lifecycle(self, engine_client, sleep_level: int) -> bool:
"""Run the full checkpoint lifecycle after the engine is loaded.
1. Put model to sleep (CRIU-friendly GPU state)
2. Write ready file (triggers DaemonSet checkpoint via readiness probe)
3. Poll for signal file (checkpoint done) or restore marker (CRIU restored us)
4. If restored: wake model and return True (caller proceeds with registration)
5. If checkpoint done: return False (caller should exit)
"""
# Sleep model for checkpoint
logger.info(f"Putting model to sleep (level={sleep_level})")
await engine_client.sleep(level=sleep_level)
# Signal readiness
with open(self.ready_file, "w") as f:
f.write("ready")
logger.info(
f"Ready for checkpoint. Waiting for signal: {self.signal_file} "
f"or restore marker: {self.restore_marker}"
)
# Poll for signal or restore
while True:
if os.path.exists(self.restore_marker):
logger.info(f"Restore detected (marker: {self.restore_marker})")
logger.info("Waking up model after restore")
await engine_client.wake_up()
return True
if os.path.exists(self.signal_file):
try:
signal = self._read_status_file(self.signal_file)
except (OSError, ValueError, json.JSONDecodeError) as exc:
raise RuntimeError(
f"Invalid checkpoint signal file {self.signal_file}: {exc}"
) from exc
if signal["success"]:
logger.info(f"Checkpoint complete (signal: {self.signal_file})")
return False
raise RuntimeError(
f"Checkpoint failed (signal: {self.signal_file}): "
f"{signal.get('error', 'unknown error')}"
)
await asyncio.sleep(1)
def get_checkpoint_config() -> Optional[CheckpointConfig]:
"""Returns CheckpointConfig if in checkpoint mode, None otherwise.
Checkpoint mode is detected by DYN_CHECKPOINT_SIGNAL_FILE being set.
If in checkpoint mode, all required env vars must be present — raises
EnvironmentError if any are missing.
"""
if "DYN_CHECKPOINT_SIGNAL_FILE" not in os.environ:
return None
missing = [v for v in _REQUIRED_ENV_VARS if v not in os.environ]
if missing:
raise EnvironmentError(
f"Checkpoint mode requires these environment variables: {', '.join(missing)}"
)
return CheckpointConfig()
...@@ -56,6 +56,7 @@ from dynamo.vllm.multimodal_handlers import ( ...@@ -56,6 +56,7 @@ from dynamo.vllm.multimodal_handlers import (
from dynamo.vllm.multimodal_utils.encode_utils import create_ec_transfer_config from dynamo.vllm.multimodal_utils.encode_utils import create_ec_transfer_config
from .args import Config, overwrite_args, parse_args from .args import Config, overwrite_args, parse_args
from .chrek import get_checkpoint_config
from .handlers import DecodeWorkerHandler, PrefillWorkerHandler from .handlers import DecodeWorkerHandler, PrefillWorkerHandler
from .health_check import ( from .health_check import (
VllmHealthCheckPayload, VllmHealthCheckPayload,
...@@ -66,6 +67,7 @@ from .publisher import DYNAMO_COMPONENT_REGISTRY, StatLoggerFactory ...@@ -66,6 +67,7 @@ from .publisher import DYNAMO_COMPONENT_REGISTRY, StatLoggerFactory
configure_dynamo_logging() configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
CHECKPOINT_SLEEP_MODE_LEVEL = 1
async def _handle_non_leader_node(dp_rank: int) -> None: async def _handle_non_leader_node(dp_rank: int) -> None:
...@@ -81,46 +83,17 @@ async def _handle_non_leader_node(dp_rank: int) -> None: ...@@ -81,46 +83,17 @@ async def _handle_non_leader_node(dp_rank: int) -> None:
await asyncio.Event().wait() await asyncio.Event().wait()
async def await_checkpoint_and_was_restored(signal_file: str) -> bool: async def graceful_shutdown(runtime, shutdown_event):
""" """
Wait for checkpoint signal file OR restore marker file. Shutdown dynamo distributed runtime.
The endpoints will be immediately invalidated so no new requests will be accepted.
In checkpoint creation mode, poll until either: For endpoints served with graceful_shutdown=True, the serving function will wait until all in-flight requests are finished.
1. The signal file exists (checkpoint complete, should exit) For endpoints served with graceful_shutdown=False, the serving function will return immediately.
2. The restore marker file exists (restored by CRIU, should proceed)
The restore marker file is created by the restore-entrypoint before CRIU restore,
so the restored process can detect it was restored even though os.environ is
restored from the checkpoint and doesn't contain new container env vars.
Args:
signal_file: Path to the checkpoint signal file
Returns:
True if restored (should proceed with registration)
False if signal file detected (should exit)
""" """
# Get restore marker file path (created by restore entrypoint before CRIU restore) logging.info("Received shutdown signal, shutting down DistributedRuntime")
restore_marker = os.environ.get("DYN_RESTORE_MARKER_FILE", "/tmp/dynamo-restored") shutdown_event.set()
runtime.shutdown()
logger.info( logging.info("DistributedRuntime shutdown complete")
f"CHECKPOINT_READY: Model loaded, ready for container checkpoint. Waiting for signal file: {signal_file} or restore marker file: {restore_marker}"
)
while True:
# Check if we've been restored (marker file created by restore entrypoint)
if os.path.exists(restore_marker):
logger.info(
f"Detected restore from checkpoint (marker file exists: {restore_marker})"
)
return True # Restored - proceed with registration
# Check if checkpoint is complete (signal file exists)
if os.path.exists(signal_file):
logger.info(f"Checkpoint signal file detected: {signal_file}")
return False # Checkpoint done - exit
await asyncio.sleep(1)
async def worker(): async def worker():
...@@ -134,29 +107,10 @@ async def worker(): ...@@ -134,29 +107,10 @@ async def worker():
if not config.served_model_name: if not config.served_model_name:
config.served_model_name = config.engine_args.served_model_name = config.model config.served_model_name = config.engine_args.served_model_name = config.model
# Check checkpoint-related environment variables EARLY # Check checkpoint mode and validate env vars EARLY (fail fast if misconfigured)
signal_file = os.environ.get("DYN_CHECKPOINT_SIGNAL_FILE") checkpoint_cfg = get_checkpoint_config()
ready_file = os.environ.get("DYN_CHECKPOINT_READY_FILE") if checkpoint_cfg and checkpoint_cfg.checkpoint_exists():
is_checkpoint_mode = signal_file is not None
# EARLY EXIT: Check if checkpoint already exists (before downloading model!)
if is_checkpoint_mode:
storage_type = os.environ.get("DYN_CHECKPOINT_STORAGE_TYPE")
checkpoint_location = os.environ.get("DYN_CHECKPOINT_LOCATION")
if storage_type == "pvc" and checkpoint_location:
done_marker = f"{checkpoint_location}/checkpoint.done"
if os.path.exists(done_marker):
logger.info(
f"Found existing checkpoint at {checkpoint_location}. Storage type: {storage_type}"
)
return return
else:
logger.info(
f"Checkpoint not found at: {checkpoint_location}. creating new checkpoint"
)
# Download the model if necessary using modelexpress. # Download the model if necessary using modelexpress.
# We want it on disk before we start vllm to avoid downloading from HuggingFace. # We want it on disk before we start vllm to avoid downloading from HuggingFace.
...@@ -173,39 +127,20 @@ async def worker(): ...@@ -173,39 +127,20 @@ async def worker():
# CHECKPOINT MODE: Load engine BEFORE runtime creation # CHECKPOINT MODE: Load engine BEFORE runtime creation
# This allows checkpointing GPU state before runtime connections are established # This allows checkpointing GPU state before runtime connections are established
pre_created_engine = None pre_created_engine = None
is_restored = False if checkpoint_cfg is not None:
if is_checkpoint_mode:
logger.info( logger.info(
f"Checkpoint mode enabled (DYN_CHECKPOINT_SIGNAL_FILE={signal_file})" f"Checkpoint mode enabled (signal_file={checkpoint_cfg.signal_file})"
) )
# CHECKPOINT MODE: Load model, sleep, wait for signal file or restore # Checkpoint mode requires sleep mode — enable before engine init
config.engine_args.enable_sleep_mode = True
pre_created_engine = setup_vllm_engine(config) pre_created_engine = setup_vllm_engine(config)
engine_client = pre_created_engine[0] engine_client = pre_created_engine[0]
# Put model to sleep before checkpoint (if sleep mode enabled) if not await checkpoint_cfg.run_lifecycle(
if config.engine_args.enable_sleep_mode: engine_client, CHECKPOINT_SLEEP_MODE_LEVEL
logger.info(f"Putting model to sleep (level={config.sleep_mode_level})") ):
await engine_client.sleep(level=config.sleep_mode_level)
# Write ready file to signal that we're ready for checkpointing
if ready_file:
with open(ready_file, "w") as f:
f.write("ready")
logger.info(f"Wrote checkpoint ready file: {ready_file}")
# Wait for checkpoint signal file OR restore detection
is_restored = await await_checkpoint_and_was_restored(signal_file)
if is_restored:
# Wake up model and proceed with registration
if config.engine_args.enable_sleep_mode:
logger.info("Waking up model after checkpoint restore")
await engine_client.wake_up()
logger.info("Proceeding with endpoint registration after restore")
else:
# Checkpoint complete, exit
logger.info("Exiting after checkpoint completion")
return return
shutdown_event = asyncio.Event() shutdown_event = asyncio.Event()
......
...@@ -142,14 +142,6 @@ COPY --from=builder /restore-entrypoint /restore-entrypoint ...@@ -142,14 +142,6 @@ COPY --from=builder /restore-entrypoint /restore-entrypoint
# Create checkpoint directory # Create checkpoint directory
RUN mkdir -p /checkpoints RUN mkdir -p /checkpoints
# Set environment variables
ENV HOST_PROC=/host/proc \
CONTAINERD_SOCKET=/run/containerd/containerd.sock \
CHECKPOINT_DIR=/checkpoints \
LISTEN_ADDR=:8080
EXPOSE 8080
USER root USER root
ENTRYPOINT ["/usr/local/bin/chrek-agent"] ENTRYPOINT ["/usr/local/bin/chrek-agent"]
...@@ -172,7 +164,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ ...@@ -172,7 +164,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
libnl-3-200 \ libnl-3-200 \
libnl-route-3-200 \ libnl-route-3-200 \
libprotobuf-c1 \ libprotobuf-c1 \
libgnutls30 \ libgnutls30t64 \
libnftables1 \ libnftables1 \
iproute2 \ iproute2 \
iptables \ iptables \
...@@ -191,23 +183,11 @@ COPY --from=criu-builder /tmp/cuda-checkpoint/bin/x86_64_Linux/cuda-checkpoint / ...@@ -191,23 +183,11 @@ COPY --from=criu-builder /tmp/cuda-checkpoint/bin/x86_64_Linux/cuda-checkpoint /
RUN chmod +x /usr/local/sbin/cuda-checkpoint RUN chmod +x /usr/local/sbin/cuda-checkpoint
# Create directories # Create directories
RUN mkdir -p /checkpoints /var/run/criu /tmp /var/criu-work RUN mkdir -p /checkpoints /var/run/criu /var/criu-work
# Copy restore binaries # Copy restore binaries
COPY --from=builder /restore-entrypoint /restore-entrypoint COPY --from=builder /restore-entrypoint /restore-entrypoint
RUN chmod +x /restore-entrypoint RUN chmod +x /restore-entrypoint
COPY scripts/smart-entrypoint.sh /smart-entrypoint.sh ENTRYPOINT ["/restore-entrypoint"]
RUN chmod +x /smart-entrypoint.sh
# Set environment variables
ENV DYN_CHECKPOINT_PATH=/checkpoints \
RESTORE_TRIGGER=/tmp/restore-trigger \
RESTORE_WAIT_TIMEOUT=300 \
CRIU_LOG_LEVEL=4 \
WAIT_FOR_CHECKPOINT=0 \
CUDA_PLUGIN_DIR=/usr/local/lib/criu \
DEBUG=0
ENTRYPOINT ["/smart-entrypoint.sh"]
CMD [] CMD []
// config.go provides configuration loading for the checkpoint agent.
package main
import (
"fmt"
"os"
"gopkg.in/yaml.v3"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint"
)
// ConfigMapPath is the default path where the ConfigMap is mounted.
const ConfigMapPath = "/etc/chrek/config.yaml"
// CheckpointSignalSource determines how checkpoint operations are triggered.
type CheckpointSignalSource string
const (
// SignalFromHTTP triggers checkpoints via HTTP API requests.
SignalFromHTTP CheckpointSignalSource = "http"
// SignalFromWatcher triggers checkpoints automatically when pods become Ready.
SignalFromWatcher CheckpointSignalSource = "watcher"
)
// FullConfig is the root configuration structure loaded from the ConfigMap.
type FullConfig struct {
Agent AgentConfig `yaml:"agent"`
Checkpoint checkpoint.CheckpointSpec `yaml:"checkpoint"`
}
// AgentConfig holds the runtime configuration for the checkpoint agent daemon.
type AgentConfig struct {
// SignalSource determines how checkpoints are triggered: "http" or "watcher"
SignalSource string `yaml:"signalSource"`
// ListenAddr is the HTTP server address for health checks and API
ListenAddr string `yaml:"listenAddr"`
// NodeName is the Kubernetes node name (from NODE_NAME env, downward API)
NodeName string `yaml:"-"`
// RestrictedNamespace restricts pod watching to this namespace (optional)
RestrictedNamespace string `yaml:"-"`
}
// ConfigError represents a configuration validation error.
type ConfigError struct {
Field string
Message string
}
func (e *ConfigError) Error() string {
return fmt.Sprintf("config error: %s: %s", e.Field, e.Message)
}
// LoadConfig loads the full configuration from a YAML file.
func LoadConfig(path string) (*FullConfig, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read config file %s: %w", path, err)
}
cfg := &FullConfig{}
if err := yaml.Unmarshal(data, cfg); err != nil {
return nil, fmt.Errorf("failed to parse config file %s: %w", path, err)
}
// Apply environment variable overrides
cfg.Agent.loadEnvOverrides()
return cfg, nil
}
// LoadConfigOrDefault loads configuration from a file, falling back to zero values if the file doesn't exist.
func LoadConfigOrDefault(path string) (*FullConfig, error) {
cfg, err := LoadConfig(path)
if err != nil {
if os.IsNotExist(err) {
cfg = &FullConfig{}
cfg.Agent.loadEnvOverrides()
return cfg, nil
}
return nil, err
}
return cfg, nil
}
// loadEnvOverrides applies environment variable overrides to the AgentConfig.
func (c *AgentConfig) loadEnvOverrides() {
if v := os.Getenv("NODE_NAME"); v != "" {
c.NodeName = v
}
if v := os.Getenv("RESTRICTED_NAMESPACE"); v != "" {
c.RestrictedNamespace = v
}
}
// GetSignalSource returns the signal source as a CheckpointSignalSource type.
func (c *AgentConfig) GetSignalSource() CheckpointSignalSource {
return CheckpointSignalSource(c.SignalSource)
}
// Validate checks that the AgentConfig has valid values.
func (c *AgentConfig) Validate() error {
if c.SignalSource != string(SignalFromHTTP) && c.SignalSource != string(SignalFromWatcher) {
return &ConfigError{
Field: "signalSource",
Message: "must be 'http' or 'watcher'",
}
}
if c.SignalSource == string(SignalFromHTTP) && c.ListenAddr == "" {
return &ConfigError{
Field: "listenAddr",
Message: "cannot be empty when signalSource is 'http'",
}
}
return nil
}
// Validate validates the full configuration.
func (c *FullConfig) Validate() error {
if err := c.Agent.Validate(); err != nil {
return err
}
if err := c.Checkpoint.Validate(); err != nil {
return err
}
return nil
}
...@@ -6,174 +6,39 @@ package main ...@@ -6,174 +6,39 @@ package main
import ( import (
"context" "context"
"encoding/json"
"fmt"
"log" "log"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
"strconv"
"strings"
"syscall" "syscall"
"time" "time"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint" "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint"
checkpointk8s "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint/k8s" httpApiServer "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/http_api_server"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/watcher" "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/watcher"
) )
// CheckpointSignalSource determines how checkpoint operations are triggered
type CheckpointSignalSource string
const (
// SignalFromHTTP triggers checkpoints via HTTP API requests
SignalFromHTTP CheckpointSignalSource = "http"
// SignalFromWatcher triggers checkpoints automatically when pods become Ready
SignalFromWatcher CheckpointSignalSource = "watcher"
)
// Config holds the agent configuration
type Config struct {
// Common settings
ContainerdSocket string
CheckpointDir string
HostProc string
NodeName string
RestrictedNamespace string // Optional: restrict pod watching to this namespace
// Mode selection
SignalSource CheckpointSignalSource // "http" or "watcher"
// HTTP API mode settings (used when SignalSource = "http")
ListenAddr string
// CRIU settings (configurable options only; LeaveRunning, ShellJob, etc. are hardcoded in pkg/checkpoint/criu.go)
CUDAPluginDir string // Path to CRIU CUDA plugin directory
GhostLimit uint32 // CRIU ghost limit in bytes
Timeout uint32 // CRIU timeout in seconds
ExternalMounts []string // External mount mappings
}
// Server is the HTTP API server
type Server struct {
config Config
discoveryClient *checkpointk8s.DiscoveryClient
checkpointer *checkpoint.Checkpointer
}
// CheckpointRequest is the request body for checkpoint operations
type CheckpointRequest struct {
ContainerID string `json:"container_id"`
CheckpointID string `json:"checkpoint_id"`
PodName string `json:"pod_name,omitempty"`
PodNamespace string `json:"pod_namespace,omitempty"`
DisableCUDA bool `json:"disable_cuda,omitempty"` // Disable CUDA plugin for non-GPU workloads
}
// TriggerRestoreRequest is the request body for Option A self-restoring trigger
type TriggerRestoreRequest struct {
CheckpointID string `json:"checkpoint_id"`
PlaceholderContainerID string `json:"placeholder_container_id"`
SkipImageValidation bool `json:"skip_image_validation,omitempty"` // Skip image matching check
}
// TriggerRestoreResponse is the response for trigger restore operations
type TriggerRestoreResponse struct {
Success bool `json:"success"`
Message string `json:"message,omitempty"`
Error string `json:"error,omitempty"`
TriggerPath string `json:"trigger_path,omitempty"`
CheckpointImage string `json:"checkpoint_image,omitempty"`
PlaceholderImage string `json:"placeholder_image,omitempty"`
}
// CheckpointResponse is the response for checkpoint operations
type CheckpointResponse struct {
Success bool `json:"success"`
CheckpointID string `json:"checkpoint_id,omitempty"`
Message string `json:"message,omitempty"`
Error string `json:"error,omitempty"`
}
// CheckpointInfo represents information about a checkpoint
type CheckpointInfo struct {
ID string `json:"id"`
CreatedAt time.Time `json:"created_at"`
SourceNode string `json:"source_node"`
ContainerID string `json:"container_id"`
PodName string `json:"pod_name"`
PodNamespace string `json:"pod_namespace"`
Image string `json:"image"`
}
// ListCheckpointsResponse is the response for list checkpoints
type ListCheckpointsResponse struct {
Checkpoints []CheckpointInfo `json:"checkpoints"`
}
// HealthResponse is the response for health check
type HealthResponse struct {
Status string `json:"status"`
NodeName string `json:"node_name"`
}
func main() { func main() {
// Parse signal source - default to HTTP for backward compatibility // Load configuration from ConfigMap (or use defaults if not found)
signalSource := CheckpointSignalSource(strings.ToLower(getEnv("CHECKPOINT_SIGNAL_FROM", "http"))) cfg, err := LoadConfigOrDefault(ConfigMapPath)
if signalSource != SignalFromHTTP && signalSource != SignalFromWatcher { if err != nil {
log.Fatalf("Invalid CHECKPOINT_SIGNAL_FROM value: %q (must be 'http' or 'watcher')", signalSource) log.Fatalf("Failed to load configuration: %v", err)
}
// Parse CRIU settings
var ghostLimit, timeout uint32
if v := os.Getenv("CRIU_GHOST_LIMIT"); v != "" {
if parsed, err := strconv.ParseUint(v, 10, 32); err == nil {
ghostLimit = uint32(parsed)
}
}
if v := os.Getenv("CRIU_TIMEOUT"); v != "" {
if parsed, err := strconv.ParseUint(v, 10, 32); err == nil {
timeout = uint32(parsed)
}
}
// Parse external mounts (comma-separated)
var externalMounts []string
if v := os.Getenv("EXTERNAL_MOUNTS"); v != "" {
externalMounts = strings.Split(v, ",")
} }
config := Config{ // Validate configuration
// Common settings if err := cfg.Agent.Validate(); err != nil {
ContainerdSocket: getEnv("CONTAINERD_SOCKET", "/run/containerd/containerd.sock"), log.Fatalf("Invalid configuration: %v", err)
CheckpointDir: getEnv("CHECKPOINT_DIR", "/checkpoints"),
HostProc: getEnv("HOST_PROC", "/host/proc"),
NodeName: getEnv("NODE_NAME", "unknown"),
RestrictedNamespace: os.Getenv("RESTRICTED_NAMESPACE"), // Optional: empty = cluster-wide watching
// Mode selection
SignalSource: signalSource,
// HTTP API settings
ListenAddr: getEnv("LISTEN_ADDR", ":8080"),
// CRIU settings
CUDAPluginDir: getEnv("CUDA_PLUGIN_DIR", ""),
GhostLimit: ghostLimit,
Timeout: timeout,
ExternalMounts: externalMounts,
} }
// Create discovery client // Create discovery client
discoveryClient, err := checkpointk8s.NewDiscoveryClient(config.ContainerdSocket) discoveryClient, err := checkpoint.NewDiscoveryClient()
if err != nil { if err != nil {
log.Fatalf("Failed to create discovery client: %v", err) log.Fatalf("Failed to create discovery client: %v", err)
} }
defer discoveryClient.Close() defer discoveryClient.Close()
// Create checkpointer // Create checkpointer
checkpointer := checkpoint.NewCheckpointer(discoveryClient, config.HostProc) checkpointer := checkpoint.NewCheckpointer(discoveryClient)
// Context for graceful shutdown // Context for graceful shutdown
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
...@@ -183,60 +48,39 @@ func main() { ...@@ -183,60 +48,39 @@ func main() {
sigChan := make(chan os.Signal, 1) sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
log.Printf("CRIU Node Agent starting (node: %s)", config.NodeName) log.Printf("CRIU Node Agent starting (node: %s)", cfg.Agent.NodeName)
log.Printf("Checkpoint directory: %s", config.CheckpointDir) log.Printf("Checkpoint directory: %s", cfg.Checkpoint.BasePath)
log.Printf("Signal source: %s", config.SignalSource) log.Printf("Signal source: %s", cfg.Agent.SignalSource)
switch config.SignalSource { switch cfg.Agent.GetSignalSource() {
case SignalFromHTTP: case SignalFromHTTP:
server := &Server{ serverCfg := httpApiServer.ServerConfig{
config: config, ListenAddr: cfg.Agent.ListenAddr,
discoveryClient: discoveryClient, NodeName: cfg.Agent.NodeName,
checkpointer: checkpointer, CheckpointSpec: &cfg.Checkpoint,
}
// Setup routes
mux := http.NewServeMux()
mux.HandleFunc("/health", server.handleHealth)
mux.HandleFunc("/checkpoint", server.handleCheckpoint)
mux.HandleFunc("/restore/trigger", server.handleTriggerRestore)
mux.HandleFunc("/checkpoints", server.handleListCheckpoints)
httpServer := &http.Server{
Addr: config.ListenAddr,
Handler: loggingMiddleware(mux),
ReadTimeout: 30 * time.Second,
WriteTimeout: 300 * time.Second,
IdleTimeout: 120 * time.Second,
} }
srv := httpApiServer.NewServer(serverCfg, checkpointer)
// Handle graceful shutdown // Handle graceful shutdown
go func() { go func() {
<-sigChan <-sigChan
log.Println("Shutting down HTTP server...")
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second) shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel() defer shutdownCancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil { if err := srv.Shutdown(shutdownCtx); err != nil {
log.Printf("HTTP server shutdown error: %v", err) log.Printf("HTTP server shutdown error: %v", err)
} }
}() }()
log.Printf("HTTP API server listening on %s", config.ListenAddr) if err := srv.Start(); err != http.ErrServerClosed {
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
log.Fatalf("HTTP server error: %v", err) log.Fatalf("HTTP server error: %v", err)
} }
case SignalFromWatcher: case SignalFromWatcher:
watcherConfig := watcher.Config{ watcherConfig := watcher.WatcherConfig{
NodeName: config.NodeName, NodeName: cfg.Agent.NodeName,
CheckpointDir: config.CheckpointDir, ListenAddr: cfg.Agent.ListenAddr,
HostProc: config.HostProc, RestrictedNamespace: cfg.Agent.RestrictedNamespace,
ListenAddr: config.ListenAddr, // For health check endpoint CheckpointSpec: &cfg.Checkpoint,
RestrictedNamespace: config.RestrictedNamespace,
CUDAPluginDir: config.CUDAPluginDir,
GhostLimit: config.GhostLimit,
Timeout: config.Timeout,
ExternalMounts: config.ExternalMounts,
} }
podWatcher, err := watcher.NewWatcher(watcherConfig, discoveryClient, checkpointer) podWatcher, err := watcher.NewWatcher(watcherConfig, discoveryClient, checkpointer)
...@@ -251,304 +95,15 @@ func main() { ...@@ -251,304 +95,15 @@ func main() {
cancel() cancel()
}() }()
log.Printf("Pod watcher started (watching for label: nvidia.com/checkpoint-source=true)") log.Printf("Pod watcher started (watching for label: %s=true)", checkpoint.KubeLabelCheckpointSource)
log.Printf("Health check endpoint: http://0.0.0.0%s/health", config.ListenAddr) log.Printf("Health check endpoint: http://0.0.0.0%s/health", cfg.Agent.ListenAddr)
if err := podWatcher.Start(ctx); err != nil { if err := podWatcher.Start(ctx); err != nil {
log.Printf("Pod watcher error: %v", err) log.Printf("Pod watcher error: %v", err)
} }
}
log.Println("Agent stopped")
}
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
resp := HealthResponse{
Status: "healthy",
NodeName: s.config.NodeName,
}
writeJSON(w, http.StatusOK, resp) default:
} log.Fatalf("Unknown signal source: %s", cfg.Agent.SignalSource)
func (s *Server) handleCheckpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req CheckpointRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, CheckpointResponse{
Success: false,
Error: fmt.Sprintf("Invalid request body: %v", err),
})
return
}
if req.ContainerID == "" {
writeJSON(w, http.StatusBadRequest, CheckpointResponse{
Success: false,
Error: "container_id is required",
})
return
}
if req.CheckpointID == "" {
req.CheckpointID = fmt.Sprintf("ckpt-%d", time.Now().UnixNano())
}
// Determine CUDA plugin directory - only use if not explicitly disabled
cudaPluginDir := s.config.CUDAPluginDir
if req.DisableCUDA {
cudaPluginDir = ""
}
// Parse optional CRIU settings from environment
var ghostLimit, timeout uint32
if v := os.Getenv("CRIU_GHOST_LIMIT"); v != "" {
if parsed, err := strconv.ParseUint(v, 10, 32); err == nil {
ghostLimit = uint32(parsed)
}
}
if v := os.Getenv("CRIU_TIMEOUT"); v != "" {
if parsed, err := strconv.ParseUint(v, 10, 32); err == nil {
timeout = uint32(parsed)
}
} }
opts := checkpoint.Options{ log.Println("Agent stopped")
ContainerID: req.ContainerID,
CheckpointID: req.CheckpointID,
CheckpointDir: s.config.CheckpointDir,
NodeName: s.config.NodeName,
PodName: req.PodName,
PodNamespace: req.PodNamespace,
GhostLimit: ghostLimit,
Timeout: timeout,
CUDAPluginDir: cudaPluginDir,
}
ctx := r.Context()
result, err := s.checkpointer.Checkpoint(ctx, opts)
if err != nil {
log.Printf("Checkpoint failed: %v", err)
writeJSON(w, http.StatusInternalServerError, CheckpointResponse{
Success: false,
Error: err.Error(),
})
return
}
log.Printf("Checkpoint successful: %s", result.CheckpointID)
writeJSON(w, http.StatusOK, CheckpointResponse{
Success: true,
CheckpointID: result.CheckpointID,
Message: fmt.Sprintf("Checkpoint created successfully at %s", result.CheckpointDir),
})
}
// handleTriggerRestore implements Option A from RESTORE_ANALYSIS.md
// It triggers a self-restoring placeholder container to start CRIU restore.
// The agent writes a trigger file to the placeholder's filesystem, which
// the placeholder's entrypoint script detects and uses to start restoration.
func (s *Server) handleTriggerRestore(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req TriggerRestoreRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, TriggerRestoreResponse{
Success: false,
Error: fmt.Sprintf("Invalid request body: %v", err),
})
return
}
if req.CheckpointID == "" {
writeJSON(w, http.StatusBadRequest, TriggerRestoreResponse{
Success: false,
Error: "checkpoint_id is required",
})
return
}
if req.PlaceholderContainerID == "" {
writeJSON(w, http.StatusBadRequest, TriggerRestoreResponse{
Success: false,
Error: "placeholder_container_id is required",
})
return
}
// Verify checkpoint exists and load metadata
checkpointPath := common.GetCheckpointDir(s.config.CheckpointDir, req.CheckpointID)
checkpointMeta, err := common.LoadMetadata(checkpointPath)
if err != nil {
writeJSON(w, http.StatusNotFound, TriggerRestoreResponse{
Success: false,
Error: fmt.Sprintf("Checkpoint not found: %v", err),
})
return
}
// Resolve placeholder container to get PID and image
ctx := r.Context()
containerInfo, err := s.discoveryClient.ResolveContainer(ctx, req.PlaceholderContainerID)
if err != nil {
writeJSON(w, http.StatusInternalServerError, TriggerRestoreResponse{
Success: false,
Error: fmt.Sprintf("Failed to resolve placeholder container: %v", err),
})
return
}
// Validate that placeholder image matches checkpoint's original image
// This is critical because rootfs-diff.tar only contains upperdir modifications,
// not the base image layers (lowerdir). If images differ, CRIU restore will fail.
if !req.SkipImageValidation && checkpointMeta.Image != "" {
if !imagesCompatible(checkpointMeta.Image, containerInfo.Image) {
writeJSON(w, http.StatusBadRequest, TriggerRestoreResponse{
Success: false,
Error: fmt.Sprintf("Image mismatch: checkpoint was from '%s' but placeholder uses '%s'. The placeholder must use the same base image. Use skip_image_validation=true to override.", checkpointMeta.Image, containerInfo.Image),
CheckpointImage: checkpointMeta.Image,
PlaceholderImage: containerInfo.Image,
})
return
}
log.Printf("Image validation passed: checkpoint=%s, placeholder=%s", checkpointMeta.Image, containerInfo.Image)
}
// Write trigger file to placeholder's filesystem via /proc/<pid>/root
// The trigger file contains the checkpoint path
triggerPath := fmt.Sprintf("%s/%d/root/tmp/restore-trigger", s.config.HostProc, containerInfo.PID)
// Write the checkpoint path to the trigger file
if err := os.WriteFile(triggerPath, []byte(checkpointPath), 0644); err != nil {
writeJSON(w, http.StatusInternalServerError, TriggerRestoreResponse{
Success: false,
Error: fmt.Sprintf("Failed to write trigger file: %v", err),
})
return
}
log.Printf("Triggered restore for placeholder %s (PID %d) from checkpoint %s",
req.PlaceholderContainerID, containerInfo.PID, req.CheckpointID)
writeJSON(w, http.StatusOK, TriggerRestoreResponse{
Success: true,
Message: fmt.Sprintf("Restore triggered for checkpoint %s", req.CheckpointID),
TriggerPath: triggerPath,
CheckpointImage: checkpointMeta.Image,
PlaceholderImage: containerInfo.Image,
})
}
func (s *Server) handleListCheckpoints(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
checkpointIDs, err := common.ListCheckpoints(s.config.CheckpointDir)
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]string{
"error": err.Error(),
})
return
}
var checkpoints []CheckpointInfo
for _, id := range checkpointIDs {
meta, err := common.GetCheckpointInfo(s.config.CheckpointDir, id)
if err != nil {
continue
}
checkpoints = append(checkpoints, CheckpointInfo{
ID: meta.CheckpointID,
CreatedAt: meta.CreatedAt,
SourceNode: meta.SourceNode,
ContainerID: meta.ContainerID,
PodName: meta.PodName,
PodNamespace: meta.PodNamespace,
Image: meta.Image,
})
}
writeJSON(w, http.StatusOK, ListCheckpointsResponse{
Checkpoints: checkpoints,
})
}
func writeJSON(w http.ResponseWriter, status int, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
func loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
log.Printf("Started %s %s", r.Method, r.URL.Path)
next.ServeHTTP(w, r)
log.Printf("Completed %s %s in %v", r.Method, r.URL.Path, time.Since(start))
})
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
// imagesCompatible checks if two container images are compatible for CRIU restore.
// The placeholder image must be based on the same image as the checkpoint.
// Handles various image name formats:
// - nginx:alpine vs nginx:alpine (exact match)
// - docker.io/library/nginx:alpine vs nginx:alpine (registry prefix)
// - criu-placeholder-nginx-alpine vs nginx:alpine (placeholder naming convention)
func imagesCompatible(checkpointImage, placeholderImage string) bool {
// Exact match
if checkpointImage == placeholderImage {
return true
}
// Normalize images by removing common registry prefixes
normalize := func(img string) string {
// Remove docker.io/library/ prefix
img = strings.TrimPrefix(img, "docker.io/library/")
// Remove docker.io/ prefix
img = strings.TrimPrefix(img, "docker.io/")
return img
}
normalizedCheckpoint := normalize(checkpointImage)
normalizedPlaceholder := normalize(placeholderImage)
if normalizedCheckpoint == normalizedPlaceholder {
return true
}
// Check if placeholder follows criu-placeholder-<image> naming convention
// e.g., criu-placeholder-nginx-alpine should match nginx:alpine
if strings.HasPrefix(normalizedPlaceholder, "criu-placeholder-") {
// Convert nginx:alpine to nginx-alpine for comparison
checkpointSanitized := strings.ReplaceAll(normalizedCheckpoint, ":", "-")
checkpointSanitized = strings.ReplaceAll(checkpointSanitized, "/", "-")
expectedPlaceholder := "criu-placeholder-" + checkpointSanitized
if normalizedPlaceholder == expectedPlaceholder ||
strings.HasPrefix(normalizedPlaceholder, expectedPlaceholder+":") {
return true
}
}
return false
} }
...@@ -5,8 +5,12 @@ package main ...@@ -5,8 +5,12 @@ package main
import ( import (
"context" "context"
"fmt"
"os" "os"
"os/exec"
"os/signal" "os/signal"
"path/filepath"
"strings"
"syscall" "syscall"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
...@@ -14,7 +18,50 @@ import ( ...@@ -14,7 +18,50 @@ import (
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/restore" "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/restore"
) )
// logGPUDiagnostics logs nvidia-smi output and /dev/nvidia* devices for debugging GPU visibility.
func logGPUDiagnostics(label string) {
fmt.Printf("=== GPU DIAGNOSTICS [%s] ===\n", label)
// nvidia-smi
if out, err := exec.Command("nvidia-smi", "-L").CombinedOutput(); err != nil {
fmt.Printf("nvidia-smi -L: error: %v\n", err)
} else {
fmt.Printf("nvidia-smi -L:\n%s", out)
}
// GPU memory usage
if out, err := exec.Command("nvidia-smi", "--query-gpu=index,uuid,memory.used,memory.total,memory.free", "--format=csv,noheader").CombinedOutput(); err != nil {
fmt.Printf("nvidia-smi memory query: error: %v\n", err)
} else {
fmt.Printf("nvidia-smi memory:\n%s", out)
}
// /dev/nvidia* devices
matches, _ := filepath.Glob("/dev/nvidia*")
fmt.Printf("/dev/nvidia* devices: %s\n", strings.Join(matches, ", "))
// NVIDIA_VISIBLE_DEVICES env
fmt.Printf("NVIDIA_VISIBLE_DEVICES=%s\n", os.Getenv("NVIDIA_VISIBLE_DEVICES"))
fmt.Printf("CUDA_VISIBLE_DEVICES=%s\n", os.Getenv("CUDA_VISIBLE_DEVICES"))
// Linux namespaces for PID 1
for _, ns := range []string{"mnt", "pid", "ipc", "net", "uts", "cgroup"} {
link, err := os.Readlink(fmt.Sprintf("/proc/1/ns/%s", ns))
if err != nil {
link = err.Error()
}
fmt.Printf("ns/%s: %s\n", ns, link)
}
fmt.Printf("=== END GPU DIAGNOSTICS [%s] ===\n", label)
}
func main() { func main() {
// Log GPU diagnostics BEFORE anything else (gated on DEBUG for production quietness)
if os.Getenv("DEBUG") == "1" {
logGPUDiagnostics("PRE-RESTORE")
}
// Set up logging // Set up logging
log := logrus.New() log := logrus.New()
log.SetOutput(os.Stdout) log.SetOutput(os.Stdout)
...@@ -23,8 +70,12 @@ func main() { ...@@ -23,8 +70,12 @@ func main() {
TimestampFormat: "2006-01-02 15:04:05", TimestampFormat: "2006-01-02 15:04:05",
}) })
// Load configuration from environment // Load configuration from hardcoded defaults + operator-injected env vars.
cfg := restore.ConfigFromEnv() // os.Args[1:] are the cold start command args (passed by the operator via pod spec).
cfg, err := restore.NewRestoreRequest(os.Args[1:])
if err != nil {
log.WithError(err).Fatal("Failed to load restore configuration")
}
// Set log level based on DEBUG flag // Set log level based on DEBUG flag
if cfg.Debug { if cfg.Debug {
......
...@@ -6,19 +6,57 @@ import ( ...@@ -6,19 +6,57 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"time" "time"
criu "github.com/checkpoint-restore/go-criu/v7" criurpc "github.com/checkpoint-restore/go-criu/v7/rpc"
specs "github.com/opencontainers/runtime-spec/specs-go"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
checkpointk8s "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint/k8s"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common" "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
) )
// Options configures the checkpoint operation // ContainerInfoSnapshot holds runtime/container info needed for checkpointing.
type Options struct { type ContainerInfoSnapshot struct {
PID int
RootFS string
UpperDir string
OCISpec *specs.Spec
MountInfo []MountInfo
Namespaces map[NamespaceType]*NamespaceInfo
}
// CheckpointManifest is saved as manifest.yaml at checkpoint time and loaded at restore.
type CheckpointManifest struct {
CheckpointID string `yaml:"checkpointId"`
CreatedAt time.Time `yaml:"createdAt"`
CRIUDump CRIUDumpManifest `yaml:"criuDump"`
K8s SourcePodManifest `yaml:"k8s"`
Filesystem FilesystemManifest `yaml:"filesystem"`
Namespaces []NamespaceManifestEntry `yaml:"namespaces"`
}
// NewCheckpointManifest assembles a CheckpointManifest from per-module builders.
func NewCheckpointManifest(
checkpointID string,
criuDump CRIUDumpManifest,
k8s SourcePodManifest,
filesystem FilesystemManifest,
namespaces []NamespaceManifestEntry,
) *CheckpointManifest {
return &CheckpointManifest{
CheckpointID: checkpointID,
CreatedAt: time.Now().UTC(),
CRIUDump: criuDump,
K8s: k8s,
Filesystem: filesystem,
Namespaces: namespaces,
}
}
// CheckpointRequest holds per-checkpoint identifiers for a checkpoint operation.
type CheckpointRequest struct {
ContainerID string ContainerID string
ContainerName string // K8s container name (for K8s API volume type lookup) ContainerName string // K8s container name (for K8s API volume type lookup)
CheckpointID string CheckpointID string
...@@ -26,237 +64,180 @@ type Options struct { ...@@ -26,237 +64,180 @@ type Options struct {
NodeName string NodeName string
PodName string PodName string
PodNamespace string PodNamespace string
// CRIU options (from environment variables)
GhostLimit uint32 // From CRIU_GHOST_LIMIT: ghost file size limit in bytes (0 = CRIU default)
Timeout uint32 // From CRIU_TIMEOUT: timeout in seconds (0 = no timeout)
// GPU/CUDA checkpoint options
CUDAPluginDir string // Path to CRIU CUDA plugin (e.g., /home/mmshin/work/criu/plugins/cuda)
ExternalMounts []string // Additional external mount mappings (e.g., "mnt[path]:path")
} }
// Result contains the result of a checkpoint operation // CheckpointOutcome contains the result of a checkpoint operation.
type Result struct { type CheckpointOutcome struct {
CheckpointID string CheckpointID string
CheckpointDir string CheckpointDir string
Metadata *common.CheckpointMetadata Data *CheckpointManifest
} }
// Checkpointer performs CRIU checkpoint operations // Checkpointer performs CRIU checkpoint operations
type Checkpointer struct { type Checkpointer struct {
discoveryClient *checkpointk8s.DiscoveryClient discoveryClient *DiscoveryClient
k8sClient *checkpointk8s.K8sClient // Optional: for accurate volume type discovery from K8s API
hostProc string
log *logrus.Entry log *logrus.Entry
} }
// NewCheckpointer creates a new checkpointer // NewCheckpointer creates a new checkpointer
func NewCheckpointer(discoveryClient *checkpointk8s.DiscoveryClient, hostProc string) *Checkpointer { func NewCheckpointer(discoveryClient *DiscoveryClient) *Checkpointer {
if hostProc == "" {
hostProc = os.Getenv("HOST_PROC")
if hostProc == "" {
hostProc = "/proc"
}
}
return &Checkpointer{ return &Checkpointer{
discoveryClient: discoveryClient, discoveryClient: discoveryClient,
hostProc: hostProc,
log: logrus.WithField("component", "checkpointer"), log: logrus.WithField("component", "checkpointer"),
} }
} }
// WithK8sClient sets an optional Kubernetes client for accurate volume type discovery. // Checkpoint performs a CRIU dump of a container.
// When set, volume types are fetched from the K8s API instead of being inferred from paths. // The operation has three phases: introspect, configure, capture.
func (c *Checkpointer) WithK8sClient(client *checkpointk8s.K8sClient) *Checkpointer { func (c *Checkpointer) Checkpoint(ctx context.Context, req CheckpointRequest, spec *CheckpointSpec) (*CheckpointOutcome, error) {
c.k8sClient = client if spec == nil {
return c return nil, fmt.Errorf("checkpoint spec is required")
} }
// Checkpoint performs a CRIU dump of a container
func (c *Checkpointer) Checkpoint(ctx context.Context, opts Options) (*Result, error) {
checkpointStart := time.Now() checkpointStart := time.Now()
c.log.Info("=== Starting checkpoint operation ===") c.log.Info("=== Starting checkpoint operation ===")
// 1. Resolve container to get PID checkpointDir := filepath.Join(req.CheckpointDir, req.CheckpointID)
resolveStart := time.Now()
containerInfo, err := c.discoveryClient.ResolveContainer(ctx, opts.ContainerID)
if err != nil {
return nil, fmt.Errorf("failed to resolve container: %w", err)
}
pid := int(containerInfo.PID)
c.log.WithField("duration", time.Since(resolveStart)).Info("Container resolution completed")
// 2. Create checkpoint directory
checkpointDir := common.GetCheckpointDir(opts.CheckpointDir, opts.CheckpointID)
if err := os.MkdirAll(checkpointDir, 0700); err != nil { if err := os.MkdirAll(checkpointDir, 0700); err != nil {
return nil, fmt.Errorf("failed to create checkpoint directory: %w", err) return nil, fmt.Errorf("failed to create checkpoint directory: %w", err)
} }
// 3. Introspect container state // Open image directory FD for CRIU — must stay open through both configure and capture
introspectStart := time.Now() // phases since CRIU's swrk child process inherits this FD.
rootFS, err := GetRootFS(pid, c.hostProc) imageDir, imageDirFD, err := common.OpenPathForCRIU(checkpointDir)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get rootfs: %w", err) return nil, fmt.Errorf("failed to open image directory: %w", err)
} }
mounts, err := GetKubernetesVolumeMounts(pid, c.hostProc) defer imageDir.Close()
// Phase 1: Introspect container state
state, err := c.introspect(ctx, req.ContainerID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get mounts: %w", err) return nil, err
} }
namespaces, err := GetAllNamespaces(pid, c.hostProc)
// Phase 2: Configure CRIU options and build checkpoint manifest.
criuOpts, data, err := c.configure(state, req, spec, checkpointDir, imageDirFD)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get namespaces: %w", err) return nil, err
} }
c.log.WithField("duration", time.Since(introspectStart)).Info("Container introspection completed")
// 4. Open image directory FD // Phase 3: Capture — CRIU dump, /dev/shm, rootfs diff
imageDir, imageDirFD, err := OpenImageDir(checkpointDir) criuDumpDuration, err := c.capture(criuOpts, data, state, checkpointDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer imageDir.Close()
// 5. Build CRIU options totalDuration := time.Since(checkpointStart)
criuOpts := BuildCRIUOptsFromCheckpointOpts(opts, pid, imageDirFD, rootFS)
// 6. Create CRIU config file for CUDA plugin (libdir is not available via RPC)
if opts.CUDAPluginDir != "" {
if opts.Timeout == 0 {
return nil, fmt.Errorf("CRIU_TIMEOUT environment variable must be set for CUDA checkpoints")
}
configPath := filepath.Join(checkpointDir, "criu.conf")
configContent := fmt.Sprintf(`enable-external-masters
libdir %s
tcp-close
link-remap
timeout %d
allow-uprobes
skip-in-flight
`, opts.CUDAPluginDir, opts.Timeout)
if err := os.WriteFile(configPath, []byte(configContent), 0644); err != nil {
return nil, fmt.Errorf("failed to write CRIU config file: %w", err)
}
criuOpts.ConfigFile = proto.String(configPath)
c.log.WithFields(logrus.Fields{ c.log.WithFields(logrus.Fields{
"config_path": configPath, "total_duration": totalDuration,
"plugin_dir": opts.CUDAPluginDir, "criu_dump_duration": criuDumpDuration,
}).Info("Created CRIU config file for CUDA plugin") }).Info("=== Checkpoint operation completed ===")
return &CheckpointOutcome{
CheckpointID: req.CheckpointID,
CheckpointDir: checkpointDir,
Data: data,
}, nil
}
// introspect resolves the container and gathers all runtime state from containerd and /proc.
func (c *Checkpointer) introspect(ctx context.Context, containerID string) (*ContainerInfoSnapshot, error) {
pid, ociSpec, err := c.discoveryClient.ResolveContainer(ctx, containerID)
if err != nil {
return nil, fmt.Errorf("failed to resolve container: %w", err)
} }
// 7. Configure external mounts and namespaces rootFS, err := GetRootFS(pid)
if err := ConfigureExternalMounts(criuOpts, pid, c.hostProc, containerInfo); err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("failed to get rootfs: %w", err)
} }
netNsInode := ConfigureExternalNamespaces(criuOpts, namespaces, opts.ExternalMounts) upperDir, err := GetOverlayUpperDir(pid)
if netNsInode > 0 { if err != nil {
c.log.WithField("inode", netNsInode).Debug("Marked network namespace as external") return nil, fmt.Errorf("failed to get overlay upperdir: %w", err)
} }
for _, extMount := range opts.ExternalMounts { mountInfo, err := ReadMountInfoFromHostProcPath(pid)
c.log.WithField("external", extMount).Debug("Added external mount mapping") if err != nil {
return nil, fmt.Errorf("failed to parse mountinfo: %w", err)
} }
namespaces, err := GetAllNamespaces(pid)
// 8. Get overlay upperdir for rootfs diff capture if err != nil {
upperDir, upperDirErr := GetOverlayUpperDir(pid, c.hostProc) return nil, fmt.Errorf("failed to get namespaces: %w", err)
if upperDirErr != nil {
c.log.WithError(upperDirErr).Warn("Could not get overlay upperdir - rootfs diff will not be captured")
} else {
c.log.WithField("upperdir", upperDir).Debug("Found overlay upperdir")
} }
// 9. Build and save initial metadata before dump return &ContainerInfoSnapshot{
metaCfg := MetadataBuilderConfig{
CheckpointID: opts.CheckpointID,
NodeName: opts.NodeName,
ContainerID: opts.ContainerID,
ContainerName: opts.ContainerName,
PodName: opts.PodName,
PodNamespace: opts.PodNamespace,
PID: pid, PID: pid,
CUDAPluginDir: opts.CUDAPluginDir, RootFS: rootFS,
} UpperDir: upperDir,
meta := BuildCheckpointMetadata(ctx, metaCfg, containerInfo, mounts, namespaces, c.k8sClient, c.log) OCISpec: ociSpec,
if upperDir != "" { MountInfo: mountInfo,
meta.UpperDir = upperDir Namespaces: namespaces,
} }, nil
if err := common.SaveMetadata(checkpointDir, meta); err != nil { }
return nil, fmt.Errorf("failed to save metadata: %w", err)
}
// 10. Remove semaphores from /dev/shm before checkpoint // configure builds CRIU options and checkpoint manifest from runtime snapshot and spec.
// Semaphores cause CRIU restore to fail with "Can't link dev/shm/link_remap.X -> dev/shm/sem.Y" func (c *Checkpointer) configure(
if err := c.removeSemaphores(pid); err != nil { state *ContainerInfoSnapshot,
return nil, fmt.Errorf("failed to remove semaphores: %w", err) req CheckpointRequest,
spec *CheckpointSpec,
checkpointDir string,
imageDirFD int32,
) (*criurpc.CriuOpts, *CheckpointManifest, error) {
criuOpts, err := BuildCRIUDumpOptions(
&spec.CRIU,
state.PID,
imageDirFD,
state.RootFS,
state.MountInfo,
state.OCISpec,
state.Namespaces,
)
if err != nil {
return nil, nil, err
} }
// 11. Execute CRIU dump via go-criu // Write CRIU config file (for options unavailable via RPC)
criuDumpStart := time.Now() configPath := filepath.Join(checkpointDir, CheckpointCRIUConfFilename)
criuClient := criu.MakeCriu() if err := os.WriteFile(configPath, []byte(spec.CRIU.GenerateCRIUConfContent()), 0644); err != nil {
if err := criuClient.Dump(criuOpts, nil); err != nil { return nil, nil, fmt.Errorf("failed to write CRIU config file: %w", err)
c.log.WithField("duration", time.Since(criuDumpStart)).Error("CRIU dump failed")
return nil, fmt.Errorf("CRIU dump failed: %w", err)
} }
criuDumpDuration := time.Since(criuDumpStart) criuOpts.ConfigFile = proto.String(configPath)
c.log.WithField("duration", criuDumpDuration).Info("CRIU dump completed successfully")
// 12. Capture rootfs diff and deleted files // Build and save the checkpoint manifest.
rootfsCaptureStart := time.Now() manifest := NewCheckpointManifest(
CaptureRootfsState(upperDir, checkpointDir, meta, c.log) req.CheckpointID,
c.log.WithField("duration", time.Since(rootfsCaptureStart)).Info("Rootfs capture completed") NewCRIUDumpManifest(criuOpts, spec.CRIU),
NewSourcePodManifest(req, state.PID),
NewFilesystemManifest(spec.RootfsExclusions, state.UpperDir, state.OCISpec),
NewNamespaceManifestEntries(state.Namespaces),
)
totalDuration := time.Since(checkpointStart) if err := WriteCheckpointManifest(checkpointDir, manifest); err != nil {
c.log.WithFields(logrus.Fields{ return nil, nil, fmt.Errorf("failed to write checkpoint manifest: %w", err)
"total_duration": totalDuration, }
"criu_dump_duration": criuDumpDuration,
}).Info("=== Checkpoint operation completed ===")
return &Result{ return criuOpts, manifest, nil
CheckpointID: opts.CheckpointID,
CheckpointDir: checkpointDir,
Metadata: meta,
}, nil
} }
// removeSemaphores removes POSIX semaphores from the container's /dev/shm. // capture executes the CRIU dump and post-dump captures (/dev/shm, rootfs diff).
// Semaphores can cause issues during CRIU checkpoint/restore because they // Returns the CRIU dump duration for timing reporting.
// maintain kernel state that may not transfer correctly between processes. func (c *Checkpointer) capture(
// This accesses the container's filesystem via /proc/<pid>/root/dev/shm/. criuOpts *criurpc.CriuOpts,
func (c *Checkpointer) removeSemaphores(pid int) error { data *CheckpointManifest,
shmPath := filepath.Join(c.hostProc, fmt.Sprintf("%d/root/dev/shm", pid)) state *ContainerInfoSnapshot,
checkpointDir string,
entries, err := os.ReadDir(shmPath) ) (time.Duration, error) {
criuDumpDuration, err := ExecuteCRIUDump(criuOpts, checkpointDir, c.log)
if err != nil { if err != nil {
// It's okay if /dev/shm doesn't exist (container may not have it) return 0, err
c.log.WithError(err).Debug("Could not read container /dev/shm (may not exist)")
return nil
} }
var removed []string // Capture /dev/shm contents (must happen after dump for final process state)
var errors []error if err := CaptureDevShm(state.PID, checkpointDir, c.log); err != nil {
for _, entry := range entries { c.log.WithError(err).Warn("Failed to capture /dev/shm contents")
name := entry.Name()
if strings.HasPrefix(name, "sem.") {
semPath := filepath.Join(shmPath, name)
if err := os.Remove(semPath); err != nil {
c.log.WithError(err).WithField("semaphore", name).Error("Failed to remove semaphore")
errors = append(errors, fmt.Errorf("failed to remove semaphore %s: %w", name, err))
} else {
removed = append(removed, name)
}
}
} }
if len(errors) > 0 { // Capture rootfs diff and deleted files
return fmt.Errorf("failed to remove %d semaphore(s): %v", len(errors), errors) CaptureRootfsState(state.UpperDir, checkpointDir, data, c.log)
}
if len(removed) > 0 {
c.log.WithFields(logrus.Fields{
"count": len(removed),
"semaphores": removed,
}).Info("Removed semaphores from container /dev/shm before checkpoint")
} else {
c.log.Debug("No semaphores found in container /dev/shm")
}
return nil return criuDumpDuration, nil
} }
// config.go defines the static checkpoint spec loaded from ConfigMap YAML.
package checkpoint
import "fmt"
// CheckpointSpec is the static checkpoint spec loaded from ConfigMap YAML.
type CheckpointSpec struct {
// BasePath is the base directory for checkpoint storage (PVC mount point).
BasePath string `yaml:"basePath"`
// CRIU options for dump operations
CRIU CRIUSettings `yaml:"criu"`
// RootfsExclusions defines paths to exclude from rootfs diff capture
RootfsExclusions FilesystemConfig `yaml:"rootfsExclusions"`
}
// Validate checks that the CheckpointSpec has valid values.
func (c *CheckpointSpec) Validate() error {
return c.RootfsExclusions.Validate()
}
// ConfigError represents a configuration validation error.
type ConfigError struct {
Field string
Message string
}
func (e *ConfigError) Error() string {
return fmt.Sprintf("config error: %s: %s", e.Field, e.Message)
}
// constants.go defines shared constants used across checkpoint and restore packages.
package checkpoint
const (
// HostProcPath is the mount point for the host's /proc in DaemonSet pods.
HostProcPath = "/host/proc"
// DevShmDirName is the directory name for captured /dev/shm contents.
DevShmDirName = "dev-shm"
// KubeLabelCheckpointSource is the pod label that triggers automatic checkpointing.
// Set by the operator on checkpoint-eligible pods.
KubeLabelCheckpointSource = "nvidia.com/checkpoint-source"
// KubeLabelCheckpointHash is the pod label specifying the checkpoint identity hash.
// Set by the operator on checkpoint-eligible pods.
KubeLabelCheckpointHash = "nvidia.com/checkpoint-hash"
// DumpLogFilename is the CRIU dump (checkpoint) log filename.
DumpLogFilename = "dump.log"
// CheckpointCRIUConfFilename is the CRIU config file written at checkpoint time.
CheckpointCRIUConfFilename = "criu.conf"
// CheckpointDoneFilename is the marker file written to the checkpoint directory
// after all checkpoint artifacts are complete. Used to detect checkpoint readiness.
// Also hard-coded in vLLM for early-exit when checkpoint already exists.
CheckpointDoneFilename = "checkpoint.done"
// CheckpointManifestFilename is the name of the manifest file in checkpoint directories.
CheckpointManifestFilename = "manifest.yaml"
// DescriptorsFilename is the name of the file descriptors file.
DescriptorsFilename = "descriptors.yaml"
// RootfsDiffFilename is the name of the rootfs diff tar in checkpoint directories.
RootfsDiffFilename = "rootfs-diff.tar"
// DeletedFilesFilename is the name of the deleted files JSON in checkpoint directories.
DeletedFilesFilename = "deleted-files.json"
)
...@@ -3,176 +3,256 @@ package checkpoint ...@@ -3,176 +3,256 @@ package checkpoint
import ( import (
"fmt" "fmt"
"os" "time"
criu "github.com/checkpoint-restore/go-criu/v7"
criurpc "github.com/checkpoint-restore/go-criu/v7/rpc" criurpc "github.com/checkpoint-restore/go-criu/v7/rpc"
specs "github.com/opencontainers/runtime-spec/specs-go"
"github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
checkpointk8s "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint/k8s"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
) )
// CRIUConfig holds configuration for CRIU dump operations. // CRIUSettings holds CRIU-specific configuration options.
// Most options are always-on with safe defaults for K8s environments. // Options are categorized by how they are passed to CRIU:
type CRIUConfig struct { // - RPC options: Passed via go-criu CriuOpts protobuf
PID int // - CRIU conf file options: Written to criu.conf (NOT available via RPC)
ImageDirFD int32 type CRIUSettings struct {
RootFS string // === RPC Options (passed via go-criu CriuOpts) ===
GhostLimit uint32 // From env CRIU_GHOST_LIMIT: max ghost file size (0 = CRIU default)
Timeout uint32 // From env CRIU_TIMEOUT: checkpoint timeout in seconds (0 = no timeout)
}
// OpenImageDir opens a checkpoint directory and prepares it for CRIU. // GhostLimit is the maximum ghost file size in bytes.
// Returns the opened file and its FD. The caller must close the file when done. // Ghost files are deleted-but-open files that CRIU needs to checkpoint.
// The file descriptor has CLOEXEC cleared so it can be inherited by CRIU. // 512MB is recommended for GPU workloads with large memory allocations.
func OpenImageDir(checkpointDir string) (*os.File, int32, error) { GhostLimit uint32 `yaml:"ghostLimit"`
return common.OpenDirForCRIU(checkpointDir)
}
// BuildCRIUOpts creates CRIU options from a config struct. // Timeout is the CRIU operation timeout in seconds.
// This sets up the base options; external mounts and namespaces are added separately. // 6 hours (21600s) is recommended for large GPU model checkpoints.
// Timeout uint32 `yaml:"timeout"`
// Always-on options for K8s:
// - LeaveRunning: always keep process running after checkpoint // LogLevel is the CRIU logging verbosity (0-4).
// - ShellJob: containers are often session leaders LogLevel int32 `yaml:"logLevel"`
// - TcpClose: pod IPs change on restore/migration
// - FileLocks: applications use file locks // WorkDir is the CRIU work directory for temporary files.
// - OrphanPtsMaster: containers with TTYs WorkDir string `yaml:"workDir"`
// - ExtUnixSk: containers have external Unix sockets
// - ManageCgroups (IGNORE): let K8s manage cgroups // AutoDedup enables auto-deduplication of memory pages.
// - LinkRemap: handle deleted-but-open files (safe for all workloads) AutoDedup bool `yaml:"autoDedup"`
// - ExtMasters: external bind mount masters (safe for all workloads)
func BuildCRIUOpts(cfg CRIUConfig) *criurpc.CriuOpts { // LazyPages enables lazy page migration (experimental).
cgMode := criurpc.CriuCgMode_IGNORE LazyPages bool `yaml:"lazyPages"`
criuOpts := &criurpc.CriuOpts{
Pid: proto.Int32(int32(cfg.PID)), // LeaveRunning keeps the process running after checkpoint (dump only).
ImagesDirFd: proto.Int32(cfg.ImageDirFD), LeaveRunning bool `yaml:"leaveRunning"`
LogLevel: proto.Int32(4),
LogFile: proto.String("dump.log"), // ShellJob allows checkpointing session leaders (containers are often session leaders).
Root: proto.String(cfg.RootFS), ShellJob bool `yaml:"shellJob"`
ManageCgroups: proto.Bool(true),
ManageCgroupsMode: &cgMode, // TcpClose closes TCP connections instead of preserving them (pod IPs change on restore).
// Always-on for K8s environments TcpClose bool `yaml:"tcpClose"`
LeaveRunning: proto.Bool(true),
ShellJob: proto.Bool(true), // FileLocks allows checkpointing processes with file locks.
TcpClose: proto.Bool(true), FileLocks bool `yaml:"fileLocks"`
FileLocks: proto.Bool(true),
OrphanPtsMaster: proto.Bool(true), // OrphanPtsMaster allows checkpointing containers with TTYs.
ExtUnixSk: proto.Bool(true), OrphanPtsMaster bool `yaml:"orphanPtsMaster"`
LinkRemap: proto.Bool(true),
ExtMasters: proto.Bool(true), // ExtUnixSk allows external Unix sockets.
} ExtUnixSk bool `yaml:"extUnixSk"`
// Optional: ghost limit from env (0 = use CRIU default) // LinkRemap handles deleted-but-open files.
if cfg.GhostLimit > 0 { LinkRemap bool `yaml:"linkRemap"`
criuOpts.GhostLimit = proto.Uint32(cfg.GhostLimit)
} // ExtMasters allows external bind mount masters.
ExtMasters bool `yaml:"extMasters"`
// Optional: timeout from env (0 = no timeout)
if cfg.Timeout > 0 { // ManageCgroupsMode controls cgroup handling: "ignore" lets K8s manage cgroups.
criuOpts.Timeout = proto.Uint32(cfg.Timeout) ManageCgroupsMode string `yaml:"manageCgroupsMode"`
}
// === CRIU Conf File Options (NOT available via RPC - written to criu.conf) ===
return criuOpts
// LibDir is the path to CRIU plugin directory (e.g., /usr/local/lib/criu).
// Required for CUDA checkpoint/restore.
LibDir string `yaml:"libDir"`
// AllowUprobes allows user-space probes (required for CUDA checkpoints).
AllowUprobes bool `yaml:"allowUprobes"`
// SkipInFlight skips in-flight TCP connections during checkpoint/restore.
SkipInFlight bool `yaml:"skipInFlight"`
} }
// AddExternalMounts adds mount points as external mounts to CRIU options. // GenerateCRIUConfContent generates the criu.conf file content for options
// CRIU requires all mounts to be marked as external for successful restore. // that cannot be passed via RPC.
func AddExternalMounts(criuOpts *criurpc.CriuOpts, mounts []AllMountInfo) { func (c *CRIUSettings) GenerateCRIUConfContent() string {
addedMounts := make(map[string]bool) var content string
for _, m := range mounts { if c.LibDir != "" {
if addedMounts[m.MountPoint] { content += "libdir " + c.LibDir + "\n"
continue
} }
criuOpts.ExtMnt = append(criuOpts.ExtMnt, &criurpc.ExtMountMap{ if c.AllowUprobes {
Key: proto.String(m.MountPoint), content += "allow-uprobes\n"
Val: proto.String(m.MountPoint), }
}) if c.SkipInFlight {
addedMounts[m.MountPoint] = true content += "skip-in-flight\n"
} }
return content
}
// ExternalMountManifestEntry is a serializable CRIU ext-mount entry in checkpoint manifests.
type ExternalMountManifestEntry struct {
Key string `yaml:"key"`
Val string `yaml:"val"`
} }
// AddExternalPaths adds additional paths (masked/readonly) as external mounts. // CRIUDumpManifest stores the resolved dump-time CRIU mount plan used for restore.
// These may not appear in mountinfo but CRIU still needs them marked as external. type CRIUDumpManifest struct {
func AddExternalPaths(criuOpts *criurpc.CriuOpts, paths []string) { CRIU CRIUSettings `yaml:"criu"`
// Build set of existing mount points ExtMnt []ExternalMountManifestEntry `yaml:"extMnt,omitempty"`
existing := make(map[string]bool) External []string `yaml:"external,omitempty"`
for _, m := range criuOpts.ExtMnt { SkipMnt []string `yaml:"skipMnt,omitempty"`
existing[m.GetKey()] = true }
// NewCRIUDumpManifest serializes resolved dump options for restore.
func NewCRIUDumpManifest(criuOpts *criurpc.CriuOpts, settings CRIUSettings) CRIUDumpManifest {
manifest := CRIUDumpManifest{CRIU: settings}
if criuOpts == nil {
return manifest
} }
for _, path := range paths { for _, mount := range criuOpts.ExtMnt {
if existing[path] { if mount == nil || mount.GetKey() == "" {
continue continue
} }
criuOpts.ExtMnt = append(criuOpts.ExtMnt, &criurpc.ExtMountMap{ manifest.ExtMnt = append(manifest.ExtMnt, ExternalMountManifestEntry{
Key: proto.String(path), Key: mount.GetKey(),
Val: proto.String(path), Val: mount.GetVal(),
}) })
existing[path] = true
} }
manifest.External = append([]string(nil), criuOpts.External...)
manifest.SkipMnt = append([]string(nil), criuOpts.SkipMnt...)
return manifest
} }
// AddExternalNamespace adds a namespace as external to CRIU options. // BuildCRIUDumpOptions creates CRIU options directly from spec settings and runtime state.
// Format: "<type>[<inode>]:<key>" func BuildCRIUDumpOptions(
func AddExternalNamespace(criuOpts *criurpc.CriuOpts, nsType NamespaceType, inode uint64, key string) { settings *CRIUSettings,
extNs := fmt.Sprintf("%s[%d]:%s", nsType, inode, key) pid int,
criuOpts.External = append(criuOpts.External, extNs) imageDirFD int32,
} rootFS string,
mountInfo []MountInfo,
ociSpec *specs.Spec,
namespaces map[NamespaceType]*NamespaceInfo,
) (*criurpc.CriuOpts, error) {
mountPolicy := BuildMountPolicy(mountInfo, ociSpec, rootFS)
// AddExternalStrings adds raw external strings to CRIU options. extMnt := buildExternalMountMaps(mountPolicy.Externalized)
// Used for additional external mount mappings (e.g., NVIDIA firmware files). skipMnt := mountPolicy.Skipped
func AddExternalStrings(criuOpts *criurpc.CriuOpts, externals []string) { external := buildExternalNamespaces(namespaces)
criuOpts.External = append(criuOpts.External, externals...) logrus.WithFields(logrus.Fields{
} "externalized_count": len(mountPolicy.Externalized),
"skipped_count": len(mountPolicy.Skipped),
}).Debug("Resolved mount policy for CRIU dump")
criuOpts := &criurpc.CriuOpts{
Pid: proto.Int32(int32(pid)),
ImagesDirFd: proto.Int32(imageDirFD),
Root: proto.String(rootFS),
LogFile: proto.String(DumpLogFilename),
}
criuOpts.ExtMnt = extMnt
criuOpts.External = external
criuOpts.SkipMnt = skipMnt
if settings == nil {
return criuOpts, nil
}
// RPC options from spec.
criuOpts.LogLevel = proto.Int32(settings.LogLevel)
criuOpts.LeaveRunning = proto.Bool(settings.LeaveRunning)
criuOpts.ShellJob = proto.Bool(settings.ShellJob)
criuOpts.TcpClose = proto.Bool(settings.TcpClose)
criuOpts.FileLocks = proto.Bool(settings.FileLocks)
criuOpts.OrphanPtsMaster = proto.Bool(settings.OrphanPtsMaster)
criuOpts.ExtUnixSk = proto.Bool(settings.ExtUnixSk)
criuOpts.LinkRemap = proto.Bool(settings.LinkRemap)
criuOpts.ExtMasters = proto.Bool(settings.ExtMasters)
criuOpts.AutoDedup = proto.Bool(settings.AutoDedup)
criuOpts.LazyPages = proto.Bool(settings.LazyPages)
// Cgroup management mode
criuOpts.ManageCgroups = proto.Bool(true)
cgMode := criurpc.CriuCgMode_IGNORE
switch settings.ManageCgroupsMode {
case "soft":
cgMode = criurpc.CriuCgMode_SOFT
case "full":
cgMode = criurpc.CriuCgMode_FULL
case "strict":
cgMode = criurpc.CriuCgMode_STRICT
}
criuOpts.ManageCgroupsMode = &cgMode
// ConfigureExternalMounts adds all required external mounts to CRIU options. // Optional numeric options
// This includes mounts from /proc/pid/mountinfo plus masked/readonly paths from OCI spec. if settings.GhostLimit > 0 {
func ConfigureExternalMounts(criuOpts *criurpc.CriuOpts, pid int, hostProc string, containerInfo *checkpointk8s.ContainerInfo) error { criuOpts.GhostLimit = proto.Uint32(settings.GhostLimit)
// Get all mounts from mountinfo - CRIU needs every mount marked as external }
allMounts, err := GetAllMountsFromMountinfo(pid, hostProc) if settings.Timeout > 0 {
if err != nil { criuOpts.Timeout = proto.Uint32(settings.Timeout)
return fmt.Errorf("failed to get all mounts from mountinfo: %w", err)
} }
// Add mounts from mountinfo return criuOpts, nil
AddExternalMounts(criuOpts, allMounts) }
// Add masked and readonly paths from OCI spec // buildExternalMountMaps serializes externalized mount paths into CRIU map entries.
AddExternalPaths(criuOpts, containerInfo.GetMaskedPaths()) func buildExternalMountMaps(paths []string) []*criurpc.ExtMountMap {
AddExternalPaths(criuOpts, containerInfo.GetReadonlyPaths()) extMnt := make([]*criurpc.ExtMountMap, 0, len(paths))
existing := make(map[string]struct{}, len(paths))
for _, path := range paths {
if path == "" {
continue
}
if _, ok := existing[path]; ok {
continue
}
extMnt = append(extMnt, &criurpc.ExtMountMap{
Key: proto.String(path),
Val: proto.String(path),
})
existing[path] = struct{}{}
}
return nil return extMnt
} }
// ConfigureExternalNamespaces adds external namespaces to CRIU options. // buildExternalNamespaces builds external namespace/mount references.
// Returns the network namespace inode if found, for logging purposes. func buildExternalNamespaces(namespaces map[NamespaceType]*NamespaceInfo) []string {
func ConfigureExternalNamespaces(criuOpts *criurpc.CriuOpts, namespaces map[NamespaceType]*NamespaceInfo, externalMounts []string) uint64 { external := make([]string, 0, 1)
var netNsInode uint64
// Mark network namespace as external for socket binding preservation // Mark network namespace as external for socket binding preservation
if netNs, ok := namespaces[NamespaceNet]; ok { if netNs, ok := namespaces[NamespaceNet]; ok {
AddExternalNamespace(criuOpts, NamespaceNet, netNs.Inode, "extNetNs") external = append(external, fmt.Sprintf("%s[%d]:%s", NamespaceNet, netNs.Inode, "extNetNs"))
netNsInode = netNs.Inode logrus.WithField("inode", netNs.Inode).Debug("Marked network namespace as external")
} }
// Add additional external mounts (e.g., for NVIDIA firmware files) return external
AddExternalStrings(criuOpts, externalMounts)
return netNsInode
} }
// BuildCRIUOptsFromCheckpointOpts constructs CRIU options from checkpoint Options. // ExecuteCRIUDump runs the CRIU dump and logs timing plus dump-log location on failure.
// Returns the configured CriuOpts ready for external mount/namespace configuration. func ExecuteCRIUDump(criuOpts *criurpc.CriuOpts, checkpointDir string, log *logrus.Entry) (time.Duration, error) {
func BuildCRIUOptsFromCheckpointOpts(opts Options, pid int, imageDirFD int32, rootFS string) *criurpc.CriuOpts { criuDumpStart := time.Now()
cfg := CRIUConfig{ criuClient := criu.MakeCriu()
PID: pid, if err := criuClient.Dump(criuOpts, nil); err != nil {
ImageDirFD: imageDirFD, dumpDuration := time.Since(criuDumpStart)
RootFS: rootFS, log.WithFields(logrus.Fields{
GhostLimit: opts.GhostLimit, "duration": dumpDuration,
Timeout: opts.Timeout, "checkpoint_dir": checkpointDir,
"dump_log_path": fmt.Sprintf("%s/%s", checkpointDir, DumpLogFilename),
}).Error("CRIU dump failed")
return 0, fmt.Errorf("CRIU dump failed: %w", err)
} }
return BuildCRIUOpts(cfg) criuDumpDuration := time.Since(criuDumpStart)
log.WithField("duration", criuDumpDuration).Info("CRIU dump completed")
return criuDumpDuration, nil
} }
// rootfs provides container rootfs introspection via /proc for CRIU checkpoint. // filesystem.go provides container rootfs introspection, filesystem config/metadata types,
// and rootfs diff capture for CRIU checkpoint.
package checkpoint package checkpoint
import ( import (
"bufio"
"encoding/json" "encoding/json"
"fmt" "fmt"
"os" "os"
...@@ -10,131 +10,152 @@ import ( ...@@ -10,131 +10,152 @@ import (
"path/filepath" "path/filepath"
"strings" "strings"
specs "github.com/opencontainers/runtime-spec/specs-go"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
) )
// GetRootFS returns the container's root filesystem path // FilesystemConfig is the static config for rootfs exclusions (from values.yaml).
// For containers using overlayfs, this extracts the upperdir type FilesystemConfig struct {
func GetRootFS(pid int, hostProc string) (string, error) { // SystemDirs are system directories that should be excluded from rootfs diff.
if hostProc == "" { // These directories are typically injected/bind-mounted by NVIDIA GPU Operator
hostProc = "/proc" // at container start time, so they already exist in the restore target.
} // Excluding them prevents conflicts (especially socket files which cannot be overwritten).
// Default: ["./usr", "./etc", "./opt", "./var", "./run"]
// The rootfs is accessible via /proc/<pid>/root SystemDirs []string `yaml:"systemDirs"`
// But for CRIU, we need the actual filesystem path
rootPath := fmt.Sprintf("%s/%d/root", hostProc, pid) // CacheDirs are cache directories that can safely be excluded to reduce checkpoint size.
// Model weights and other cached data are typically re-downloaded if needed.
// Default: ["./.cache/huggingface", "./.cache/torch"]
CacheDirs []string `yaml:"cacheDirs"`
// AdditionalExclusions are custom paths to exclude from the rootfs diff.
// Use this for application-specific exclusions.
// Paths should be relative with "./" prefix (e.g., "./data/temp").
AdditionalExclusions []string `yaml:"additionalExclusions"`
}
// Verify it exists // GetAllExclusions returns all exclusion paths combined.
if _, err := os.Stat(rootPath); err != nil { // This is used when building tar arguments for rootfs diff capture.
return "", fmt.Errorf("rootfs not accessible at %s: %w", rootPath, err) func (c *FilesystemConfig) GetAllExclusions() []string {
if c == nil {
return nil
} }
total := len(c.SystemDirs) + len(c.CacheDirs) + len(c.AdditionalExclusions)
return rootPath, nil exclusions := make([]string, 0, total)
exclusions = append(exclusions, c.SystemDirs...)
exclusions = append(exclusions, c.CacheDirs...)
exclusions = append(exclusions, c.AdditionalExclusions...)
return exclusions
} }
// GetOverlayUpperDir extracts the overlay upperdir from mountinfo // Validate checks that the FilesystemConfig has valid values.
// This is the writable layer of the container's filesystem func (c *FilesystemConfig) Validate() error {
func GetOverlayUpperDir(pid int, hostProc string) (string, error) { if c == nil {
if hostProc == "" { return nil
hostProc = "/proc" }
// All paths should start with "./" for tar relative path handling
for _, path := range c.GetAllExclusions() {
if !strings.HasPrefix(path, "./") {
return &ConfigError{
Field: "rootfsExclusions",
Message: "all exclusion paths must start with './' (got: " + path + ")",
} }
mountinfoPath := fmt.Sprintf("%s/%d/mountinfo", hostProc, pid)
file, err := os.Open(mountinfoPath)
if err != nil {
return "", fmt.Errorf("failed to open mountinfo: %w", err)
} }
defer file.Close() }
return nil
}
scanner := bufio.NewScanner(file) // FilesystemManifest holds runtime filesystem state captured at checkpoint time.
for scanner.Scan() { type FilesystemManifest struct {
line := scanner.Text() Exclusions FilesystemConfig `yaml:"exclusions"`
fields := strings.Fields(line) UpperDir string `yaml:"upperDir,omitempty"`
ExternalPaths []string `yaml:"externalPaths,omitempty"`
BindMountDests []string `yaml:"bindMountDests,omitempty"`
HasRootfsDiff bool `yaml:"hasRootfsDiff"`
HasDeletedFiles bool `yaml:"hasDeletedFiles"`
}
// Look for the root mount (mount point is /) // NewFilesystemManifest constructs FilesystemManifest from config, overlay state, and OCI spec.
// mountinfo format: id parent major:minor root mount-point options ... - fstype source super-options func NewFilesystemManifest(exclusions FilesystemConfig, upperDir string, ociSpec *specs.Spec) FilesystemManifest {
if len(fields) < 5 { meta := FilesystemManifest{
continue Exclusions: exclusions,
UpperDir: upperDir,
} }
if ociSpec == nil {
mountPoint := fields[4] return meta
if mountPoint != "/" {
continue
} }
// Find the separator (-) to get fstype and options if ociSpec.Linux != nil {
sepIdx := -1 meta.ExternalPaths = make([]string, 0, len(ociSpec.Linux.MaskedPaths)+len(ociSpec.Linux.ReadonlyPaths))
for i, f := range fields { meta.ExternalPaths = append(meta.ExternalPaths, ociSpec.Linux.MaskedPaths...)
if f == "-" { meta.ExternalPaths = append(meta.ExternalPaths, ociSpec.Linux.ReadonlyPaths...)
sepIdx = i
break
} }
for _, m := range ociSpec.Mounts {
if m.Type == "bind" {
meta.BindMountDests = append(meta.BindMountDests, m.Destination)
} }
}
return meta
}
if sepIdx == -1 || sepIdx+2 >= len(fields) { // GetRootFS returns the container's root filesystem path.
continue func GetRootFS(pid int) (string, error) {
rootPath := fmt.Sprintf("%s/%d/root", HostProcPath, pid)
if _, err := os.Stat(rootPath); err != nil {
return "", fmt.Errorf("rootfs not accessible at %s: %w", rootPath, err)
} }
fsType := fields[sepIdx+1] return rootPath, nil
if fsType != "overlay" { }
// GetOverlayUpperDir extracts the overlay upperdir from mountinfo.
// This is the writable layer of the container's filesystem.
func GetOverlayUpperDir(pid int) (string, error) {
mountInfo, err := ReadMountInfoFromHostProcPath(pid)
if err != nil {
return "", fmt.Errorf("failed to parse mountinfo: %w", err)
}
for _, mount := range mountInfo {
if mount.MountPoint != "/" || mount.FSType != "overlay" {
continue continue
} }
// Parse super options to find upperdir // Parse super options to find upperdir
superOptions := fields[sepIdx+3] for _, opt := range strings.Split(mount.SuperOptions, ",") {
for _, opt := range strings.Split(superOptions, ",") {
if strings.HasPrefix(opt, "upperdir=") { if strings.HasPrefix(opt, "upperdir=") {
return strings.TrimPrefix(opt, "upperdir="), nil return strings.TrimPrefix(opt, "upperdir="), nil
} }
} }
} }
if err := scanner.Err(); err != nil {
return "", fmt.Errorf("error reading mountinfo: %w", err)
}
return "", fmt.Errorf("overlay upperdir not found for pid %d", pid) return "", fmt.Errorf("overlay upperdir not found for pid %d", pid)
} }
// DefaultRootfsDiffExclusions are paths excluded from the rootfs diff capture.
// These directories are injected/bind-mounted by NVIDIA GPU Operator at container
// start time, so they already exist in the restore target and cause conflicts
// (especially socket files which cannot be overwritten).
var DefaultRootfsDiffExclusions = []string{
// NVIDIA GPU Operator injects drivers, libraries, and config here
"./usr",
"./etc",
"./opt",
"./var",
// NVIDIA GPU Operator creates runtime sockets and firmware mounts here
// Socket files cause fatal tar errors even with --keep-old-files
"./run",
}
// CaptureRootfsDiff captures the overlay upperdir to a tar file. // CaptureRootfsDiff captures the overlay upperdir to a tar file.
// The upperdir contains all filesystem modifications made by the container. // The upperdir contains all filesystem modifications made by the container.
// Excludes bind mount destinations and system directories to avoid conflicts during restore. // Excludes bind mount destinations and configured directories to avoid conflicts during restore.
// Returns the path to the tar file or empty string if capture failed. // Returns the path to the tar file or empty string if capture failed.
func CaptureRootfsDiff(upperDir, checkpointDir string, excludePaths []string) (string, error) { func CaptureRootfsDiff(upperDir, checkpointDir string, exclusions *FilesystemConfig, bindMountDests []string) (string, error) {
if upperDir == "" { if upperDir == "" {
return "", fmt.Errorf("upperdir is empty") return "", fmt.Errorf("upperdir is empty")
} }
rootfsDiffPath := filepath.Join(checkpointDir, "rootfs-diff.tar") rootfsDiffPath := filepath.Join(checkpointDir, RootfsDiffFilename)
// Build tar arguments with xattrs and exclusions // Build tar arguments with xattrs and exclusions
tarArgs := []string{"--xattrs"} tarArgs := []string{"--xattrs"}
// Add default exclusions for system directories and caches // Add configured exclusions (systemDirs, cacheDirs, additionalExclusions from values.yaml)
for _, excl := range DefaultRootfsDiffExclusions { if exclusions != nil {
for _, excl := range exclusions.GetAllExclusions() {
tarArgs = append(tarArgs, "--exclude="+excl) tarArgs = append(tarArgs, "--exclude="+excl)
} }
}
// Add bind mount exclusions passed from caller // Add bind mount exclusions passed from caller
for _, dest := range excludePaths { for _, dest := range bindMountDests {
// Convert absolute path to relative for tar (e.g., /etc/hosts -> ./etc/hosts) // Convert absolute path to relative for tar (e.g., /etc/hosts -> ./etc/hosts)
tarArgs = append(tarArgs, "--exclude=."+dest) tarArgs = append(tarArgs, "--exclude=."+dest)
} }
...@@ -165,7 +186,7 @@ func CaptureDeletedFiles(upperDir, checkpointDir string) (bool, error) { ...@@ -165,7 +186,7 @@ func CaptureDeletedFiles(upperDir, checkpointDir string) (bool, error) {
return false, nil return false, nil
} }
deletedFilesPath := filepath.Join(checkpointDir, "deleted-files.json") deletedFilesPath := filepath.Join(checkpointDir, DeletedFilesFilename)
data, err := json.Marshal(whiteouts) data, err := json.Marshal(whiteouts)
if err != nil { if err != nil {
return false, fmt.Errorf("failed to marshal whiteouts: %w", err) return false, fmt.Errorf("failed to marshal whiteouts: %w", err)
...@@ -195,11 +216,11 @@ func FindWhiteoutFiles(upperDir string) ([]string, error) { ...@@ -195,11 +216,11 @@ func FindWhiteoutFiles(upperDir string) ([]string, error) {
relPath, _ := filepath.Rel(upperDir, path) relPath, _ := filepath.Rel(upperDir, path)
dir := filepath.Dir(relPath) dir := filepath.Dir(relPath)
deletedFile := strings.TrimPrefix(name, ".wh.") deletedFile := strings.TrimPrefix(name, ".wh.")
if dir == "." { deletedPath := deletedFile
whiteouts = append(whiteouts, deletedFile) if dir != "." {
} else { deletedPath = filepath.Join(dir, deletedFile)
whiteouts = append(whiteouts, filepath.Join(dir, deletedFile))
} }
whiteouts = append(whiteouts, deletedPath)
} }
return nil return nil
}) })
...@@ -208,23 +229,23 @@ func FindWhiteoutFiles(upperDir string) ([]string, error) { ...@@ -208,23 +229,23 @@ func FindWhiteoutFiles(upperDir string) ([]string, error) {
} }
// CaptureRootfsState captures the overlay upperdir and deleted files after CRIU dump. // CaptureRootfsState captures the overlay upperdir and deleted files after CRIU dump.
// Updates the metadata with rootfs diff information and saves it. // Updates the checkpoint manifest with rootfs diff information and saves it.
func CaptureRootfsState(upperDir, checkpointDir string, meta *common.CheckpointMetadata, log *logrus.Entry) { func CaptureRootfsState(upperDir, checkpointDir string, data *CheckpointManifest, log *logrus.Entry) {
if upperDir == "" { if upperDir == "" || data == nil {
return return
} }
// Capture rootfs diff // Capture rootfs diff using exclusions from the checkpoint manifest.
configuredExclusions := data.Filesystem.Exclusions.GetAllExclusions()
log.WithFields(logrus.Fields{ log.WithFields(logrus.Fields{
"default_exclusions": DefaultRootfsDiffExclusions, "configured_exclusions": configuredExclusions,
"bind_mount_exclusions": meta.BindMountDests, "bind_mount_exclusions": data.Filesystem.BindMountDests,
}).Debug("Rootfs diff exclusions") }).Debug("Rootfs diff exclusions")
rootfsDiffPath, err := CaptureRootfsDiff(upperDir, checkpointDir, meta.BindMountDests) rootfsDiffPath, err := CaptureRootfsDiff(upperDir, checkpointDir, &data.Filesystem.Exclusions, data.Filesystem.BindMountDests)
if err != nil { if err != nil {
log.WithError(err).Warn("Failed to capture rootfs diff") log.WithError(err).Warn("Failed to capture rootfs diff")
} else { } else {
meta.RootfsDiffPath = rootfsDiffPath data.Filesystem.HasRootfsDiff = true
meta.HasRootfsDiff = true
log.WithFields(logrus.Fields{ log.WithFields(logrus.Fields{
"upperdir": upperDir, "upperdir": upperDir,
"tar_path": rootfsDiffPath, "tar_path": rootfsDiffPath,
...@@ -236,12 +257,12 @@ func CaptureRootfsState(upperDir, checkpointDir string, meta *common.CheckpointM ...@@ -236,12 +257,12 @@ func CaptureRootfsState(upperDir, checkpointDir string, meta *common.CheckpointM
if err != nil { if err != nil {
log.WithError(err).Warn("Failed to capture deleted files") log.WithError(err).Warn("Failed to capture deleted files")
} else if hasDeletedFiles { } else if hasDeletedFiles {
meta.HasDeletedFiles = true data.Filesystem.HasDeletedFiles = true
log.Info("Recorded deleted files (whiteouts)") log.Info("Recorded deleted files (whiteouts)")
} }
// Update metadata with rootfs diff info // Update checkpoint manifest with rootfs diff info.
if err := common.SaveMetadata(checkpointDir, meta); err != nil { if err := WriteCheckpointManifest(checkpointDir, data); err != nil {
log.WithError(err).Warn("Failed to update metadata with rootfs diff info") log.WithError(err).Warn("Failed to update checkpoint manifest with rootfs diff info")
} }
} }
// k8s contains containerd discovery and Kubernetes path classification helpers.
package checkpoint
import (
"context"
"fmt"
"github.com/containerd/containerd"
"github.com/containerd/containerd/namespaces"
specs "github.com/opencontainers/runtime-spec/specs-go"
)
const (
// K8sNamespace is the containerd namespace used by Kubernetes.
K8sNamespace = "k8s.io"
// ContainerdSocket is the default containerd socket path.
ContainerdSocket = "/run/containerd/containerd.sock"
)
type SourcePodManifest struct {
ContainerID string `yaml:"containerId"`
PID int `yaml:"pid"`
SourceNode string `yaml:"sourceNode"`
PodName string `yaml:"podName"`
PodNamespace string `yaml:"podNamespace"`
}
func NewSourcePodManifest(params CheckpointRequest, pid int) SourcePodManifest {
return SourcePodManifest{
ContainerID: params.ContainerID,
PID: pid,
SourceNode: params.NodeName,
PodName: params.PodName,
PodNamespace: params.PodNamespace,
}
}
type DiscoveryClient struct {
client *containerd.Client
}
func NewDiscoveryClient() (*DiscoveryClient, error) {
client, err := containerd.New(ContainerdSocket)
if err != nil {
return nil, fmt.Errorf("failed to connect to containerd at %s: %w", ContainerdSocket, err)
}
return &DiscoveryClient{client: client}, nil
}
func (c *DiscoveryClient) Close() error {
if c.client != nil {
return c.client.Close()
}
return nil
}
func (c *DiscoveryClient) ResolveContainer(ctx context.Context, containerID string) (int, *specs.Spec, error) {
ctx = namespaces.WithNamespace(ctx, K8sNamespace)
container, err := c.client.LoadContainer(ctx, containerID)
if err != nil {
return 0, nil, fmt.Errorf("failed to load container %s: %w", containerID, err)
}
task, err := container.Task(ctx, nil)
if err != nil {
return 0, nil, fmt.Errorf("failed to get task for container %s: %w", containerID, err)
}
pid := task.Pid()
spec, err := container.Spec(ctx)
if err != nil {
return 0, nil, fmt.Errorf("failed to get spec for container %s: %w", containerID, err)
}
return int(pid), spec, nil
}
// discovery provides container information resolution via containerd.
// This prefers containerd RPCs for configuration over /proc inspection,
// following the principle that configuration should come from the container runtime
// while runtime state (like namespace inodes) requires /proc.
package k8s
import (
"context"
"fmt"
"os"
"path/filepath"
"github.com/containerd/containerd"
"github.com/containerd/containerd/namespaces"
specs "github.com/opencontainers/runtime-spec/specs-go"
)
const (
// K8sNamespace is the containerd namespace used by Kubernetes
K8sNamespace = "k8s.io"
// DefaultSocket is the default containerd socket path
DefaultSocket = "/run/containerd/containerd.sock"
)
// ContainerInfo holds resolved container information from containerd.
// Configuration data comes from containerd RPCs, runtime state from /proc.
type ContainerInfo struct {
ContainerID string
PID uint32
RootFS string // Actual rootfs path (bundle path + spec.Root.Path)
BundlePath string // Path to container bundle directory
Image string
Spec *specs.Spec // OCI spec from containerd (mounts, namespaces config)
Labels map[string]string
}
// MountInfo represents a mount from the OCI spec.
type MountInfo struct {
Destination string // Mount point inside container
Source string // Source path on host
Type string // Filesystem type (bind, tmpfs, etc.)
Options []string // Mount options
}
// NamespaceConfig represents namespace configuration from OCI spec.
type NamespaceConfig struct {
Type string // Namespace type (network, pid, mount, etc.)
Path string // Path to namespace (empty for new namespace)
}
// DiscoveryClient wraps the containerd client for container discovery.
type DiscoveryClient struct {
client *containerd.Client
socket string
}
// NewDiscoveryClient creates a new discovery client.
func NewDiscoveryClient(socket string) (*DiscoveryClient, error) {
if socket == "" {
socket = DefaultSocket
}
client, err := containerd.New(socket)
if err != nil {
return nil, fmt.Errorf("failed to connect to containerd at %s: %w", socket, err)
}
return &DiscoveryClient{
client: client,
socket: socket,
}, nil
}
// Close closes the containerd client connection.
func (c *DiscoveryClient) Close() error {
if c.client != nil {
return c.client.Close()
}
return nil
}
// ResolveContainer resolves a container ID to its process information.
// This retrieves configuration from containerd RPCs (OCI spec, labels, image)
// and runtime paths from /proc (rootfs access path).
func (c *DiscoveryClient) ResolveContainer(ctx context.Context, containerID string) (*ContainerInfo, error) {
// Use the Kubernetes namespace for containerd
ctx = namespaces.WithNamespace(ctx, K8sNamespace)
// Load the container
container, err := c.client.LoadContainer(ctx, containerID)
if err != nil {
return nil, fmt.Errorf("failed to load container %s: %w", containerID, err)
}
// Get the task (running process)
task, err := container.Task(ctx, nil)
if err != nil {
return nil, fmt.Errorf("failed to get task for container %s: %w", containerID, err)
}
// Get the PID
pid := task.Pid()
// Get container image
image, err := container.Image(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get image for container %s: %w", containerID, err)
}
// Get OCI spec from containerd - this contains mount config, namespace config, etc.
// This is preferred over parsing /proc for configuration data.
spec, err := container.Spec(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get spec for container %s: %w", containerID, err)
}
// Get container labels (includes K8s pod info)
labels, err := container.Labels(ctx)
if err != nil {
// Labels are optional, don't fail
labels = make(map[string]string)
}
// Construct the bundle path where containerd stores the container runtime files
// Standard containerd layout: /run/containerd/io.containerd.runtime.v2.task/<namespace>/<container_id>/
containerdRunRoot := os.Getenv("CONTAINERD_RUN_ROOT")
if containerdRunRoot == "" {
containerdRunRoot = "/run/containerd"
}
bundlePath := filepath.Join(containerdRunRoot, "io.containerd.runtime.v2.task", K8sNamespace, containerID)
// Get the rootfs path from the OCI spec (usually "rootfs" relative to bundle)
rootfsRelPath := "rootfs"
if spec.Root != nil && spec.Root.Path != "" {
rootfsRelPath = spec.Root.Path
}
// Construct full rootfs path
var rootFS string
if filepath.IsAbs(rootfsRelPath) {
rootFS = rootfsRelPath
} else {
rootFS = filepath.Join(bundlePath, rootfsRelPath)
}
return &ContainerInfo{
ContainerID: containerID,
PID: pid,
RootFS: rootFS,
BundlePath: bundlePath,
Image: image.Name(),
Spec: spec,
Labels: labels,
}, nil
}
// GetMounts returns the mount configuration from the OCI spec.
// This is preferred over parsing /proc/mountinfo for configuration,
// though /proc is still needed for runtime mount state.
func (info *ContainerInfo) GetMounts() []MountInfo {
if info.Spec == nil || info.Spec.Mounts == nil {
return nil
}
mounts := make([]MountInfo, len(info.Spec.Mounts))
for i, m := range info.Spec.Mounts {
mounts[i] = MountInfo{
Destination: m.Destination,
Source: m.Source,
Type: m.Type,
Options: m.Options,
}
}
return mounts
}
// GetNamespaces returns the namespace configuration from the OCI spec.
func (info *ContainerInfo) GetNamespaces() []NamespaceConfig {
if info.Spec == nil || info.Spec.Linux == nil {
return nil
}
namespaces := make([]NamespaceConfig, len(info.Spec.Linux.Namespaces))
for i, ns := range info.Spec.Linux.Namespaces {
namespaces[i] = NamespaceConfig{
Type: string(ns.Type),
Path: ns.Path,
}
}
return namespaces
}
// GetMaskedPaths returns the masked paths from the OCI spec.
func (info *ContainerInfo) GetMaskedPaths() []string {
if info.Spec == nil || info.Spec.Linux == nil {
return nil
}
return info.Spec.Linux.MaskedPaths
}
// GetReadonlyPaths returns the readonly paths from the OCI spec.
func (info *ContainerInfo) GetReadonlyPaths() []string {
if info.Spec == nil || info.Spec.Linux == nil {
return nil
}
return info.Spec.Linux.ReadonlyPaths
}
// GetRootfsPath returns the rootfs path from the OCI spec.
// Note: For CRIU, use info.RootFS which is the /proc/<pid>/root path.
func (info *ContainerInfo) GetRootfsPath() string {
if info.Spec == nil || info.Spec.Root == nil {
return ""
}
return info.Spec.Root.Path
}
// IsRootReadonly returns whether the root filesystem is readonly.
func (info *ContainerInfo) IsRootReadonly() bool {
if info.Spec == nil || info.Spec.Root == nil {
return false
}
return info.Spec.Root.Readonly
}
// GetHostname returns the container's hostname from the OCI spec.
func (info *ContainerInfo) GetHostname() string {
if info.Spec == nil {
return ""
}
return info.Spec.Hostname
}
// ListContainers lists all containers in the K8s namespace.
func (c *DiscoveryClient) ListContainers(ctx context.Context) ([]string, error) {
ctx = namespaces.WithNamespace(ctx, K8sNamespace)
containers, err := c.client.Containers(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list containers: %w", err)
}
ids := make([]string, len(containers))
for i, container := range containers {
ids[i] = container.ID()
}
return ids, nil
}
// GetContainerLabels returns the labels for a container.
func (c *DiscoveryClient) GetContainerLabels(ctx context.Context, containerID string) (map[string]string, error) {
ctx = namespaces.WithNamespace(ctx, K8sNamespace)
container, err := c.client.LoadContainer(ctx, containerID)
if err != nil {
return nil, fmt.Errorf("failed to load container %s: %w", containerID, err)
}
labels, err := container.Labels(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get labels for container %s: %w", containerID, err)
}
return labels, nil
}
// Package k8s provides Kubernetes-specific functionality for checkpoint operations.
// This includes volume type discovery via K8s API and containerd container discovery.
package k8s
import (
"context"
"fmt"
"strings"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd"
)
// VolumeInfo contains Kubernetes volume information for a mount.
type VolumeInfo struct {
VolumeName string // Name from pod.spec.volumes[].name
VolumeType string // Type: emptyDir, configMap, secret, persistentVolumeClaim, etc.
MountPath string // Container path from volumeMounts[].mountPath
SubPath string // SubPath if specified
ReadOnly bool // Whether mount is read-only
// Type-specific details
ConfigMapName string // For configMap volumes
SecretName string // For secret volumes
PVCName string // For persistentVolumeClaim volumes
}
// K8sClient wraps the Kubernetes clientset for volume discovery.
type K8sClient struct {
clientset *kubernetes.Clientset
}
// NewK8sClient creates a new Kubernetes client.
// It attempts in-cluster config first, then falls back to kubeconfig.
func NewK8sClient() (*K8sClient, error) {
config, err := rest.InClusterConfig()
if err != nil {
// Fall back to kubeconfig for local development
config, err = clientcmd.BuildConfigFromFlags("", clientcmd.RecommendedHomeFile)
if err != nil {
return nil, fmt.Errorf("failed to create k8s config: %w", err)
}
}
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
return nil, fmt.Errorf("failed to create k8s clientset: %w", err)
}
return &K8sClient{clientset: clientset}, nil
}
// NewK8sClientWithConfig creates a client with explicit config.
func NewK8sClientWithConfig(config *rest.Config) (*K8sClient, error) {
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
return nil, fmt.Errorf("failed to create k8s clientset: %w", err)
}
return &K8sClient{clientset: clientset}, nil
}
// GetPodVolumes returns volume information for all mounts in a container.
// Returns a map from mount path to VolumeInfo.
func (c *K8sClient) GetPodVolumes(ctx context.Context, namespace, podName, containerName string) (map[string]*VolumeInfo, error) {
pod, err := c.clientset.CoreV1().Pods(namespace).Get(ctx, podName, metav1.GetOptions{})
if err != nil {
return nil, fmt.Errorf("failed to get pod %s/%s: %w", namespace, podName, err)
}
return ExtractVolumeInfo(pod, containerName)
}
// ExtractVolumeInfo extracts volume information from a Pod spec.
// This is the core logic that maps volumeMounts to volumes and determines types.
func ExtractVolumeInfo(pod *corev1.Pod, containerName string) (map[string]*VolumeInfo, error) {
// Build volume name -> type mapping from pod.spec.volumes
volumeTypes := make(map[string]*volumeDetails)
for _, vol := range pod.Spec.Volumes {
volumeTypes[vol.Name] = getVolumeDetails(&vol)
}
// Find the target container
var container *corev1.Container
for i := range pod.Spec.Containers {
if pod.Spec.Containers[i].Name == containerName {
container = &pod.Spec.Containers[i]
break
}
}
if container == nil {
// Try init containers
for i := range pod.Spec.InitContainers {
if pod.Spec.InitContainers[i].Name == containerName {
container = &pod.Spec.InitContainers[i]
break
}
}
}
if container == nil {
return nil, fmt.Errorf("container %s not found in pod", containerName)
}
// Build mount path -> volume info mapping
result := make(map[string]*VolumeInfo)
for _, mount := range container.VolumeMounts {
details, ok := volumeTypes[mount.Name]
if !ok {
continue // Mount references unknown volume
}
result[mount.MountPath] = &VolumeInfo{
VolumeName: mount.Name,
VolumeType: details.volumeType,
MountPath: mount.MountPath,
SubPath: mount.SubPath,
ReadOnly: mount.ReadOnly,
ConfigMapName: details.configMapName,
SecretName: details.secretName,
PVCName: details.pvcName,
}
}
return result, nil
}
// volumeDetails holds extracted volume type information.
type volumeDetails struct {
volumeType string
configMapName string
secretName string
pvcName string
}
// getVolumeDetails extracts type and details from a Volume spec.
func getVolumeDetails(vol *corev1.Volume) *volumeDetails {
d := &volumeDetails{volumeType: "unknown"}
switch {
case vol.EmptyDir != nil:
d.volumeType = "emptyDir"
case vol.ConfigMap != nil:
d.volumeType = "configMap"
d.configMapName = vol.ConfigMap.Name
case vol.Secret != nil:
d.volumeType = "secret"
d.secretName = vol.Secret.SecretName
case vol.PersistentVolumeClaim != nil:
d.volumeType = "persistentVolumeClaim"
d.pvcName = vol.PersistentVolumeClaim.ClaimName
case vol.HostPath != nil:
d.volumeType = "hostPath"
case vol.Projected != nil:
d.volumeType = "projected"
case vol.DownwardAPI != nil:
d.volumeType = "downwardAPI"
case vol.CSI != nil:
d.volumeType = "csi"
case vol.NFS != nil:
d.volumeType = "nfs"
case vol.ISCSI != nil:
d.volumeType = "iscsi"
case vol.GCEPersistentDisk != nil:
d.volumeType = "gcePersistentDisk"
case vol.AWSElasticBlockStore != nil:
d.volumeType = "awsElasticBlockStore"
case vol.AzureDisk != nil:
d.volumeType = "azureDisk"
case vol.AzureFile != nil:
d.volumeType = "azureFile"
case vol.CephFS != nil:
d.volumeType = "cephfs"
case vol.Cinder != nil:
d.volumeType = "cinder"
case vol.FC != nil:
d.volumeType = "fc"
case vol.FlexVolume != nil:
d.volumeType = "flexVolume"
case vol.Flocker != nil:
d.volumeType = "flocker"
case vol.GitRepo != nil:
d.volumeType = "gitRepo"
case vol.Glusterfs != nil:
d.volumeType = "glusterfs"
case vol.PhotonPersistentDisk != nil:
d.volumeType = "photonPersistentDisk"
case vol.PortworxVolume != nil:
d.volumeType = "portworxVolume"
case vol.Quobyte != nil:
d.volumeType = "quobyte"
case vol.RBD != nil:
d.volumeType = "rbd"
case vol.ScaleIO != nil:
d.volumeType = "scaleIO"
case vol.StorageOS != nil:
d.volumeType = "storageos"
case vol.VsphereVolume != nil:
d.volumeType = "vsphereVolume"
case vol.Ephemeral != nil:
d.volumeType = "ephemeral"
}
return d
}
// DetectVolumeTypeFromPath attempts to identify volume type from kubelet path patterns.
// This is a best-effort fallback; accurate volume types require K8s API access via GetPodVolumes.
func DetectVolumeTypeFromPath(hostPath string) (volumeType, volumeName string) {
volumeType = "unknown"
volumeName = ""
// Map of path patterns to volume types
patterns := map[string]string{
"/kubernetes.io~empty-dir/": "emptyDir",
"/kubernetes.io~configmap/": "configMap",
"/kubernetes.io~secret/": "secret",
"/kubernetes.io~projected/": "projected",
"/kubernetes.io~downward-api/": "downwardAPI",
"/kubernetes.io~persistentvolumeclaim/": "persistentVolumeClaim",
"/kubernetes.io~hostpath/": "hostPath",
}
for pattern, vType := range patterns {
if strings.Contains(hostPath, pattern) {
volumeType = vType
// Extract volume name from path
parts := strings.Split(hostPath, pattern)
if len(parts) > 1 {
volumeName = strings.Split(parts[1], "/")[0]
}
break
}
}
return volumeType, volumeName
}
// metadata_builder provides checkpoint metadata construction.
package checkpoint
import (
"context"
"strings"
"github.com/sirupsen/logrus"
checkpointk8s "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint/k8s"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
)
// MetadataBuilderConfig holds configuration for building checkpoint metadata.
type MetadataBuilderConfig struct {
CheckpointID string
NodeName string
ContainerID string
ContainerName string
PodName string
PodNamespace string
PID int
CUDAPluginDir string
}
// BuildCheckpointMetadata constructs checkpoint metadata from container state.
func BuildCheckpointMetadata(
ctx context.Context,
cfg MetadataBuilderConfig,
containerInfo *checkpointk8s.ContainerInfo,
mounts []MountMapping,
namespaces map[NamespaceType]*NamespaceInfo,
k8sClient *checkpointk8s.K8sClient,
log *logrus.Entry,
) *common.CheckpointMetadata {
meta := common.NewCheckpointMetadata(cfg.CheckpointID)
meta.SourceNode = cfg.NodeName
meta.ContainerID = cfg.ContainerID
meta.PodName = cfg.PodName
meta.PodNamespace = cfg.PodNamespace
meta.PID = cfg.PID
meta.Image = containerInfo.Image
// Populate OCI spec derived paths
meta.MaskedPaths = containerInfo.GetMaskedPaths()
meta.ReadonlyPaths = containerInfo.GetReadonlyPaths()
// Build mount metadata
ociMountByDest := buildOCIMountLookup(containerInfo, meta)
// Get K8s volume types if available
k8sVolumes := getK8sVolumes(ctx, k8sClient, cfg, log)
// Add mount metadata
for _, mount := range mounts {
mountMeta := buildMountMetadata(mount, k8sVolumes, ociMountByDest)
meta.Mounts = append(meta.Mounts, mountMeta)
}
// Add namespace metadata
for nsType, nsInfo := range namespaces {
meta.Namespaces = append(meta.Namespaces, common.NamespaceMetadata{
Type: string(nsType),
Inode: nsInfo.Inode,
IsExternal: nsInfo.IsExternal,
})
}
// Set CRIU options (hardcoded as always-on for K8s, stored for compatibility)
meta.CRIUOptions = common.CRIUOptionsMetadata{
TcpEstablished: false, // Always false - we close TCP connections
TcpClose: true, // Always true - pod IPs change on restore
ShellJob: true, // Always true - containers are session leaders
FileLocks: true, // Always true - apps use file locks
LeaveRunning: true, // Always true - keep process running after checkpoint
LinkRemap: true, // Always true - handle deleted-but-open files
ExtMasters: true, // Always true - external bind mount masters
}
return meta
}
// buildOCIMountLookup builds a lookup map from OCI mounts and populates bind mount destinations.
func buildOCIMountLookup(containerInfo *checkpointk8s.ContainerInfo, meta *common.CheckpointMetadata) map[string]checkpointk8s.MountInfo {
ociMounts := containerInfo.GetMounts()
ociMountByDest := make(map[string]checkpointk8s.MountInfo)
for _, m := range ociMounts {
ociMountByDest[m.Destination] = m
if m.Type == "bind" {
meta.BindMountDests = append(meta.BindMountDests, m.Destination)
}
}
return ociMountByDest
}
// getK8sVolumes fetches volume types from K8s API if available.
func getK8sVolumes(ctx context.Context, k8sClient *checkpointk8s.K8sClient, cfg MetadataBuilderConfig, log *logrus.Entry) map[string]*checkpointk8s.VolumeInfo {
if k8sClient == nil || cfg.PodNamespace == "" || cfg.PodName == "" || cfg.ContainerName == "" {
return nil
}
k8sVolumes, err := k8sClient.GetPodVolumes(ctx, cfg.PodNamespace, cfg.PodName, cfg.ContainerName)
if err != nil {
log.WithError(err).Warn("Failed to get volume types from K8s API, falling back to path-based detection")
return nil
}
log.WithField("volume_count", len(k8sVolumes)).Debug("Got volume types from K8s API")
return k8sVolumes
}
// buildMountMetadata constructs metadata for a single mount.
func buildMountMetadata(mount MountMapping, k8sVolumes map[string]*checkpointk8s.VolumeInfo, ociMountByDest map[string]checkpointk8s.MountInfo) common.MountMetadata {
var volumeType, volumeName string
// Try K8s API first for accurate volume types
if k8sVolumes != nil {
if volInfo, ok := k8sVolumes[mount.InsidePath]; ok {
volumeType = volInfo.VolumeType
volumeName = volInfo.VolumeName
}
}
// Fall back to path-based detection if K8s API didn't provide info
if volumeType == "" {
volumeType, volumeName = checkpointk8s.DetectVolumeTypeFromPath(mount.OutsidePath)
}
mountMeta := common.MountMetadata{
ContainerPath: mount.InsidePath,
HostPath: mount.OutsidePath,
VolumeType: volumeType,
VolumeName: volumeName,
FSType: mount.FSType,
ReadOnly: strings.Contains(mount.Options, "ro"),
}
// Cross-reference with OCI spec mount if available
if ociMount, ok := ociMountByDest[mount.InsidePath]; ok {
mountMeta.OCISource = ociMount.Source
mountMeta.OCIType = ociMount.Type
mountMeta.OCIOptions = ociMount.Options
}
return mountMeta
}
// mounts provides mount parsing from /proc for CRIU checkpoint. // mounts parses runtime mount state from /proc.
// This is used for runtime mount state that requires /proc inspection.
package checkpoint package checkpoint
import ( import (
"bufio"
"fmt" "fmt"
"os" "path"
"path/filepath"
"strings" "strings"
specs "github.com/opencontainers/runtime-spec/specs-go"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
) )
// MountMapping represents an external mount for CRIU type MountInfo struct {
type MountMapping struct { MountID string
InsidePath string // Path inside container (mount point) ParentID string
OutsidePath string // Path on host (source) MountPoint string
FSType string // Filesystem type Root string
Source string // Mount source FSType string
Options string // Mount options Source string
Options string
SuperOptions string
} }
// System mount types that should be filtered out // MountPolicy is the classified mount plan for CRIU dump options.
var systemMountTypes = map[string]bool{ type MountPolicy struct {
"proc": true, Externalized []string
"sysfs": true, Skipped []string
"devpts": true,
"mqueue": true,
"tmpfs": true, // Note: some tmpfs mounts may need special handling
"cgroup": true,
"cgroup2": true,
"securityfs": true,
"debugfs": true,
"tracefs": true,
"fusectl": true,
"configfs": true,
"devtmpfs": true,
"hugetlbfs": true,
"pstore": true,
"bpf": true,
} }
// System mount paths that should always be filtered // BuildMountPolicy classifies mounts into CRIU extMnt and skipMnt lists.
var systemMountPaths = map[string]bool{ //
"/proc": true, // Rule order and precedence (top to bottom):
"/sys": true, // 1. Skip non-OCI proc/sys submounts and non-OCI runtime /run submounts.
"/dev": true, // These mounts are typically node/kernel/runtime specific and are the
"/dev/pts": true, // highest-risk source of cross-node restore failures, so skip wins.
"/dev/shm": true, // 2. Externalize mounts owned by runtime/OCI:
"/dev/mqueue": true, // - "/" (rootfs is recreated by runtime in OCI restore path)
"/run": true, // - OCI mount destinations
"/run/secrets": true, // - OCI masked/readonly paths
} // 3. Externalize non-OCI bind-like mounts (mount root is not "/" or ".").
// This captures runtime-injected file mounts (for example driver files)
// so CRIU does not try to recreate them from checkpoint data.
// 4. Anything else is left unflagged and handled by CRIU default behavior.
//
// Precedence: skip > externalize. If a path is classified as skipped, it is
// removed from the externalized set.
func BuildMountPolicy(mountInfo []MountInfo, ociSpec *specs.Spec, rootFS string) *MountPolicy {
ociManagedSet := collectOCIManagedDestinations(ociSpec, rootFS)
// ParseMountInfo parses /proc/<pid>/mountinfo and returns bind mounts externalizedSet := make(map[string]struct{}, len(mountInfo)+len(ociManagedSet))
// that need to be handled by CRIU as external mounts skippedSet := make(map[string]struct{}, len(mountInfo))
func ParseMountInfo(pid int, hostProc string) ([]MountMapping, error) {
if hostProc == "" {
hostProc = "/proc"
}
mountinfoPath := fmt.Sprintf("%s/%d/mountinfo", hostProc, pid) for _, mount := range mountInfo {
file, err := os.Open(mountinfoPath) mp := normalizeMountPath(mount.MountPoint)
if err != nil { if mp == "" {
return nil, fmt.Errorf("failed to open mountinfo: %w", err) continue
} }
defer file.Close()
var mounts []MountMapping
scanner := bufio.NewScanner(file)
for scanner.Scan() { source := path.Clean(strings.TrimSpace(mount.Source))
line := scanner.Text() root := path.Clean(strings.TrimSpace(mount.Root))
mount, skip := parseMountInfoLine(line) isOCIManaged := false
if skip { if _, ok := ociManagedSet[mp]; ok {
continue isOCIManaged = true
} }
mounts = append(mounts, mount) if !isOCIManaged && strings.HasPrefix(mp, "/run/") {
if _, ok := ociManagedSet["/var"+mp]; ok {
isOCIManaged = true
} }
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading mountinfo: %w", err)
} }
if !isOCIManaged && strings.HasPrefix(mp, "/var/run/") {
return mounts, nil if _, ok := ociManagedSet[strings.TrimPrefix(mp, "/var")]; ok {
} isOCIManaged = true
}
// parseMountInfoLine parses a single line from mountinfo
// Returns the mount mapping and whether to skip this mount
//
// mountinfo format:
// 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
// (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
//
// (1) mount ID
// (2) parent ID
// (3) major:minor
// (4) root: root of the mount within the filesystem (host-side path for bind mounts)
// (5) mount point: mount point relative to process's root
// (6) mount options
// (7) optional fields (terminated by single hyphen)
// (8) separator (hyphen)
// (9) filesystem type
// (10) mount source (device)
// (11) super options
func parseMountInfoLine(line string) (MountMapping, bool) {
fields := strings.Fields(line)
if len(fields) < 10 {
return MountMapping{}, true
} }
root := fields[3] // Host-side path within the filesystem (important for bind mounts) // Runtime-owned /run mounts are usually ephemeral tmpfs/overlay mounts
mountPoint := fields[4] // Container-side mount point // or bind-like mounts sourced from host runtime directories.
mountOptions := fields[5] // We skip these unless OCI explicitly manages that destination.
isRunRuntimeMount := strings.HasPrefix(mp, "/run/") &&
(mount.FSType == "tmpfs" ||
mount.FSType == "overlay" ||
strings.HasPrefix(source, "/run/") ||
strings.HasPrefix(source, "/var/run/") ||
strings.HasPrefix(root, "/run/") ||
strings.HasPrefix(root, "/var/run/"))
// Find separator (-) to get fstype and source if !isOCIManaged && (strings.HasPrefix(mp, "/proc/") || strings.HasPrefix(mp, "/sys/") || isRunRuntimeMount) {
sepIdx := -1 skippedSet[mp] = struct{}{}
for i, f := range fields { delete(externalizedSet, mp)
if f == "-" { continue
sepIdx = i
break
}
} }
if sepIdx == -1 || sepIdx+2 >= len(fields) { if mp == "/" || isOCIManaged || (root != "." && root != "/") {
return MountMapping{}, true externalizedSet[mp] = struct{}{}
continue
} }
fsType := fields[sepIdx+1]
source := fields[sepIdx+2]
superOptions := ""
if sepIdx+3 < len(fields) {
superOptions = fields[sepIdx+3]
} }
// Skip system mount types // Ensure OCI-managed destinations are externalized, even when mountinfo does not
if systemMountTypes[fsType] { // include a direct entry (e.g., runtime-managed masked/readonly paths).
return MountMapping{}, true for mp := range ociManagedSet {
if _, skipped := skippedSet[mp]; skipped {
continue
} }
externalizedSet[mp] = struct{}{}
// Skip system mount paths
if systemMountPaths[mountPoint] {
return MountMapping{}, true
} }
// Skip /sys and /proc prefixed paths externalized := make([]string, 0, len(externalizedSet))
if strings.HasPrefix(mountPoint, "/sys/") || strings.HasPrefix(mountPoint, "/proc/") { for mp := range externalizedSet {
return MountMapping{}, true externalized = append(externalized, mp)
} }
skipped := make([]string, 0, len(skippedSet))
// Skip overlay (the root filesystem itself) for mp := range skippedSet {
if fsType == "overlay" && mountPoint == "/" { skipped = append(skipped, mp)
return MountMapping{}, true
} }
// For bind mounts, the root field contains the actual host path return &MountPolicy{
// Use root as OutsidePath since it gives us the host-side path for volume mounts Externalized: externalized,
outsidePath := root Skipped: skipped,
if root == "/" {
// If root is /, this isn't a bind mount from a subdirectory
outsidePath = source
} }
return MountMapping{
InsidePath: mountPoint,
OutsidePath: outsidePath,
FSType: fsType,
Source: source,
Options: mountOptions + "," + superOptions,
}, false
} }
// GetBindMounts returns only bind mounts (type "bind" or with bind option) // collectOCIManagedDestinations returns the canonical set of OCI-owned mount
func GetBindMounts(pid int, hostProc string) ([]MountMapping, error) { // targets. This includes regular OCI mounts plus Linux masked/readonly paths.
mounts, err := ParseMountInfo(pid, hostProc) // Those masked/readonly paths may not appear as direct mountinfo entries, but
if err != nil { // still need to be treated as runtime-owned and externalized.
return nil, err func collectOCIManagedDestinations(ociSpec *specs.Spec, rootFS string) map[string]struct{} {
set := map[string]struct{}{}
if ociSpec == nil {
return set
} }
var bindMounts []MountMapping paths := make([]string, 0, len(ociSpec.Mounts))
for _, m := range mounts { for _, mount := range ociSpec.Mounts {
// Bind mounts typically show the underlying filesystem type paths = append(paths, mount.Destination)
// and have paths that look like kubelet volume paths }
if strings.Contains(m.OutsidePath, "/var/lib/kubelet/pods/") || if ociSpec.Linux != nil {
strings.Contains(m.OutsidePath, "/volumes/") || paths = append(paths, ociSpec.Linux.MaskedPaths...)
strings.Contains(m.Options, "bind") { paths = append(paths, ociSpec.Linux.ReadonlyPaths...)
bindMounts = append(bindMounts, m) }
for _, raw := range paths {
if p := normalizeOCIDestinationPath(raw, rootFS); p != "" {
set[p] = struct{}{}
} }
} }
return bindMounts, nil return set
} }
// GetKubernetesVolumeMounts returns mounts that appear to be Kubernetes volumes // normalizeMountPath applies lexical normalization only.
func GetKubernetesVolumeMounts(pid int, hostProc string) ([]MountMapping, error) { // Mountinfo paths are already kernel truth for the container namespace.
mounts, err := ParseMountInfo(pid, hostProc) func normalizeMountPath(raw string) string {
if err != nil { raw = strings.TrimSpace(raw)
return nil, err if raw == "" {
return ""
} }
var k8sMounts []MountMapping p := path.Clean(raw)
for _, m := range mounts { if !strings.HasPrefix(p, "/") {
// Kubernetes volumes are identified by: p = "/" + p
// 1. Standard kubelet paths: /var/lib/kubelet/pods/
// 2. Minikube/Docker paths: /var/lib/docker/volumes/minikube/_data/lib/kubelet/pods/
// 3. Kubernetes volume markers: kubernetes.io~empty-dir, kubernetes.io~configmap, etc.
if strings.Contains(m.OutsidePath, "/kubelet/pods/") ||
strings.Contains(m.OutsidePath, "/kubernetes.io~") ||
strings.Contains(m.OutsidePath, "/containerd/io.containerd") {
k8sMounts = append(k8sMounts, m)
}
} }
return path.Clean(p)
return k8sMounts, nil
}
// AllMountInfo represents a mount entry from /proc/<pid>/mountinfo
// This includes ALL mounts without filtering, which CRIU captures during checkpoint.
type AllMountInfo struct {
MountID string // Mount ID
ParentID string // Parent mount ID
MountPoint string // Mount point inside container (container-side path)
Root string // Root of mount within filesystem (host-side path for bind mounts)
FSType string // Filesystem type
Source string // Mount source
Options string // Mount options
SuperOptions string // Super block options
} }
// GetAllMountsFromMountinfo parses /proc/<pid>/mountinfo and returns ALL mounts. // normalizeOCIDestinationPath canonicalizes OCI destinations against container
// This is used for CRIU checkpoint to mark ALL mounts as external, since CRIU // rootfs symlinks (for example /var/run -> /run) with lexical fallback.
// captures everything from mountinfo, not just the filtered subset. func normalizeOCIDestinationPath(raw, rootFS string) string {
// Without marking ALL mounts as external, CRIU restore fails with p := normalizeMountPath(raw)
// "No mapping for <mount_id>:(null) mountpoint" errors. if p == "" || rootFS == "" {
func GetAllMountsFromMountinfo(pid int, hostProc string) ([]AllMountInfo, error) { return p
if hostProc == "" {
hostProc = "/proc"
} }
mountinfoPath := fmt.Sprintf("%s/%d/mountinfo", hostProc, pid) hostPath := filepath.Join(rootFS, strings.TrimPrefix(p, "/"))
file, err := os.Open(mountinfoPath) resolved, err := filepath.EvalSymlinks(hostPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open mountinfo: %w", err) return p
} }
defer file.Close()
var mounts []AllMountInfo rel, err := filepath.Rel(rootFS, resolved)
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
mount, err := parseAllMountInfoLine(line)
if err != nil { if err != nil {
continue // Skip malformed lines return p
} }
mounts = append(mounts, mount) rel = filepath.ToSlash(rel)
if rel == "." {
return "/"
} }
if strings.HasPrefix(rel, "../") || rel == ".." {
if err := scanner.Err(); err != nil { return p
return nil, fmt.Errorf("error reading mountinfo: %w", err)
} }
return mounts, nil return normalizeMountPath("/" + rel)
} }
// parseAllMountInfoLine parses a single line from mountinfo without filtering. func ReadMountInfoFromHostProcPath(pid int) ([]MountInfo, error) {
// mountinfo format: mountinfoPath := fmt.Sprintf("%s/%d/mountinfo", HostProcPath, pid)
// 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue parsedMounts, err := common.ParseMountInfoFile(mountinfoPath)
// (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11) if err != nil {
func parseAllMountInfoLine(line string) (AllMountInfo, error) { return nil, fmt.Errorf("failed to parse mountinfo at %s: %w", mountinfoPath, err)
fields := strings.Fields(line)
if len(fields) < 10 {
return AllMountInfo{}, fmt.Errorf("malformed mountinfo line: %s", line)
}
mountID := fields[0]
parentID := fields[1]
root := fields[3] // Host-side path within the filesystem
mountPoint := fields[4] // Container-side mount point
mountOptions := fields[5]
// Find separator (-) to get fstype and source
sepIdx := -1
for i, f := range fields {
if f == "-" {
sepIdx = i
break
}
}
if sepIdx == -1 || sepIdx+2 >= len(fields) {
return AllMountInfo{}, fmt.Errorf("malformed mountinfo line (no separator): %s", line)
} }
fsType := fields[sepIdx+1] mounts := make([]MountInfo, 0, len(parsedMounts))
source := fields[sepIdx+2] for _, parsed := range parsedMounts {
superOptions := "" mounts = append(mounts, MountInfo{
if sepIdx+3 < len(fields) { MountID: parsed.MountID,
superOptions = fields[sepIdx+3] ParentID: parsed.ParentID,
MountPoint: parsed.Path,
Root: parsed.Root,
FSType: parsed.FSType,
Source: parsed.Source,
Options: parsed.Options,
SuperOptions: parsed.SuperOpts,
})
} }
return AllMountInfo{ return mounts, nil
MountID: mountID,
ParentID: parentID,
MountPoint: mountPoint,
Root: root,
FSType: fsType,
Source: source,
Options: mountOptions,
SuperOptions: superOptions,
}, nil
} }
...@@ -3,12 +3,17 @@ package checkpoint ...@@ -3,12 +3,17 @@ package checkpoint
import ( import (
"fmt" "fmt"
"os"
"strings"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// NamespaceManifestEntry stores namespace information saved in checkpoint manifests.
type NamespaceManifestEntry struct {
Type string `yaml:"type"` // net, pid, mnt, etc.
Inode uint64 `yaml:"inode"` // Namespace inode
IsExternal bool `yaml:"isExternal"` // Whether namespace is external (shared)
}
// NamespaceType represents a Linux namespace type // NamespaceType represents a Linux namespace type
type NamespaceType string type NamespaceType string
...@@ -26,17 +31,29 @@ const ( ...@@ -26,17 +31,29 @@ const (
type NamespaceInfo struct { type NamespaceInfo struct {
Type NamespaceType Type NamespaceType
Inode uint64 Inode uint64
Path string
IsExternal bool // Whether NS is external (shared with pause container) IsExternal bool // Whether NS is external (shared with pause container)
} }
// GetNamespaceInode returns the inode number for a namespace // NewNamespaceManifestEntries constructs namespace manifest entries from introspected namespaces.
func GetNamespaceInode(pid int, nsType NamespaceType, hostProc string) (uint64, error) { func NewNamespaceManifestEntries(namespaces map[NamespaceType]*NamespaceInfo) []NamespaceManifestEntry {
if hostProc == "" { if len(namespaces) == 0 {
hostProc = "/proc" return nil
}
result := make([]NamespaceManifestEntry, 0, len(namespaces))
for nsType, nsInfo := range namespaces {
result = append(result, NamespaceManifestEntry{
Type: string(nsType),
Inode: nsInfo.Inode,
IsExternal: nsInfo.IsExternal,
})
} }
return result
}
nsPath := fmt.Sprintf("%s/%d/ns/%s", hostProc, pid, nsType) // GetNamespaceInode returns the inode number for a namespace
func GetNamespaceInode(pid int, nsType NamespaceType) (uint64, error) {
nsPath := fmt.Sprintf("%s/%d/ns/%s", HostProcPath, pid, nsType)
var stat unix.Stat_t var stat unix.Stat_t
if err := unix.Stat(nsPath, &stat); err != nil { if err := unix.Stat(nsPath, &stat); err != nil {
return 0, fmt.Errorf("failed to stat namespace %s: %w", nsPath, err) return 0, fmt.Errorf("failed to stat namespace %s: %w", nsPath, err)
...@@ -46,12 +63,8 @@ func GetNamespaceInode(pid int, nsType NamespaceType, hostProc string) (uint64, ...@@ -46,12 +63,8 @@ func GetNamespaceInode(pid int, nsType NamespaceType, hostProc string) (uint64,
} }
// GetNamespaceInfo returns detailed namespace information // GetNamespaceInfo returns detailed namespace information
func GetNamespaceInfo(pid int, nsType NamespaceType, hostProc string) (*NamespaceInfo, error) { func GetNamespaceInfo(pid int, nsType NamespaceType) (*NamespaceInfo, error) {
if hostProc == "" { nsPath := fmt.Sprintf("%s/%d/ns/%s", HostProcPath, pid, nsType)
hostProc = "/proc"
}
nsPath := fmt.Sprintf("%s/%d/ns/%s", hostProc, pid, nsType)
// Get inode // Get inode
var stat unix.Stat_t var stat unix.Stat_t
...@@ -59,14 +72,8 @@ func GetNamespaceInfo(pid int, nsType NamespaceType, hostProc string) (*Namespac ...@@ -59,14 +72,8 @@ func GetNamespaceInfo(pid int, nsType NamespaceType, hostProc string) (*Namespac
return nil, fmt.Errorf("failed to stat namespace %s: %w", nsPath, err) return nil, fmt.Errorf("failed to stat namespace %s: %w", nsPath, err)
} }
// Read the symlink to get the namespace identifier
link, err := os.Readlink(nsPath)
if err != nil {
return nil, fmt.Errorf("failed to readlink %s: %w", nsPath, err)
}
// Check if this is different from init's namespace (PID 1) // Check if this is different from init's namespace (PID 1)
initNsPath := fmt.Sprintf("%s/1/ns/%s", hostProc, nsType) initNsPath := fmt.Sprintf("%s/1/ns/%s", HostProcPath, nsType)
var initStat unix.Stat_t var initStat unix.Stat_t
isExternal := false isExternal := false
if err := unix.Stat(initNsPath, &initStat); err == nil { if err := unix.Stat(initNsPath, &initStat); err == nil {
...@@ -77,13 +84,12 @@ func GetNamespaceInfo(pid int, nsType NamespaceType, hostProc string) (*Namespac ...@@ -77,13 +84,12 @@ func GetNamespaceInfo(pid int, nsType NamespaceType, hostProc string) (*Namespac
return &NamespaceInfo{ return &NamespaceInfo{
Type: nsType, Type: nsType,
Inode: stat.Ino, Inode: stat.Ino,
Path: link,
IsExternal: isExternal, IsExternal: isExternal,
}, nil }, nil
} }
// GetAllNamespaces returns information about all namespaces for a process // GetAllNamespaces returns information about all namespaces for a process
func GetAllNamespaces(pid int, hostProc string) (map[NamespaceType]*NamespaceInfo, error) { func GetAllNamespaces(pid int) (map[NamespaceType]*NamespaceInfo, error) {
nsTypes := []NamespaceType{ nsTypes := []NamespaceType{
NamespaceNet, NamespaceNet,
NamespacePID, NamespacePID,
...@@ -96,66 +102,10 @@ func GetAllNamespaces(pid int, hostProc string) (map[NamespaceType]*NamespaceInf ...@@ -96,66 +102,10 @@ func GetAllNamespaces(pid int, hostProc string) (map[NamespaceType]*NamespaceInf
namespaces := make(map[NamespaceType]*NamespaceInfo) namespaces := make(map[NamespaceType]*NamespaceInfo)
for _, nsType := range nsTypes { for _, nsType := range nsTypes {
info, err := GetNamespaceInfo(pid, nsType, hostProc) if info, err := GetNamespaceInfo(pid, nsType); err == nil {
if err != nil {
// Some namespaces might not exist, skip them
continue
}
namespaces[nsType] = info namespaces[nsType] = info
} }
return namespaces, nil
}
// IsNetNamespaceExternal checks if the network namespace is external
// (i.e., shared with the pause container in Kubernetes)
func IsNetNamespaceExternal(pid int, hostProc string) (bool, uint64, error) {
info, err := GetNamespaceInfo(pid, NamespaceNet, hostProc)
if err != nil {
return false, 0, err
}
return info.IsExternal, info.Inode, nil
}
// IsPIDNamespaceExternal checks if the PID namespace is external
func IsPIDNamespaceExternal(pid int, hostProc string) (bool, uint64, error) {
info, err := GetNamespaceInfo(pid, NamespacePID, hostProc)
if err != nil {
return false, 0, err
}
return info.IsExternal, info.Inode, nil
}
// OpenNamespaceFD opens a file descriptor to a namespace
// The caller is responsible for closing the returned file
func OpenNamespaceFD(pid int, nsType NamespaceType, hostProc string) (*os.File, error) {
if hostProc == "" {
hostProc = "/proc"
} }
nsPath := fmt.Sprintf("%s/%d/ns/%s", hostProc, pid, nsType) return namespaces, nil
return os.Open(nsPath)
}
// FormatExternalNamespace formats namespace info for CRIU's External option
// Format: <type>[<inode>]:<key>
func FormatExternalNamespace(nsType NamespaceType, inode uint64) string {
key := formatNamespaceKey(nsType)
return fmt.Sprintf("%s[%d]:%s", nsType, inode, key)
}
// formatNamespaceKey creates the CRIU key for external namespaces
// Format: extRoot<Type>NS (e.g., extRootNetNS, extRootPidNS)
func formatNamespaceKey(nsType NamespaceType) string {
// Capitalize first letter of namespace type
nsName := string(nsType)
if len(nsName) > 0 {
nsName = strings.ToUpper(nsName[:1]) + nsName[1:]
}
return "extRoot" + nsName + "NS"
}
// GetNamespaceKey returns the CRIU key for a namespace type
func GetNamespaceKey(nsType NamespaceType) string {
return formatNamespaceKey(nsType)
} }
// Package checkpoint provides CRIU checkpoint (dump) operations.
package checkpoint
import (
"fmt"
"io"
"os"
"path/filepath"
"github.com/sirupsen/logrus"
)
// CaptureDevShm captures files from /dev/shm to the checkpoint directory.
// This is needed because /dev/shm is a tmpfs mount that is not part of the
// container's overlay filesystem, so rootfs diff doesn't capture it.
//
// Semaphores (sem.* files) are included so that sem_unlink() calls succeed
// after restore. The semaphore kernel state won't be perfectly restored,
// but the files will exist for cleanup operations.
//
// The files are saved to <checkpointDir>/dev-shm/ and can be restored
// using RestoreDevShm before CRIU restore.
func CaptureDevShm(pid int, checkpointDir string, log *logrus.Entry) error {
// Access container's /dev/shm via /proc/<pid>/root
shmPath := filepath.Join(HostProcPath, fmt.Sprintf("%d/root/dev/shm", pid))
entries, err := os.ReadDir(shmPath)
if err != nil {
if os.IsNotExist(err) {
log.Debug("Container /dev/shm does not exist, skipping capture")
return nil
}
return fmt.Errorf("failed to read container /dev/shm: %w", err)
}
// Filter out directories
var filesToCapture []os.DirEntry
for _, entry := range entries {
// Skip directories (unlikely in /dev/shm but be safe)
if entry.IsDir() {
log.WithField("dir", entry.Name()).Debug("Skipping directory in /dev/shm")
continue
}
filesToCapture = append(filesToCapture, entry)
}
if len(filesToCapture) == 0 {
log.Debug("No files to capture from /dev/shm")
return nil
}
// Create destination directory
destDir := filepath.Join(checkpointDir, DevShmDirName)
if err := os.MkdirAll(destDir, 0755); err != nil {
return fmt.Errorf("failed to create dev-shm directory: %w", err)
}
var captured []string
var totalSize int64
for _, entry := range filesToCapture {
name := entry.Name()
srcPath := filepath.Join(shmPath, name)
destPath := filepath.Join(destDir, name)
info, err := entry.Info()
if err != nil {
log.WithError(err).WithField("file", name).Warn("Failed to get file info, skipping")
continue
}
size := info.Size()
// Copy the file
if err := copyFile(srcPath, destPath, info.Mode()); err != nil {
log.WithError(err).WithField("file", name).Warn("Failed to copy file, skipping")
continue
}
captured = append(captured, name)
totalSize += size
log.WithFields(logrus.Fields{
"file": name,
"size": size,
}).Debug("Captured /dev/shm file")
}
if len(captured) > 0 {
log.WithFields(logrus.Fields{
"count": len(captured),
"total_size": totalSize,
"files": captured,
}).Info("Captured /dev/shm files")
}
return nil
}
// copyFile copies a file from src to dest with the given permissions.
func copyFile(src, dest string, mode os.FileMode) error {
srcFile, err := os.Open(src)
if err != nil {
return fmt.Errorf("failed to open source: %w", err)
}
defer srcFile.Close()
destFile, err := os.OpenFile(dest, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode)
if err != nil {
return fmt.Errorf("failed to create destination: %w", err)
}
defer destFile.Close()
if _, err := io.Copy(destFile, srcFile); err != nil {
return fmt.Errorf("failed to copy contents: %w", err)
}
// Sync to ensure durability for checkpoint data
if err := destFile.Sync(); err != nil {
return fmt.Errorf("failed to sync destination: %w", err)
}
return nil
}
// storage.go provides checkpoint storage I/O: write/read manifests, listing, deletion.
package checkpoint
import (
"fmt"
"os"
"path/filepath"
"strings"
"gopkg.in/yaml.v3"
)
// WriteCheckpointManifest writes a checkpoint manifest file in the checkpoint directory.
func WriteCheckpointManifest(checkpointDir string, data *CheckpointManifest) error {
content, err := yaml.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal checkpoint manifest: %w", err)
}
manifestPath := filepath.Join(checkpointDir, CheckpointManifestFilename)
if err := os.WriteFile(manifestPath, content, 0600); err != nil {
return fmt.Errorf("failed to write checkpoint manifest: %w", err)
}
return nil
}
// ReadCheckpointManifest reads checkpoint manifest from a checkpoint directory.
func ReadCheckpointManifest(checkpointDir string) (*CheckpointManifest, error) {
manifestPath := filepath.Join(checkpointDir, CheckpointManifestFilename)
content, err := os.ReadFile(manifestPath)
if err != nil {
return nil, fmt.Errorf("failed to read checkpoint manifest: %w", err)
}
var data CheckpointManifest
if err := yaml.Unmarshal(content, &data); err != nil {
return nil, fmt.Errorf("failed to unmarshal checkpoint manifest: %w", err)
}
return &data, nil
}
// SaveDescriptors writes file descriptor information to the checkpoint directory.
func SaveDescriptors(checkpointDir string, descriptors []string) error {
content, err := yaml.Marshal(descriptors)
if err != nil {
return fmt.Errorf("failed to marshal descriptors: %w", err)
}
descriptorsPath := filepath.Join(checkpointDir, DescriptorsFilename)
if err := os.WriteFile(descriptorsPath, content, 0600); err != nil {
return fmt.Errorf("failed to write descriptors file: %w", err)
}
return nil
}
// LoadDescriptors reads file descriptor information from checkpoint directory.
func LoadDescriptors(checkpointDir string) ([]string, error) {
descriptorsPath := filepath.Join(checkpointDir, DescriptorsFilename)
content, err := os.ReadFile(descriptorsPath)
if err != nil {
return nil, fmt.Errorf("failed to read descriptors file: %w", err)
}
var descriptors []string
if err := yaml.Unmarshal(content, &descriptors); err != nil {
return nil, fmt.Errorf("failed to unmarshal descriptors: %w", err)
}
return descriptors, nil
}
// ListCheckpoints returns all checkpoint IDs in the base directory.
func ListCheckpoints(baseDir string) ([]string, error) {
entries, err := os.ReadDir(baseDir)
if err != nil {
return nil, fmt.Errorf("failed to read checkpoint directory: %w", err)
}
var checkpoints []string
for _, entry := range entries {
if !entry.IsDir() {
continue
}
// Check if manifest file exists.
manifestPath := filepath.Join(baseDir, entry.Name(), CheckpointManifestFilename)
if _, err := os.Stat(manifestPath); err == nil {
checkpoints = append(checkpoints, entry.Name())
}
}
return checkpoints, nil
}
// DeleteCheckpoint removes a checkpoint directory.
func DeleteCheckpoint(baseDir, checkpointID string) error {
checkpointDir := filepath.Join(baseDir, checkpointID)
// Ensure resolved path is within baseDir to prevent path traversal
absBase, _ := filepath.Abs(baseDir)
absDir, _ := filepath.Abs(checkpointDir)
if !strings.HasPrefix(absDir, absBase+string(filepath.Separator)) && absDir != absBase {
return fmt.Errorf("invalid checkpoint ID: resolved path %s is outside base directory %s", absDir, absBase)
}
return os.RemoveAll(checkpointDir)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment