Unverified Commit f3aa1e01 authored by Julien Mancuso's avatar Julien Mancuso Committed by GitHub
Browse files

feat: introducing ChReK (Checkpoint Restore in K8s) (#4978)


Signed-off-by: default avatarJulien Mancuso <jmancuso@nvidia.com>
parent 44986bf5
...@@ -98,6 +98,7 @@ operator: ...@@ -98,6 +98,7 @@ operator:
deploy: deploy:
- 'deploy/helm/**' - 'deploy/helm/**'
- 'deploy/utils/**' - 'deploy/utils/**'
- 'deploy/chrek/**'
planner: planner:
- 'components/src/dynamo/planner/**' - 'components/src/dynamo/planner/**'
......
...@@ -86,6 +86,9 @@ class Config: ...@@ -86,6 +86,9 @@ class Config:
# Use vLLM's tokenizer for pre/post processing # Use vLLM's tokenizer for pre/post processing
use_vllm_tokenizer: bool = False use_vllm_tokenizer: bool = False
# sleep mode support (enable_sleep_mode comes from vLLM's engine_args)
sleep_mode_level: int = 1
# Whether to enable NATS for KV events (derived from kv_events_config in overwrite_args) # Whether to enable NATS for KV events (derived from kv_events_config in overwrite_args)
use_kv_events: bool = False use_kv_events: bool = False
...@@ -289,6 +292,13 @@ def parse_args() -> Config: ...@@ -289,6 +292,13 @@ def parse_args() -> Config:
default=False, default=False,
help="Use vLLM's tokenizer for pre and post processing. This bypasses Dynamo's preprocessor and only v1/chat/completions will be available through the Dynamo frontend.", help="Use vLLM's tokenizer for pre and post processing. This bypasses Dynamo's preprocessor and only v1/chat/completions will be available through the Dynamo frontend.",
) )
parser.add_argument(
"--sleep-mode-level",
type=int,
default=1,
choices=[1, 2, 3],
help="Sleep mode level (1=offload to CPU, 2=discard weights, 3=discard all). Default: 1",
)
add_config_dump_args(parser) add_config_dump_args(parser)
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
...@@ -423,6 +433,7 @@ def parse_args() -> Config: ...@@ -423,6 +433,7 @@ def parse_args() -> Config:
config.event_plane = args.event_plane config.event_plane = args.event_plane
config.enable_local_indexer = args.enable_local_indexer config.enable_local_indexer = args.enable_local_indexer
config.use_vllm_tokenizer = args.use_vllm_tokenizer config.use_vllm_tokenizer = args.use_vllm_tokenizer
config.sleep_mode_level = args.sleep_mode_level
# use_kv_events is set later in overwrite_args() based on kv_events_config # use_kv_events is set later in overwrite_args() based on kv_events_config
# Validate custom Jinja template file exists if provided # Validate custom Jinja template file exists if provided
......
...@@ -72,6 +72,48 @@ async def _handle_non_leader_node(dp_rank: int) -> None: ...@@ -72,6 +72,48 @@ async def _handle_non_leader_node(dp_rank: int) -> None:
await asyncio.Event().wait() await asyncio.Event().wait()
async def await_checkpoint_and_was_restored(signal_file: str) -> bool:
"""
Wait for checkpoint signal file OR restore marker file.
In checkpoint creation mode, poll until either:
1. The signal file exists (checkpoint complete, should exit)
2. The restore marker file exists (restored by CRIU, should proceed)
The restore marker file is created by the restore-entrypoint before CRIU restore,
so the restored process can detect it was restored even though os.environ is
restored from the checkpoint and doesn't contain new container env vars.
Args:
signal_file: Path to the checkpoint signal file
Returns:
True if restored (should proceed with registration)
False if signal file detected (should exit)
"""
# Get restore marker file path (created by restore entrypoint before CRIU restore)
restore_marker = os.environ.get("DYN_RESTORE_MARKER_FILE", "/tmp/dynamo-restored")
logger.info(
f"CHECKPOINT_READY: Model loaded, ready for container checkpoint. Waiting for signal file: {signal_file} or restore marker file: {restore_marker}"
)
while True:
# Check if we've been restored (marker file created by restore entrypoint)
if os.path.exists(restore_marker):
logger.info(
f"Detected restore from checkpoint (marker file exists: {restore_marker})"
)
return True # Restored - proceed with registration
# Check if checkpoint is complete (signal file exists)
if os.path.exists(signal_file):
logger.info(f"Checkpoint signal file detected: {signal_file}")
return False # Checkpoint done - exit
await asyncio.sleep(1)
async def graceful_shutdown(runtime, shutdown_event): async def graceful_shutdown(runtime, shutdown_event):
""" """
Shutdown dynamo distributed runtime. Shutdown dynamo distributed runtime.
...@@ -90,6 +132,86 @@ async def worker(): ...@@ -90,6 +132,86 @@ async def worker():
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
overwrite_args(config) overwrite_args(config)
dump_config(config.dump_config_to, config)
# Name the model. Use either the full path (vllm and sglang do the same),
# or the HF name (e.g. "Qwen/Qwen3-0.6B"), depending on cmd line params.
if not config.served_model_name:
config.served_model_name = config.engine_args.served_model_name = config.model
# Check checkpoint-related environment variables EARLY
signal_file = os.environ.get("DYN_CHECKPOINT_SIGNAL_FILE")
ready_file = os.environ.get("DYN_CHECKPOINT_READY_FILE")
is_checkpoint_mode = signal_file is not None
# EARLY EXIT: Check if checkpoint already exists (before downloading model!)
if is_checkpoint_mode:
storage_type = os.environ.get("DYN_CHECKPOINT_STORAGE_TYPE")
checkpoint_location = os.environ.get("DYN_CHECKPOINT_LOCATION")
if storage_type == "pvc" and checkpoint_location:
done_marker = f"{checkpoint_location}/checkpoint.done"
if os.path.exists(done_marker):
logger.info(
f"Found existing checkpoint at {checkpoint_location}. Storage type: {storage_type}"
)
return
else:
logger.info(
f"Checkpoint not found at: {checkpoint_location}. creating new checkpoint"
)
# Download the model if necessary using modelexpress.
# We want it on disk before we start vllm to avoid downloading from HuggingFace.
#
# We don't set `config.engine_args.model` to the local path fetch_llm returns
# because vllm will send that name to its Ray pipeline-parallel workers, which
# may not have the local path.
# vllm will attempt to download the model again, but find it in the HF cache.
# For non-HF models use a path instead of an HF name, and ensure all workers have
# that path (ideally via a shared folder).
if not os.path.exists(config.model):
await fetch_llm(config.model)
# CHECKPOINT MODE: Load engine BEFORE runtime creation
# This allows checkpointing GPU state before runtime connections are established
pre_created_engine = None
is_restored = False
if is_checkpoint_mode:
logger.info(
f"Checkpoint mode enabled (DYN_CHECKPOINT_SIGNAL_FILE={signal_file})"
)
# CHECKPOINT MODE: Load model, sleep, wait for signal file or restore
pre_created_engine = setup_vllm_engine(config)
engine_client = pre_created_engine[0]
# Put model to sleep before checkpoint (if sleep mode enabled)
if config.engine_args.enable_sleep_mode:
logger.info(f"Putting model to sleep (level={config.sleep_mode_level})")
await engine_client.sleep(level=config.sleep_mode_level)
# Write ready file to signal that we're ready for checkpointing
if ready_file:
with open(ready_file, "w") as f:
f.write("ready")
logger.info(f"Wrote checkpoint ready file: {ready_file}")
# Wait for checkpoint signal file OR restore detection
is_restored = await await_checkpoint_and_was_restored(signal_file)
if is_restored:
# Wake up model and proceed with registration
if config.engine_args.enable_sleep_mode:
logger.info("Waking up model after checkpoint restore")
await engine_client.wake_up()
logger.info("Proceeding with endpoint registration after restore")
else:
# Checkpoint complete, exit
logger.info("Exiting after checkpoint completion")
return
# Create shutdown event # Create shutdown event
shutdown_event = asyncio.Event() shutdown_event = asyncio.Event()
...@@ -117,25 +239,6 @@ async def worker(): ...@@ -117,25 +239,6 @@ async def worker():
logging.debug("Signal handlers set up for graceful shutdown") logging.debug("Signal handlers set up for graceful shutdown")
dump_config(config.dump_config_to, config)
# Name the model. Use either the full path (vllm and sglang do the same),
# or the HF name (e.g. "Qwen/Qwen3-0.6B"), depending on cmd line params.
if not config.served_model_name:
config.served_model_name = config.engine_args.served_model_name = config.model
# Download the model if necessary using modelexpress.
# We want it on disk before we start vllm to avoid downloading from HuggingFace.
#
# We don't set `config.engine_args.model` to the local path fetch_llm returns
# because vllm will send that name to its Ray pipeline-parallel workers, which
# may not have the local path.
# vllm will attempt to download the model again, but find it in the HF cache.
# For non-HF models use a path instead of an HF name, and ensure all workers have
# that path (ideally via a shared folder).
if not os.path.exists(config.model):
await fetch_llm(config.model)
# Route to appropriate initialization based on config flags # Route to appropriate initialization based on config flags
if config.vllm_native_encoder_worker: if config.vllm_native_encoder_worker:
await init_vllm_native_encoder(runtime, config, shutdown_event) await init_vllm_native_encoder(runtime, config, shutdown_event)
...@@ -154,13 +257,19 @@ async def worker(): ...@@ -154,13 +257,19 @@ async def worker():
or config.multimodal_decode_worker or config.multimodal_decode_worker
or config.multimodal_encode_prefill_worker or config.multimodal_encode_prefill_worker
): ):
await init_multimodal_worker(runtime, config, shutdown_event) await init_multimodal_worker(
runtime, config, shutdown_event, pre_created_engine=pre_created_engine
)
logger.debug("init_multimodal_worker completed") logger.debug("init_multimodal_worker completed")
elif config.is_prefill_worker: elif config.is_prefill_worker:
await init_prefill(runtime, config, shutdown_event) await init_prefill(
runtime, config, shutdown_event, pre_created_engine=pre_created_engine
)
logger.debug("init_prefill completed") logger.debug("init_prefill completed")
else: else:
await init(runtime, config, shutdown_event) await init(
runtime, config, shutdown_event, pre_created_engine=pre_created_engine
)
logger.debug("init completed") logger.debug("init completed")
logger.debug("Worker function completed, exiting...") logger.debug("Worker function completed, exiting...")
...@@ -451,7 +560,10 @@ async def register_vllm_model( ...@@ -451,7 +560,10 @@ async def register_vllm_model(
async def init_prefill( async def init_prefill(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
pre_created_engine=None,
): ):
""" """
Instantiate and serve Instantiate and serve
...@@ -461,6 +573,15 @@ async def init_prefill( ...@@ -461,6 +573,15 @@ async def init_prefill(
generate_endpoint = component.endpoint(config.endpoint) generate_endpoint = component.endpoint(config.endpoint)
clear_endpoint = component.endpoint("clear_kv_blocks") clear_endpoint = component.endpoint("clear_kv_blocks")
# Use pre-created engine if provided (checkpoint mode), otherwise create new
if pre_created_engine is not None:
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
) = pre_created_engine
else:
( (
engine_client, engine_client,
vllm_config, vllm_config,
...@@ -567,7 +688,10 @@ async def init_prefill( ...@@ -567,7 +688,10 @@ async def init_prefill(
async def init( async def init(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
pre_created_engine=None,
): ):
""" """
Instantiate and serve Instantiate and serve
...@@ -586,6 +710,16 @@ async def init( ...@@ -586,6 +710,16 @@ async def init(
config.engine_args.data_parallel_rank or 0, config.engine_args.data_parallel_rank or 0,
metrics_labels=[("model", config.served_model_name or config.model)], metrics_labels=[("model", config.served_model_name or config.model)],
) )
# Use pre-created engine if provided (checkpoint mode), otherwise create new
if pre_created_engine is not None:
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
) = pre_created_engine
else:
( (
engine_client, engine_client,
vllm_config, vllm_config,
...@@ -975,7 +1109,10 @@ async def init_ec_processor( ...@@ -975,7 +1109,10 @@ async def init_ec_processor(
async def init_multimodal_worker( async def init_multimodal_worker(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
pre_created_engine=None,
): ):
""" """
Initialize multimodal worker component. Initialize multimodal worker component.
...@@ -1013,6 +1150,15 @@ async def init_multimodal_worker( ...@@ -1013,6 +1150,15 @@ async def init_multimodal_worker(
config.engine_args.ec_transfer_config = ec_transfer_config config.engine_args.ec_transfer_config = ec_transfer_config
logger.info(f"Configured as ECConnector consumer with engine_id={engine_id}") logger.info(f"Configured as ECConnector consumer with engine_id={engine_id}")
# Use pre-created engine if provided (checkpoint mode), otherwise create new
if pre_created_engine is not None:
(
engine_client,
vllm_config,
default_sampling_params,
prometheus_temp_dir,
) = pre_created_engine
else:
( (
engine_client, engine_client,
vllm_config, vllm_config,
......
# Binaries
bin/
*.exe
# Reference source repos (clone separately if needed)
containerd/
runc/
# Build artifacts
*.o
*.a
*.so
# IDE/Editor
.idea/
.vscode/
*.swp
*.swo
*~
# OS files
.DS_Store
Thumbs.db
# Test artifacts
*.test
coverage.out
*.prof
# Temporary files
*.tmp
*.tar
/tmp/
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Unified Dockerfile for chrek-agent and placeholder images.
#
# Build targets:
# docker build --target agent -t chrek-agent:latest .
# docker build --target placeholder --build-arg BASE_IMAGE=<app-image> -t placeholder:latest .
#
# Optional targets for CI:
# docker build --target linter . # Run linting
# docker build --target tester . # Run tests
# =============================================================================
# Build Arguments
# =============================================================================
ARG DOCKER_PROXY
ARG GO_VERSION=1.25
ARG CRIU_VERSION=v4.2
ARG AGENT_BASE_IMAGE=nvcr.io/nvidia/cuda-dl-base:25.11-cuda13.0-devel-ubuntu24.04
# For placeholder target only - this default allows agent builds to succeed,
# but placeholder builds MUST override it with --build-arg BASE_IMAGE=<image>
ARG BASE_IMAGE=placeholder-requires-base-image-arg
# =============================================================================
# Stage: Go base - Common setup for Go builds
# =============================================================================
FROM ${DOCKER_PROXY}golang:${GO_VERSION} AS go-base
ARG TARGETOS=linux
ARG TARGETARCH=amd64
RUN echo "Building for ${TARGETOS}/${TARGETARCH}"
RUN apt-get update && apt-get install -y --no-install-recommends git ca-certificates \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /workspace
COPY go.mod go.sum ./
RUN go mod download
COPY . .
# =============================================================================
# Stage: Linter - Run golangci-lint
# =============================================================================
FROM go-base AS linter
RUN go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.62.2
RUN golangci-lint run --timeout=5m
# =============================================================================
# Stage: Tester - Run tests
# =============================================================================
FROM go-base AS tester
RUN go test ./... -v
# =============================================================================
# Stage: Builder - Build Go binaries
# =============================================================================
FROM go-base AS builder
ARG TARGETOS=linux
ARG TARGETARCH=amd64
RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -ldflags="-w -s" -o /chrek-agent ./cmd/agent
RUN CGO_ENABLED=0 GOOS=${TARGETOS} GOARCH=${TARGETARCH} go build -ldflags="-w -s" -o /restore-entrypoint ./cmd/restore-entrypoint
# =============================================================================
# Stage: CRIU Builder - Build CRIU with CUDA plugin
# =============================================================================
FROM ubuntu:24.04 AS criu-builder
ARG CRIU_VERSION
RUN apt-get update && apt-get install -y --no-install-recommends \
git \
ca-certificates \
build-essential \
pkg-config \
libbsd-dev \
libcap-dev \
libnet1-dev \
libnl-3-dev \
libnl-route-3-dev \
libprotobuf-dev \
libprotobuf-c-dev \
protobuf-c-compiler \
protobuf-compiler \
python3 \
python3-protobuf \
libgnutls28-dev \
libnftables-dev \
uuid-dev \
&& rm -rf /var/lib/apt/lists/*
RUN git clone --branch ${CRIU_VERSION} https://github.com/checkpoint-restore/criu.git /tmp/criu \
&& cd /tmp/criu \
&& make -j$(nproc) \
&& make DESTDIR=/criu-install install-criu install-lib install-cuda_plugin
RUN git clone https://github.com/NVIDIA/cuda-checkpoint.git /tmp/cuda-checkpoint
# =============================================================================
# Stage: Agent - Final chrek-agent image
# =============================================================================
FROM ${AGENT_BASE_IMAGE} AS agent
# Install CRIU runtime dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
libbsd0 \
libcap2 \
libnet1 \
libnl-3-200 \
libnl-route-3-200 \
libprotobuf-c1 \
libgnutls30t64 \
libnftables1 \
iproute2 \
iptables \
procps \
uuid-runtime \
tar \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# Copy CRIU from builder
COPY --from=criu-builder /criu-install/usr/local /usr/local
RUN criu --version
# Copy cuda-checkpoint binary
COPY --from=criu-builder /tmp/cuda-checkpoint/bin/x86_64_Linux/cuda-checkpoint /usr/local/sbin/cuda-checkpoint
RUN chmod +x /usr/local/sbin/cuda-checkpoint
# Copy the built binaries
COPY --from=builder /chrek-agent /usr/local/bin/chrek-agent
COPY --from=builder /restore-entrypoint /restore-entrypoint
# Create checkpoint directory
RUN mkdir -p /checkpoints
# Set environment variables
ENV HOST_PROC=/host/proc \
CONTAINERD_SOCKET=/run/containerd/containerd.sock \
CHECKPOINT_DIR=/checkpoints \
LISTEN_ADDR=:8080
EXPOSE 8080
USER root
ENTRYPOINT ["/usr/local/bin/chrek-agent"]
# =============================================================================
# Stage: Placeholder - Restore placeholder image (requires BASE_IMAGE arg)
# =============================================================================
FROM ${BASE_IMAGE} AS placeholder
ARG BASE_IMAGE
ENV ORIGINAL_BASE_IMAGE=${BASE_IMAGE}
USER root
# Install CRIU runtime dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
libbsd0 \
libcap2 \
libnet1 \
libnl-3-200 \
libnl-route-3-200 \
libprotobuf-c1 \
libgnutls30 \
libnftables1 \
iproute2 \
iptables \
procps \
uuid-runtime \
tar \
ca-certificates \
&& rm -rf /var/lib/apt/lists/*
# Copy CRIU from builder
COPY --from=criu-builder /criu-install/usr/local /usr/local
RUN criu --version && echo "CRIU installed successfully"
# Copy cuda-checkpoint binary
COPY --from=criu-builder /tmp/cuda-checkpoint/bin/x86_64_Linux/cuda-checkpoint /usr/local/sbin/cuda-checkpoint
RUN chmod +x /usr/local/sbin/cuda-checkpoint
# Create directories
RUN mkdir -p /checkpoints /var/run/criu /tmp /var/criu-work
# Copy restore binaries
COPY --from=builder /restore-entrypoint /restore-entrypoint
RUN chmod +x /restore-entrypoint
COPY scripts/smart-entrypoint.sh /smart-entrypoint.sh
RUN chmod +x /smart-entrypoint.sh
# Set environment variables
ENV DYN_CHECKPOINT_PATH=/checkpoints \
RESTORE_TRIGGER=/tmp/restore-trigger \
RESTORE_WAIT_TIMEOUT=300 \
CRIU_LOG_LEVEL=4 \
WAIT_FOR_CHECKPOINT=0 \
CUDA_PLUGIN_DIR=/usr/local/lib/criu \
DEBUG=0
ENTRYPOINT ["/smart-entrypoint.sh"]
CMD []
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Image URL to use all building/pushing image targets
IMG ?= nvcr.io/nvidian/dynamo-dev/chrek-agent:latest
PLACEHOLDER_IMG ?= nvcr.io/nvidian/dynamo-dev/dynamo-vllm-placeholder:latest
# PLACEHOLDER_BASE_IMG must be provided when building placeholder (no default)
# Get the currently used golang install path (in GOPATH/bin, unless GOBIN is set)
ifeq (,$(shell go env GOBIN))
GOBIN=$(shell go env GOPATH)/bin
else
GOBIN=$(shell go env GOBIN)
endif
# CONTAINER_TOOL defines the container tool to be used for building images.
CONTAINER_TOOL ?= docker
# Setting SHELL to bash allows bash commands to be executed by recipes.
SHELL = /usr/bin/env bash -o pipefail
.SHELLFLAGS = -ec
.PHONY: all
all: build
##@ General
.PHONY: help
help: ## Display this help.
@awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m<target>\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST)
##@ Development
.PHONY: fmt
fmt: ## Run go fmt against code.
go fmt ./...
.PHONY: vet
vet: ## Run go vet against code.
go vet ./...
.PHONY: test
test: fmt vet ## Run tests.
go test ./... -v -coverprofile cover.out
.PHONY: lint
lint: golangci-lint ## Run golangci-lint linter.
$(GOLANGCI_LINT) run --timeout=5m
.PHONY: lint-fix
lint-fix: golangci-lint ## Run golangci-lint linter and perform fixes.
$(GOLANGCI_LINT) run --fix --timeout=5m
##@ Build
.PHONY: build
build: fmt vet ## Build chrek-agent and restore-entrypoint binaries.
CGO_ENABLED=0 go build -ldflags="-w -s" -o bin/chrek-agent ./cmd/agent
CGO_ENABLED=0 go build -ldflags="-w -s" -o bin/restore-entrypoint ./cmd/restore-entrypoint
.PHONY: build-agent
build-agent: fmt vet ## Build chrek-agent binary only.
CGO_ENABLED=0 go build -ldflags="-w -s" -o bin/chrek-agent ./cmd/agent
.PHONY: build-restore
build-restore: fmt vet ## Build restore-entrypoint binary only.
CGO_ENABLED=0 go build -ldflags="-w -s" -o bin/restore-entrypoint ./cmd/restore-entrypoint
.PHONY: run
run: build ## Run chrek-agent from your host.
./bin/chrek-agent
.PHONY: clean
clean: ## Remove build artifacts.
rm -rf bin/
rm -f cover.out
##@ Docker
.PHONY: docker-build-agent
docker-build-agent: ## Build chrek-agent docker image.
$(CONTAINER_TOOL) build --target agent -t ${IMG} .
.PHONY: docker-build-agent-lint
docker-build-agent-lint: ## Build chrek-agent docker image up to lint stage.
$(CONTAINER_TOOL) build --target linter -t ${IMG}-lint .
.PHONY: docker-build-agent-test
docker-build-agent-test: ## Build chrek-agent docker image up to test stage.
$(CONTAINER_TOOL) build --target tester -t ${IMG}-test .
.PHONY: docker-build-placeholder
docker-build-placeholder: ## Build placeholder image for checkpoint restore. Requires PLACEHOLDER_BASE_IMG.
ifndef PLACEHOLDER_BASE_IMG
$(error PLACEHOLDER_BASE_IMG is required. Example: make docker-build-placeholder PLACEHOLDER_BASE_IMG=nvcr.io/nvidia/ai-dynamo/vllm-runtime:0.8.1-cuda13)
endif
$(CONTAINER_TOOL) build --target placeholder \
--build-arg BASE_IMAGE=${PLACEHOLDER_BASE_IMG} \
-t ${PLACEHOLDER_IMG} .
.PHONY: docker-push-agent
docker-push-agent: ## Push chrek-agent docker image.
$(CONTAINER_TOOL) push ${IMG}
.PHONY: docker-push-placeholder
docker-push-placeholder: ## Push placeholder docker image.
$(CONTAINER_TOOL) push ${PLACEHOLDER_IMG}
##@ Dependencies
## Location to install dependencies to
LOCALBIN ?= $(shell pwd)/bin
$(LOCALBIN):
mkdir -p $(LOCALBIN)
## Tool Binaries
GOLANGCI_LINT = $(LOCALBIN)/golangci-lint-$(GOLANGCI_LINT_VERSION)
## Tool Versions
GOLANGCI_LINT_VERSION ?= v1.62.2
.PHONY: golangci-lint
golangci-lint: $(GOLANGCI_LINT) ## Download golangci-lint locally if necessary.
$(GOLANGCI_LINT): $(LOCALBIN)
$(call go-install-tool,$(GOLANGCI_LINT),github.com/golangci/golangci-lint/cmd/golangci-lint,${GOLANGCI_LINT_VERSION})
# go-install-tool will 'go install' any package with custom target and name of binary, if it doesn't exist
# $1 - target path with name of binary (ideally with version)
# $2 - package url which can be installed
# $3 - specific version of package
define go-install-tool
@[ -f $(1) ] || { \
set -e; \
package=$(2)@$(3) ;\
echo "Downloading $${package}" ;\
GOBIN=$(LOCALBIN) go install $${package} ;\
mv "$$(echo "$(1)" | sed "s/-$(3)$$//")" $(1) ;\
}
endef
.PHONY: coverage
coverage: test ## Show test coverage report.
go tool cover -func=cover.out
// Package main provides the CRIU node agent with HTTP API and/or pod watching.
// The agent supports two modes that can be enabled independently:
// - HTTP API mode: Exposes REST endpoints for checkpoint/restore operations
// - Watcher mode: Automatically checkpoints pods with nvidia.com/checkpoint-source=true label
package main
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint"
checkpointk8s "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint/k8s"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/watcher"
)
// CheckpointSignalSource determines how checkpoint operations are triggered
type CheckpointSignalSource string
const (
// SignalFromHTTP triggers checkpoints via HTTP API requests
SignalFromHTTP CheckpointSignalSource = "http"
// SignalFromWatcher triggers checkpoints automatically when pods become Ready
SignalFromWatcher CheckpointSignalSource = "watcher"
)
// Config holds the agent configuration
type Config struct {
// Common settings
ContainerdSocket string
CheckpointDir string
HostProc string
NodeName string
RestrictedNamespace string // Optional: restrict pod watching to this namespace
// Mode selection
SignalSource CheckpointSignalSource // "http" or "watcher"
// HTTP API mode settings (used when SignalSource = "http")
ListenAddr string
// CRIU settings (configurable options only; LeaveRunning, ShellJob, etc. are hardcoded in pkg/checkpoint/criu.go)
CUDAPluginDir string // Path to CRIU CUDA plugin directory
GhostLimit uint32 // CRIU ghost limit in bytes
Timeout uint32 // CRIU timeout in seconds
ExternalMounts []string // External mount mappings
}
// Server is the HTTP API server
type Server struct {
config Config
discoveryClient *checkpointk8s.DiscoveryClient
checkpointer *checkpoint.Checkpointer
}
// CheckpointRequest is the request body for checkpoint operations
type CheckpointRequest struct {
ContainerID string `json:"container_id"`
CheckpointID string `json:"checkpoint_id"`
PodName string `json:"pod_name,omitempty"`
PodNamespace string `json:"pod_namespace,omitempty"`
DisableCUDA bool `json:"disable_cuda,omitempty"` // Disable CUDA plugin for non-GPU workloads
}
// TriggerRestoreRequest is the request body for Option A self-restoring trigger
type TriggerRestoreRequest struct {
CheckpointID string `json:"checkpoint_id"`
PlaceholderContainerID string `json:"placeholder_container_id"`
SkipImageValidation bool `json:"skip_image_validation,omitempty"` // Skip image matching check
}
// TriggerRestoreResponse is the response for trigger restore operations
type TriggerRestoreResponse struct {
Success bool `json:"success"`
Message string `json:"message,omitempty"`
Error string `json:"error,omitempty"`
TriggerPath string `json:"trigger_path,omitempty"`
CheckpointImage string `json:"checkpoint_image,omitempty"`
PlaceholderImage string `json:"placeholder_image,omitempty"`
}
// CheckpointResponse is the response for checkpoint operations
type CheckpointResponse struct {
Success bool `json:"success"`
CheckpointID string `json:"checkpoint_id,omitempty"`
Message string `json:"message,omitempty"`
Error string `json:"error,omitempty"`
}
// CheckpointInfo represents information about a checkpoint
type CheckpointInfo struct {
ID string `json:"id"`
CreatedAt time.Time `json:"created_at"`
SourceNode string `json:"source_node"`
ContainerID string `json:"container_id"`
PodName string `json:"pod_name"`
PodNamespace string `json:"pod_namespace"`
Image string `json:"image"`
}
// ListCheckpointsResponse is the response for list checkpoints
type ListCheckpointsResponse struct {
Checkpoints []CheckpointInfo `json:"checkpoints"`
}
// HealthResponse is the response for health check
type HealthResponse struct {
Status string `json:"status"`
NodeName string `json:"node_name"`
}
func main() {
// Parse signal source - default to HTTP for backward compatibility
signalSource := CheckpointSignalSource(strings.ToLower(getEnv("CHECKPOINT_SIGNAL_FROM", "http")))
if signalSource != SignalFromHTTP && signalSource != SignalFromWatcher {
log.Fatalf("Invalid CHECKPOINT_SIGNAL_FROM value: %q (must be 'http' or 'watcher')", signalSource)
}
// Parse CRIU settings
var ghostLimit, timeout uint32
if v := os.Getenv("CRIU_GHOST_LIMIT"); v != "" {
if parsed, err := strconv.ParseUint(v, 10, 32); err == nil {
ghostLimit = uint32(parsed)
}
}
if v := os.Getenv("CRIU_TIMEOUT"); v != "" {
if parsed, err := strconv.ParseUint(v, 10, 32); err == nil {
timeout = uint32(parsed)
}
}
// Parse external mounts (comma-separated)
var externalMounts []string
if v := os.Getenv("EXTERNAL_MOUNTS"); v != "" {
externalMounts = strings.Split(v, ",")
}
config := Config{
// Common settings
ContainerdSocket: getEnv("CONTAINERD_SOCKET", "/run/containerd/containerd.sock"),
CheckpointDir: getEnv("CHECKPOINT_DIR", "/checkpoints"),
HostProc: getEnv("HOST_PROC", "/host/proc"),
NodeName: getEnv("NODE_NAME", "unknown"),
RestrictedNamespace: os.Getenv("RESTRICTED_NAMESPACE"), // Optional: empty = cluster-wide watching
// Mode selection
SignalSource: signalSource,
// HTTP API settings
ListenAddr: getEnv("LISTEN_ADDR", ":8080"),
// CRIU settings
CUDAPluginDir: getEnv("CUDA_PLUGIN_DIR", ""),
GhostLimit: ghostLimit,
Timeout: timeout,
ExternalMounts: externalMounts,
}
// Create discovery client
discoveryClient, err := checkpointk8s.NewDiscoveryClient(config.ContainerdSocket)
if err != nil {
log.Fatalf("Failed to create discovery client: %v", err)
}
defer discoveryClient.Close()
// Create checkpointer
checkpointer := checkpoint.NewCheckpointer(discoveryClient, config.HostProc)
// Context for graceful shutdown
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Handle graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
log.Printf("CRIU Node Agent starting (node: %s)", config.NodeName)
log.Printf("Checkpoint directory: %s", config.CheckpointDir)
log.Printf("Signal source: %s", config.SignalSource)
switch config.SignalSource {
case SignalFromHTTP:
server := &Server{
config: config,
discoveryClient: discoveryClient,
checkpointer: checkpointer,
}
// Setup routes
mux := http.NewServeMux()
mux.HandleFunc("/health", server.handleHealth)
mux.HandleFunc("/checkpoint", server.handleCheckpoint)
mux.HandleFunc("/restore/trigger", server.handleTriggerRestore)
mux.HandleFunc("/checkpoints", server.handleListCheckpoints)
httpServer := &http.Server{
Addr: config.ListenAddr,
Handler: loggingMiddleware(mux),
ReadTimeout: 30 * time.Second,
WriteTimeout: 300 * time.Second,
IdleTimeout: 120 * time.Second,
}
// Handle graceful shutdown
go func() {
<-sigChan
log.Println("Shutting down HTTP server...")
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
log.Printf("HTTP server shutdown error: %v", err)
}
}()
log.Printf("HTTP API server listening on %s", config.ListenAddr)
if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
log.Fatalf("HTTP server error: %v", err)
}
case SignalFromWatcher:
watcherConfig := watcher.Config{
NodeName: config.NodeName,
CheckpointDir: config.CheckpointDir,
HostProc: config.HostProc,
ListenAddr: config.ListenAddr, // For health check endpoint
RestrictedNamespace: config.RestrictedNamespace,
CUDAPluginDir: config.CUDAPluginDir,
GhostLimit: config.GhostLimit,
Timeout: config.Timeout,
ExternalMounts: config.ExternalMounts,
}
podWatcher, err := watcher.NewWatcher(watcherConfig, discoveryClient, checkpointer)
if err != nil {
log.Fatalf("Failed to create pod watcher: %v", err)
}
// Handle graceful shutdown
go func() {
<-sigChan
log.Println("Shutting down pod watcher...")
cancel()
}()
log.Printf("Pod watcher started (watching for label: nvidia.com/checkpoint-source=true)")
log.Printf("Health check endpoint: http://0.0.0.0%s/health", config.ListenAddr)
if err := podWatcher.Start(ctx); err != nil {
log.Printf("Pod watcher error: %v", err)
}
}
log.Println("Agent stopped")
}
func (s *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
resp := HealthResponse{
Status: "healthy",
NodeName: s.config.NodeName,
}
writeJSON(w, http.StatusOK, resp)
}
func (s *Server) handleCheckpoint(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req CheckpointRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, CheckpointResponse{
Success: false,
Error: fmt.Sprintf("Invalid request body: %v", err),
})
return
}
if req.ContainerID == "" {
writeJSON(w, http.StatusBadRequest, CheckpointResponse{
Success: false,
Error: "container_id is required",
})
return
}
if req.CheckpointID == "" {
req.CheckpointID = fmt.Sprintf("ckpt-%d", time.Now().UnixNano())
}
// Determine CUDA plugin directory - only use if not explicitly disabled
cudaPluginDir := s.config.CUDAPluginDir
if req.DisableCUDA {
cudaPluginDir = ""
}
// Parse optional CRIU settings from environment
var ghostLimit, timeout uint32
if v := os.Getenv("CRIU_GHOST_LIMIT"); v != "" {
if parsed, err := strconv.ParseUint(v, 10, 32); err == nil {
ghostLimit = uint32(parsed)
}
}
if v := os.Getenv("CRIU_TIMEOUT"); v != "" {
if parsed, err := strconv.ParseUint(v, 10, 32); err == nil {
timeout = uint32(parsed)
}
}
opts := checkpoint.Options{
ContainerID: req.ContainerID,
CheckpointID: req.CheckpointID,
CheckpointDir: s.config.CheckpointDir,
NodeName: s.config.NodeName,
PodName: req.PodName,
PodNamespace: req.PodNamespace,
GhostLimit: ghostLimit,
Timeout: timeout,
CUDAPluginDir: cudaPluginDir,
}
ctx := r.Context()
result, err := s.checkpointer.Checkpoint(ctx, opts)
if err != nil {
log.Printf("Checkpoint failed: %v", err)
writeJSON(w, http.StatusInternalServerError, CheckpointResponse{
Success: false,
Error: err.Error(),
})
return
}
log.Printf("Checkpoint successful: %s", result.CheckpointID)
writeJSON(w, http.StatusOK, CheckpointResponse{
Success: true,
CheckpointID: result.CheckpointID,
Message: fmt.Sprintf("Checkpoint created successfully at %s", result.CheckpointDir),
})
}
// handleTriggerRestore implements Option A from RESTORE_ANALYSIS.md
// It triggers a self-restoring placeholder container to start CRIU restore.
// The agent writes a trigger file to the placeholder's filesystem, which
// the placeholder's entrypoint script detects and uses to start restoration.
func (s *Server) handleTriggerRestore(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
var req TriggerRestoreRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeJSON(w, http.StatusBadRequest, TriggerRestoreResponse{
Success: false,
Error: fmt.Sprintf("Invalid request body: %v", err),
})
return
}
if req.CheckpointID == "" {
writeJSON(w, http.StatusBadRequest, TriggerRestoreResponse{
Success: false,
Error: "checkpoint_id is required",
})
return
}
if req.PlaceholderContainerID == "" {
writeJSON(w, http.StatusBadRequest, TriggerRestoreResponse{
Success: false,
Error: "placeholder_container_id is required",
})
return
}
// Verify checkpoint exists and load metadata
checkpointPath := common.GetCheckpointDir(s.config.CheckpointDir, req.CheckpointID)
checkpointMeta, err := common.LoadMetadata(checkpointPath)
if err != nil {
writeJSON(w, http.StatusNotFound, TriggerRestoreResponse{
Success: false,
Error: fmt.Sprintf("Checkpoint not found: %v", err),
})
return
}
// Resolve placeholder container to get PID and image
ctx := r.Context()
containerInfo, err := s.discoveryClient.ResolveContainer(ctx, req.PlaceholderContainerID)
if err != nil {
writeJSON(w, http.StatusInternalServerError, TriggerRestoreResponse{
Success: false,
Error: fmt.Sprintf("Failed to resolve placeholder container: %v", err),
})
return
}
// Validate that placeholder image matches checkpoint's original image
// This is critical because rootfs-diff.tar only contains upperdir modifications,
// not the base image layers (lowerdir). If images differ, CRIU restore will fail.
if !req.SkipImageValidation && checkpointMeta.Image != "" {
if !imagesCompatible(checkpointMeta.Image, containerInfo.Image) {
writeJSON(w, http.StatusBadRequest, TriggerRestoreResponse{
Success: false,
Error: fmt.Sprintf("Image mismatch: checkpoint was from '%s' but placeholder uses '%s'. The placeholder must use the same base image. Use skip_image_validation=true to override.", checkpointMeta.Image, containerInfo.Image),
CheckpointImage: checkpointMeta.Image,
PlaceholderImage: containerInfo.Image,
})
return
}
log.Printf("Image validation passed: checkpoint=%s, placeholder=%s", checkpointMeta.Image, containerInfo.Image)
}
// Write trigger file to placeholder's filesystem via /proc/<pid>/root
// The trigger file contains the checkpoint path
triggerPath := fmt.Sprintf("%s/%d/root/tmp/restore-trigger", s.config.HostProc, containerInfo.PID)
// Write the checkpoint path to the trigger file
if err := os.WriteFile(triggerPath, []byte(checkpointPath), 0644); err != nil {
writeJSON(w, http.StatusInternalServerError, TriggerRestoreResponse{
Success: false,
Error: fmt.Sprintf("Failed to write trigger file: %v", err),
})
return
}
log.Printf("Triggered restore for placeholder %s (PID %d) from checkpoint %s",
req.PlaceholderContainerID, containerInfo.PID, req.CheckpointID)
writeJSON(w, http.StatusOK, TriggerRestoreResponse{
Success: true,
Message: fmt.Sprintf("Restore triggered for checkpoint %s", req.CheckpointID),
TriggerPath: triggerPath,
CheckpointImage: checkpointMeta.Image,
PlaceholderImage: containerInfo.Image,
})
}
func (s *Server) handleListCheckpoints(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
checkpointIDs, err := common.ListCheckpoints(s.config.CheckpointDir)
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]string{
"error": err.Error(),
})
return
}
var checkpoints []CheckpointInfo
for _, id := range checkpointIDs {
meta, err := common.GetCheckpointInfo(s.config.CheckpointDir, id)
if err != nil {
continue
}
checkpoints = append(checkpoints, CheckpointInfo{
ID: meta.CheckpointID,
CreatedAt: meta.CreatedAt,
SourceNode: meta.SourceNode,
ContainerID: meta.ContainerID,
PodName: meta.PodName,
PodNamespace: meta.PodNamespace,
Image: meta.Image,
})
}
writeJSON(w, http.StatusOK, ListCheckpointsResponse{
Checkpoints: checkpoints,
})
}
func writeJSON(w http.ResponseWriter, status int, data interface{}) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
func loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
log.Printf("Started %s %s", r.Method, r.URL.Path)
next.ServeHTTP(w, r)
log.Printf("Completed %s %s in %v", r.Method, r.URL.Path, time.Since(start))
})
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
// imagesCompatible checks if two container images are compatible for CRIU restore.
// The placeholder image must be based on the same image as the checkpoint.
// Handles various image name formats:
// - nginx:alpine vs nginx:alpine (exact match)
// - docker.io/library/nginx:alpine vs nginx:alpine (registry prefix)
// - criu-placeholder-nginx-alpine vs nginx:alpine (placeholder naming convention)
func imagesCompatible(checkpointImage, placeholderImage string) bool {
// Exact match
if checkpointImage == placeholderImage {
return true
}
// Normalize images by removing common registry prefixes
normalize := func(img string) string {
// Remove docker.io/library/ prefix
img = strings.TrimPrefix(img, "docker.io/library/")
// Remove docker.io/ prefix
img = strings.TrimPrefix(img, "docker.io/")
return img
}
normalizedCheckpoint := normalize(checkpointImage)
normalizedPlaceholder := normalize(placeholderImage)
if normalizedCheckpoint == normalizedPlaceholder {
return true
}
// Check if placeholder follows criu-placeholder-<image> naming convention
// e.g., criu-placeholder-nginx-alpine should match nginx:alpine
if strings.HasPrefix(normalizedPlaceholder, "criu-placeholder-") {
// Convert nginx:alpine to nginx-alpine for comparison
checkpointSanitized := strings.ReplaceAll(normalizedCheckpoint, ":", "-")
checkpointSanitized = strings.ReplaceAll(checkpointSanitized, "/", "-")
expectedPlaceholder := "criu-placeholder-" + checkpointSanitized
if normalizedPlaceholder == expectedPlaceholder ||
strings.HasPrefix(normalizedPlaceholder, expectedPlaceholder+":") {
return true
}
}
return false
}
// Package main provides the restore-entrypoint binary for self-restoring placeholder containers.
// This binary replaces the shell script restore-entrypoint.sh with a Go implementation
// that uses the go-criu library for CRIU operations.
package main
import (
"context"
"os"
"os/signal"
"syscall"
"github.com/sirupsen/logrus"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/restore"
)
func main() {
// Set up logging
log := logrus.New()
log.SetOutput(os.Stdout)
log.SetFormatter(&logrus.TextFormatter{
FullTimestamp: true,
TimestampFormat: "2006-01-02 15:04:05",
})
// Load configuration from environment
cfg := restore.ConfigFromEnv()
// Set log level based on DEBUG flag
if cfg.Debug {
log.SetLevel(logrus.DebugLevel)
} else {
log.SetLevel(logrus.InfoLevel)
}
entry := log.WithField("component", "restore-entrypoint")
// Set up context with signal handling for graceful shutdown
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Handle shutdown signals
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT)
go func() {
sig := <-sigChan
entry.WithField("signal", sig).Info("Received shutdown signal")
cancel()
}()
// Run the restore entrypoint
if err := restore.Run(ctx, cfg, entry); err != nil {
entry.WithError(err).Fatal("Restore entrypoint failed")
}
}
module github.com/ai-dynamo/dynamo/deploy/chrek
go 1.25
require (
github.com/checkpoint-restore/go-criu/v7 v7.2.0
github.com/containerd/containerd v1.7.11
github.com/opencontainers/runtime-spec v1.1.0
github.com/sirupsen/logrus v1.9.3
golang.org/x/sys v0.25.0
google.golang.org/protobuf v1.34.2
k8s.io/api v0.29.0
k8s.io/apimachinery v0.29.0
k8s.io/client-go v0.29.0
)
require (
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 // indirect
github.com/AdamKorcz/go-118-fuzz-build v0.0.0-20230306123547-8075edf89bb0 // indirect
github.com/Microsoft/go-winio v0.6.1 // indirect
github.com/Microsoft/hcsshim v0.11.4 // indirect
github.com/containerd/cgroups v1.1.0 // indirect
github.com/containerd/continuity v0.4.2 // indirect
github.com/containerd/fifo v1.1.0 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/ttrpc v1.2.2 // indirect
github.com/containerd/typeurl/v2 v2.1.1 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/docker/go-events v0.0.0-20190806004212-e31b211e4f1c // indirect
github.com/emicklei/go-restful/v3 v3.11.0 // indirect
github.com/felixge/httpsnoop v1.0.3 // indirect
github.com/go-logr/logr v1.3.0 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/jsonpointer v0.19.6 // indirect
github.com/go-openapi/jsonreference v0.20.2 // indirect
github.com/go-openapi/swag v0.22.3 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/gnostic-models v0.6.8 // indirect
github.com/google/go-cmp v0.6.0 // indirect
github.com/google/gofuzz v1.2.0 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/imdario/mergo v0.3.13 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.16.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/moby/locker v1.0.1 // indirect
github.com/moby/sys/mountinfo v0.6.2 // indirect
github.com/moby/sys/sequential v0.5.0 // indirect
github.com/moby/sys/signal v0.7.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0-rc5 // indirect
github.com/opencontainers/runc v1.1.5 // indirect
github.com/opencontainers/selinux v1.11.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/spf13/pflag v1.0.5 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.45.0 // indirect
go.opentelemetry.io/otel v1.19.0 // indirect
go.opentelemetry.io/otel/metric v1.19.0 // indirect
go.opentelemetry.io/otel/trace v1.19.0 // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/oauth2 v0.10.0 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/term v0.13.0 // indirect
golang.org/x/text v0.13.0 // indirect
golang.org/x/time v0.3.0 // indirect
golang.org/x/tools v0.12.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20230920204549-e6e6cdab5c13 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20231002182017-d307bd883b97 // indirect
google.golang.org/grpc v1.58.3 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
k8s.io/klog/v2 v2.110.1 // indirect
k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00 // indirect
k8s.io/utils v0.0.0-20230726121419-3b25d923346b // indirect
sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect
sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect
sigs.k8s.io/yaml v1.3.0 // indirect
)
This diff is collapsed.
// Package checkpoint provides CRIU checkpoint (dump) operations.
package checkpoint
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"time"
criu "github.com/checkpoint-restore/go-criu/v7"
"github.com/sirupsen/logrus"
"google.golang.org/protobuf/proto"
checkpointk8s "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint/k8s"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
)
// Options configures the checkpoint operation
type Options struct {
ContainerID string
ContainerName string // K8s container name (for K8s API volume type lookup)
CheckpointID string
CheckpointDir string
NodeName string
PodName string
PodNamespace string
// CRIU options (from environment variables)
GhostLimit uint32 // From CRIU_GHOST_LIMIT: ghost file size limit in bytes (0 = CRIU default)
Timeout uint32 // From CRIU_TIMEOUT: timeout in seconds (0 = no timeout)
// GPU/CUDA checkpoint options
CUDAPluginDir string // Path to CRIU CUDA plugin (e.g., /home/mmshin/work/criu/plugins/cuda)
ExternalMounts []string // Additional external mount mappings (e.g., "mnt[path]:path")
}
// Result contains the result of a checkpoint operation
type Result struct {
CheckpointID string
CheckpointDir string
Metadata *common.CheckpointMetadata
}
// Checkpointer performs CRIU checkpoint operations
type Checkpointer struct {
discoveryClient *checkpointk8s.DiscoveryClient
k8sClient *checkpointk8s.K8sClient // Optional: for accurate volume type discovery from K8s API
hostProc string
log *logrus.Entry
}
// NewCheckpointer creates a new checkpointer
func NewCheckpointer(discoveryClient *checkpointk8s.DiscoveryClient, hostProc string) *Checkpointer {
if hostProc == "" {
hostProc = os.Getenv("HOST_PROC")
if hostProc == "" {
hostProc = "/proc"
}
}
return &Checkpointer{
discoveryClient: discoveryClient,
hostProc: hostProc,
log: logrus.WithField("component", "checkpointer"),
}
}
// WithK8sClient sets an optional Kubernetes client for accurate volume type discovery.
// When set, volume types are fetched from the K8s API instead of being inferred from paths.
func (c *Checkpointer) WithK8sClient(client *checkpointk8s.K8sClient) *Checkpointer {
c.k8sClient = client
return c
}
// Checkpoint performs a CRIU dump of a container
func (c *Checkpointer) Checkpoint(ctx context.Context, opts Options) (*Result, error) {
checkpointStart := time.Now()
c.log.Info("=== Starting checkpoint operation ===")
// 1. Resolve container to get PID
resolveStart := time.Now()
containerInfo, err := c.discoveryClient.ResolveContainer(ctx, opts.ContainerID)
if err != nil {
return nil, fmt.Errorf("failed to resolve container: %w", err)
}
pid := int(containerInfo.PID)
c.log.WithField("duration", time.Since(resolveStart)).Info("Container resolution completed")
// 2. Create checkpoint directory
checkpointDir := common.GetCheckpointDir(opts.CheckpointDir, opts.CheckpointID)
if err := os.MkdirAll(checkpointDir, 0700); err != nil {
return nil, fmt.Errorf("failed to create checkpoint directory: %w", err)
}
// 3. Introspect container state
introspectStart := time.Now()
rootFS, err := GetRootFS(pid, c.hostProc)
if err != nil {
return nil, fmt.Errorf("failed to get rootfs: %w", err)
}
mounts, err := GetKubernetesVolumeMounts(pid, c.hostProc)
if err != nil {
return nil, fmt.Errorf("failed to get mounts: %w", err)
}
namespaces, err := GetAllNamespaces(pid, c.hostProc)
if err != nil {
return nil, fmt.Errorf("failed to get namespaces: %w", err)
}
c.log.WithField("duration", time.Since(introspectStart)).Info("Container introspection completed")
// 4. Open image directory FD
imageDir, imageDirFD, err := OpenImageDir(checkpointDir)
if err != nil {
return nil, err
}
defer imageDir.Close()
// 5. Build CRIU options
criuOpts := BuildCRIUOptsFromCheckpointOpts(opts, pid, imageDirFD, rootFS)
// 6. Create CRIU config file for CUDA plugin (libdir is not available via RPC)
if opts.CUDAPluginDir != "" {
if opts.Timeout == 0 {
return nil, fmt.Errorf("CRIU_TIMEOUT environment variable must be set for CUDA checkpoints")
}
configPath := filepath.Join(checkpointDir, "criu.conf")
configContent := fmt.Sprintf(`enable-external-masters
libdir %s
tcp-close
link-remap
timeout %d
allow-uprobes
skip-in-flight
`, opts.CUDAPluginDir, opts.Timeout)
if err := os.WriteFile(configPath, []byte(configContent), 0644); err != nil {
return nil, fmt.Errorf("failed to write CRIU config file: %w", err)
}
criuOpts.ConfigFile = proto.String(configPath)
c.log.WithFields(logrus.Fields{
"config_path": configPath,
"plugin_dir": opts.CUDAPluginDir,
}).Info("Created CRIU config file for CUDA plugin")
}
// 7. Configure external mounts and namespaces
if err := ConfigureExternalMounts(criuOpts, pid, c.hostProc, containerInfo); err != nil {
return nil, err
}
netNsInode := ConfigureExternalNamespaces(criuOpts, namespaces, opts.ExternalMounts)
if netNsInode > 0 {
c.log.WithField("inode", netNsInode).Debug("Marked network namespace as external")
}
for _, extMount := range opts.ExternalMounts {
c.log.WithField("external", extMount).Debug("Added external mount mapping")
}
// 8. Get overlay upperdir for rootfs diff capture
upperDir, upperDirErr := GetOverlayUpperDir(pid, c.hostProc)
if upperDirErr != nil {
c.log.WithError(upperDirErr).Warn("Could not get overlay upperdir - rootfs diff will not be captured")
} else {
c.log.WithField("upperdir", upperDir).Debug("Found overlay upperdir")
}
// 9. Build and save initial metadata before dump
metaCfg := MetadataBuilderConfig{
CheckpointID: opts.CheckpointID,
NodeName: opts.NodeName,
ContainerID: opts.ContainerID,
ContainerName: opts.ContainerName,
PodName: opts.PodName,
PodNamespace: opts.PodNamespace,
PID: pid,
CUDAPluginDir: opts.CUDAPluginDir,
}
meta := BuildCheckpointMetadata(ctx, metaCfg, containerInfo, mounts, namespaces, c.k8sClient, c.log)
if upperDir != "" {
meta.UpperDir = upperDir
}
if err := common.SaveMetadata(checkpointDir, meta); err != nil {
return nil, fmt.Errorf("failed to save metadata: %w", err)
}
// 10. Remove semaphores from /dev/shm before checkpoint
// Semaphores cause CRIU restore to fail with "Can't link dev/shm/link_remap.X -> dev/shm/sem.Y"
if err := c.removeSemaphores(pid); err != nil {
return nil, fmt.Errorf("failed to remove semaphores: %w", err)
}
// 11. Execute CRIU dump via go-criu
criuDumpStart := time.Now()
criuClient := criu.MakeCriu()
if err := criuClient.Dump(criuOpts, nil); err != nil {
c.log.WithField("duration", time.Since(criuDumpStart)).Error("CRIU dump failed")
return nil, fmt.Errorf("CRIU dump failed: %w", err)
}
criuDumpDuration := time.Since(criuDumpStart)
c.log.WithField("duration", criuDumpDuration).Info("CRIU dump completed successfully")
// 12. Capture rootfs diff and deleted files
rootfsCaptureStart := time.Now()
CaptureRootfsState(upperDir, checkpointDir, meta, c.log)
c.log.WithField("duration", time.Since(rootfsCaptureStart)).Info("Rootfs capture completed")
totalDuration := time.Since(checkpointStart)
c.log.WithFields(logrus.Fields{
"total_duration": totalDuration,
"criu_dump_duration": criuDumpDuration,
}).Info("=== Checkpoint operation completed ===")
return &Result{
CheckpointID: opts.CheckpointID,
CheckpointDir: checkpointDir,
Metadata: meta,
}, nil
}
// removeSemaphores removes POSIX semaphores from the container's /dev/shm.
// Semaphores can cause issues during CRIU checkpoint/restore because they
// maintain kernel state that may not transfer correctly between processes.
// This accesses the container's filesystem via /proc/<pid>/root/dev/shm/.
func (c *Checkpointer) removeSemaphores(pid int) error {
shmPath := filepath.Join(c.hostProc, fmt.Sprintf("%d/root/dev/shm", pid))
entries, err := os.ReadDir(shmPath)
if err != nil {
// It's okay if /dev/shm doesn't exist (container may not have it)
c.log.WithError(err).Debug("Could not read container /dev/shm (may not exist)")
return nil
}
var removed []string
var errors []error
for _, entry := range entries {
name := entry.Name()
if strings.HasPrefix(name, "sem.") {
semPath := filepath.Join(shmPath, name)
if err := os.Remove(semPath); err != nil {
c.log.WithError(err).WithField("semaphore", name).Error("Failed to remove semaphore")
errors = append(errors, fmt.Errorf("failed to remove semaphore %s: %w", name, err))
} else {
removed = append(removed, name)
}
}
}
if len(errors) > 0 {
return fmt.Errorf("failed to remove %d semaphore(s): %v", len(errors), errors)
}
if len(removed) > 0 {
c.log.WithFields(logrus.Fields{
"count": len(removed),
"semaphores": removed,
}).Info("Removed semaphores from container /dev/shm before checkpoint")
} else {
c.log.Debug("No semaphores found in container /dev/shm")
}
return nil
}
// criu provides CRIU-specific configuration and utilities for checkpoint operations.
package checkpoint
import (
"fmt"
"os"
criurpc "github.com/checkpoint-restore/go-criu/v7/rpc"
"google.golang.org/protobuf/proto"
checkpointk8s "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint/k8s"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
)
// CRIUConfig holds configuration for CRIU dump operations.
// Most options are always-on with safe defaults for K8s environments.
type CRIUConfig struct {
PID int
ImageDirFD int32
RootFS string
GhostLimit uint32 // From env CRIU_GHOST_LIMIT: max ghost file size (0 = CRIU default)
Timeout uint32 // From env CRIU_TIMEOUT: checkpoint timeout in seconds (0 = no timeout)
}
// OpenImageDir opens a checkpoint directory and prepares it for CRIU.
// Returns the opened file and its FD. The caller must close the file when done.
// The file descriptor has CLOEXEC cleared so it can be inherited by CRIU.
func OpenImageDir(checkpointDir string) (*os.File, int32, error) {
return common.OpenDirForCRIU(checkpointDir)
}
// BuildCRIUOpts creates CRIU options from a config struct.
// This sets up the base options; external mounts and namespaces are added separately.
//
// Always-on options for K8s:
// - LeaveRunning: always keep process running after checkpoint
// - ShellJob: containers are often session leaders
// - TcpClose: pod IPs change on restore/migration
// - FileLocks: applications use file locks
// - OrphanPtsMaster: containers with TTYs
// - ExtUnixSk: containers have external Unix sockets
// - ManageCgroups (IGNORE): let K8s manage cgroups
// - LinkRemap: handle deleted-but-open files (safe for all workloads)
// - ExtMasters: external bind mount masters (safe for all workloads)
func BuildCRIUOpts(cfg CRIUConfig) *criurpc.CriuOpts {
cgMode := criurpc.CriuCgMode_IGNORE
criuOpts := &criurpc.CriuOpts{
Pid: proto.Int32(int32(cfg.PID)),
ImagesDirFd: proto.Int32(cfg.ImageDirFD),
LogLevel: proto.Int32(4),
LogFile: proto.String("dump.log"),
Root: proto.String(cfg.RootFS),
ManageCgroups: proto.Bool(true),
ManageCgroupsMode: &cgMode,
// Always-on for K8s environments
LeaveRunning: proto.Bool(true),
ShellJob: proto.Bool(true),
TcpClose: proto.Bool(true),
FileLocks: proto.Bool(true),
OrphanPtsMaster: proto.Bool(true),
ExtUnixSk: proto.Bool(true),
LinkRemap: proto.Bool(true),
ExtMasters: proto.Bool(true),
}
// Optional: ghost limit from env (0 = use CRIU default)
if cfg.GhostLimit > 0 {
criuOpts.GhostLimit = proto.Uint32(cfg.GhostLimit)
}
// Optional: timeout from env (0 = no timeout)
if cfg.Timeout > 0 {
criuOpts.Timeout = proto.Uint32(cfg.Timeout)
}
return criuOpts
}
// AddExternalMounts adds mount points as external mounts to CRIU options.
// CRIU requires all mounts to be marked as external for successful restore.
func AddExternalMounts(criuOpts *criurpc.CriuOpts, mounts []AllMountInfo) {
addedMounts := make(map[string]bool)
for _, m := range mounts {
if addedMounts[m.MountPoint] {
continue
}
criuOpts.ExtMnt = append(criuOpts.ExtMnt, &criurpc.ExtMountMap{
Key: proto.String(m.MountPoint),
Val: proto.String(m.MountPoint),
})
addedMounts[m.MountPoint] = true
}
}
// AddExternalPaths adds additional paths (masked/readonly) as external mounts.
// These may not appear in mountinfo but CRIU still needs them marked as external.
func AddExternalPaths(criuOpts *criurpc.CriuOpts, paths []string) {
// Build set of existing mount points
existing := make(map[string]bool)
for _, m := range criuOpts.ExtMnt {
existing[m.GetKey()] = true
}
for _, path := range paths {
if existing[path] {
continue
}
criuOpts.ExtMnt = append(criuOpts.ExtMnt, &criurpc.ExtMountMap{
Key: proto.String(path),
Val: proto.String(path),
})
existing[path] = true
}
}
// AddExternalNamespace adds a namespace as external to CRIU options.
// Format: "<type>[<inode>]:<key>"
func AddExternalNamespace(criuOpts *criurpc.CriuOpts, nsType NamespaceType, inode uint64, key string) {
extNs := fmt.Sprintf("%s[%d]:%s", nsType, inode, key)
criuOpts.External = append(criuOpts.External, extNs)
}
// AddExternalStrings adds raw external strings to CRIU options.
// Used for additional external mount mappings (e.g., NVIDIA firmware files).
func AddExternalStrings(criuOpts *criurpc.CriuOpts, externals []string) {
criuOpts.External = append(criuOpts.External, externals...)
}
// ConfigureExternalMounts adds all required external mounts to CRIU options.
// This includes mounts from /proc/pid/mountinfo plus masked/readonly paths from OCI spec.
func ConfigureExternalMounts(criuOpts *criurpc.CriuOpts, pid int, hostProc string, containerInfo *checkpointk8s.ContainerInfo) error {
// Get all mounts from mountinfo - CRIU needs every mount marked as external
allMounts, err := GetAllMountsFromMountinfo(pid, hostProc)
if err != nil {
return fmt.Errorf("failed to get all mounts from mountinfo: %w", err)
}
// Add mounts from mountinfo
AddExternalMounts(criuOpts, allMounts)
// Add masked and readonly paths from OCI spec
AddExternalPaths(criuOpts, containerInfo.GetMaskedPaths())
AddExternalPaths(criuOpts, containerInfo.GetReadonlyPaths())
return nil
}
// ConfigureExternalNamespaces adds external namespaces to CRIU options.
// Returns the network namespace inode if found, for logging purposes.
func ConfigureExternalNamespaces(criuOpts *criurpc.CriuOpts, namespaces map[NamespaceType]*NamespaceInfo, externalMounts []string) uint64 {
var netNsInode uint64
// Mark network namespace as external for socket binding preservation
if netNs, ok := namespaces[NamespaceNet]; ok {
AddExternalNamespace(criuOpts, NamespaceNet, netNs.Inode, "extNetNs")
netNsInode = netNs.Inode
}
// Add additional external mounts (e.g., for NVIDIA firmware files)
AddExternalStrings(criuOpts, externalMounts)
return netNsInode
}
// BuildCRIUOptsFromCheckpointOpts constructs CRIU options from checkpoint Options.
// Returns the configured CriuOpts ready for external mount/namespace configuration.
func BuildCRIUOptsFromCheckpointOpts(opts Options, pid int, imageDirFD int32, rootFS string) *criurpc.CriuOpts {
cfg := CRIUConfig{
PID: pid,
ImageDirFD: imageDirFD,
RootFS: rootFS,
GhostLimit: opts.GhostLimit,
Timeout: opts.Timeout,
}
return BuildCRIUOpts(cfg)
}
// discovery provides container information resolution via containerd.
// This prefers containerd RPCs for configuration over /proc inspection,
// following the principle that configuration should come from the container runtime
// while runtime state (like namespace inodes) requires /proc.
package k8s
import (
"context"
"fmt"
"os"
"path/filepath"
"github.com/containerd/containerd"
"github.com/containerd/containerd/namespaces"
specs "github.com/opencontainers/runtime-spec/specs-go"
)
const (
// K8sNamespace is the containerd namespace used by Kubernetes
K8sNamespace = "k8s.io"
// DefaultSocket is the default containerd socket path
DefaultSocket = "/run/containerd/containerd.sock"
)
// ContainerInfo holds resolved container information from containerd.
// Configuration data comes from containerd RPCs, runtime state from /proc.
type ContainerInfo struct {
ContainerID string
PID uint32
RootFS string // Actual rootfs path (bundle path + spec.Root.Path)
BundlePath string // Path to container bundle directory
Image string
Spec *specs.Spec // OCI spec from containerd (mounts, namespaces config)
Labels map[string]string
}
// MountInfo represents a mount from the OCI spec.
type MountInfo struct {
Destination string // Mount point inside container
Source string // Source path on host
Type string // Filesystem type (bind, tmpfs, etc.)
Options []string // Mount options
}
// NamespaceConfig represents namespace configuration from OCI spec.
type NamespaceConfig struct {
Type string // Namespace type (network, pid, mount, etc.)
Path string // Path to namespace (empty for new namespace)
}
// DiscoveryClient wraps the containerd client for container discovery.
type DiscoveryClient struct {
client *containerd.Client
socket string
}
// NewDiscoveryClient creates a new discovery client.
func NewDiscoveryClient(socket string) (*DiscoveryClient, error) {
if socket == "" {
socket = DefaultSocket
}
client, err := containerd.New(socket)
if err != nil {
return nil, fmt.Errorf("failed to connect to containerd at %s: %w", socket, err)
}
return &DiscoveryClient{
client: client,
socket: socket,
}, nil
}
// Close closes the containerd client connection.
func (c *DiscoveryClient) Close() error {
if c.client != nil {
return c.client.Close()
}
return nil
}
// ResolveContainer resolves a container ID to its process information.
// This retrieves configuration from containerd RPCs (OCI spec, labels, image)
// and runtime paths from /proc (rootfs access path).
func (c *DiscoveryClient) ResolveContainer(ctx context.Context, containerID string) (*ContainerInfo, error) {
// Use the Kubernetes namespace for containerd
ctx = namespaces.WithNamespace(ctx, K8sNamespace)
// Load the container
container, err := c.client.LoadContainer(ctx, containerID)
if err != nil {
return nil, fmt.Errorf("failed to load container %s: %w", containerID, err)
}
// Get the task (running process)
task, err := container.Task(ctx, nil)
if err != nil {
return nil, fmt.Errorf("failed to get task for container %s: %w", containerID, err)
}
// Get the PID
pid := task.Pid()
// Get container image
image, err := container.Image(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get image for container %s: %w", containerID, err)
}
// Get OCI spec from containerd - this contains mount config, namespace config, etc.
// This is preferred over parsing /proc for configuration data.
spec, err := container.Spec(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get spec for container %s: %w", containerID, err)
}
// Get container labels (includes K8s pod info)
labels, err := container.Labels(ctx)
if err != nil {
// Labels are optional, don't fail
labels = make(map[string]string)
}
// Construct the bundle path where containerd stores the container runtime files
// Standard containerd layout: /run/containerd/io.containerd.runtime.v2.task/<namespace>/<container_id>/
containerdRunRoot := os.Getenv("CONTAINERD_RUN_ROOT")
if containerdRunRoot == "" {
containerdRunRoot = "/run/containerd"
}
bundlePath := filepath.Join(containerdRunRoot, "io.containerd.runtime.v2.task", K8sNamespace, containerID)
// Get the rootfs path from the OCI spec (usually "rootfs" relative to bundle)
rootfsRelPath := "rootfs"
if spec.Root != nil && spec.Root.Path != "" {
rootfsRelPath = spec.Root.Path
}
// Construct full rootfs path
var rootFS string
if filepath.IsAbs(rootfsRelPath) {
rootFS = rootfsRelPath
} else {
rootFS = filepath.Join(bundlePath, rootfsRelPath)
}
return &ContainerInfo{
ContainerID: containerID,
PID: pid,
RootFS: rootFS,
BundlePath: bundlePath,
Image: image.Name(),
Spec: spec,
Labels: labels,
}, nil
}
// GetMounts returns the mount configuration from the OCI spec.
// This is preferred over parsing /proc/mountinfo for configuration,
// though /proc is still needed for runtime mount state.
func (info *ContainerInfo) GetMounts() []MountInfo {
if info.Spec == nil || info.Spec.Mounts == nil {
return nil
}
mounts := make([]MountInfo, len(info.Spec.Mounts))
for i, m := range info.Spec.Mounts {
mounts[i] = MountInfo{
Destination: m.Destination,
Source: m.Source,
Type: m.Type,
Options: m.Options,
}
}
return mounts
}
// GetNamespaces returns the namespace configuration from the OCI spec.
func (info *ContainerInfo) GetNamespaces() []NamespaceConfig {
if info.Spec == nil || info.Spec.Linux == nil {
return nil
}
namespaces := make([]NamespaceConfig, len(info.Spec.Linux.Namespaces))
for i, ns := range info.Spec.Linux.Namespaces {
namespaces[i] = NamespaceConfig{
Type: string(ns.Type),
Path: ns.Path,
}
}
return namespaces
}
// GetMaskedPaths returns the masked paths from the OCI spec.
func (info *ContainerInfo) GetMaskedPaths() []string {
if info.Spec == nil || info.Spec.Linux == nil {
return nil
}
return info.Spec.Linux.MaskedPaths
}
// GetReadonlyPaths returns the readonly paths from the OCI spec.
func (info *ContainerInfo) GetReadonlyPaths() []string {
if info.Spec == nil || info.Spec.Linux == nil {
return nil
}
return info.Spec.Linux.ReadonlyPaths
}
// GetRootfsPath returns the rootfs path from the OCI spec.
// Note: For CRIU, use info.RootFS which is the /proc/<pid>/root path.
func (info *ContainerInfo) GetRootfsPath() string {
if info.Spec == nil || info.Spec.Root == nil {
return ""
}
return info.Spec.Root.Path
}
// IsRootReadonly returns whether the root filesystem is readonly.
func (info *ContainerInfo) IsRootReadonly() bool {
if info.Spec == nil || info.Spec.Root == nil {
return false
}
return info.Spec.Root.Readonly
}
// GetHostname returns the container's hostname from the OCI spec.
func (info *ContainerInfo) GetHostname() string {
if info.Spec == nil {
return ""
}
return info.Spec.Hostname
}
// ListContainers lists all containers in the K8s namespace.
func (c *DiscoveryClient) ListContainers(ctx context.Context) ([]string, error) {
ctx = namespaces.WithNamespace(ctx, K8sNamespace)
containers, err := c.client.Containers(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list containers: %w", err)
}
ids := make([]string, len(containers))
for i, container := range containers {
ids[i] = container.ID()
}
return ids, nil
}
// GetContainerLabels returns the labels for a container.
func (c *DiscoveryClient) GetContainerLabels(ctx context.Context, containerID string) (map[string]string, error) {
ctx = namespaces.WithNamespace(ctx, K8sNamespace)
container, err := c.client.LoadContainer(ctx, containerID)
if err != nil {
return nil, fmt.Errorf("failed to load container %s: %w", containerID, err)
}
labels, err := container.Labels(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get labels for container %s: %w", containerID, err)
}
return labels, nil
}
// Package k8s provides Kubernetes-specific functionality for checkpoint operations.
// This includes volume type discovery via K8s API and containerd container discovery.
package k8s
import (
"context"
"fmt"
"strings"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd"
)
// VolumeInfo contains Kubernetes volume information for a mount.
type VolumeInfo struct {
VolumeName string // Name from pod.spec.volumes[].name
VolumeType string // Type: emptyDir, configMap, secret, persistentVolumeClaim, etc.
MountPath string // Container path from volumeMounts[].mountPath
SubPath string // SubPath if specified
ReadOnly bool // Whether mount is read-only
// Type-specific details
ConfigMapName string // For configMap volumes
SecretName string // For secret volumes
PVCName string // For persistentVolumeClaim volumes
}
// K8sClient wraps the Kubernetes clientset for volume discovery.
type K8sClient struct {
clientset *kubernetes.Clientset
}
// NewK8sClient creates a new Kubernetes client.
// It attempts in-cluster config first, then falls back to kubeconfig.
func NewK8sClient() (*K8sClient, error) {
config, err := rest.InClusterConfig()
if err != nil {
// Fall back to kubeconfig for local development
config, err = clientcmd.BuildConfigFromFlags("", clientcmd.RecommendedHomeFile)
if err != nil {
return nil, fmt.Errorf("failed to create k8s config: %w", err)
}
}
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
return nil, fmt.Errorf("failed to create k8s clientset: %w", err)
}
return &K8sClient{clientset: clientset}, nil
}
// NewK8sClientWithConfig creates a client with explicit config.
func NewK8sClientWithConfig(config *rest.Config) (*K8sClient, error) {
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
return nil, fmt.Errorf("failed to create k8s clientset: %w", err)
}
return &K8sClient{clientset: clientset}, nil
}
// GetPodVolumes returns volume information for all mounts in a container.
// Returns a map from mount path to VolumeInfo.
func (c *K8sClient) GetPodVolumes(ctx context.Context, namespace, podName, containerName string) (map[string]*VolumeInfo, error) {
pod, err := c.clientset.CoreV1().Pods(namespace).Get(ctx, podName, metav1.GetOptions{})
if err != nil {
return nil, fmt.Errorf("failed to get pod %s/%s: %w", namespace, podName, err)
}
return ExtractVolumeInfo(pod, containerName)
}
// ExtractVolumeInfo extracts volume information from a Pod spec.
// This is the core logic that maps volumeMounts to volumes and determines types.
func ExtractVolumeInfo(pod *corev1.Pod, containerName string) (map[string]*VolumeInfo, error) {
// Build volume name -> type mapping from pod.spec.volumes
volumeTypes := make(map[string]*volumeDetails)
for _, vol := range pod.Spec.Volumes {
volumeTypes[vol.Name] = getVolumeDetails(&vol)
}
// Find the target container
var container *corev1.Container
for i := range pod.Spec.Containers {
if pod.Spec.Containers[i].Name == containerName {
container = &pod.Spec.Containers[i]
break
}
}
if container == nil {
// Try init containers
for i := range pod.Spec.InitContainers {
if pod.Spec.InitContainers[i].Name == containerName {
container = &pod.Spec.InitContainers[i]
break
}
}
}
if container == nil {
return nil, fmt.Errorf("container %s not found in pod", containerName)
}
// Build mount path -> volume info mapping
result := make(map[string]*VolumeInfo)
for _, mount := range container.VolumeMounts {
details, ok := volumeTypes[mount.Name]
if !ok {
continue // Mount references unknown volume
}
result[mount.MountPath] = &VolumeInfo{
VolumeName: mount.Name,
VolumeType: details.volumeType,
MountPath: mount.MountPath,
SubPath: mount.SubPath,
ReadOnly: mount.ReadOnly,
ConfigMapName: details.configMapName,
SecretName: details.secretName,
PVCName: details.pvcName,
}
}
return result, nil
}
// volumeDetails holds extracted volume type information.
type volumeDetails struct {
volumeType string
configMapName string
secretName string
pvcName string
}
// getVolumeDetails extracts type and details from a Volume spec.
func getVolumeDetails(vol *corev1.Volume) *volumeDetails {
d := &volumeDetails{volumeType: "unknown"}
switch {
case vol.EmptyDir != nil:
d.volumeType = "emptyDir"
case vol.ConfigMap != nil:
d.volumeType = "configMap"
d.configMapName = vol.ConfigMap.Name
case vol.Secret != nil:
d.volumeType = "secret"
d.secretName = vol.Secret.SecretName
case vol.PersistentVolumeClaim != nil:
d.volumeType = "persistentVolumeClaim"
d.pvcName = vol.PersistentVolumeClaim.ClaimName
case vol.HostPath != nil:
d.volumeType = "hostPath"
case vol.Projected != nil:
d.volumeType = "projected"
case vol.DownwardAPI != nil:
d.volumeType = "downwardAPI"
case vol.CSI != nil:
d.volumeType = "csi"
case vol.NFS != nil:
d.volumeType = "nfs"
case vol.ISCSI != nil:
d.volumeType = "iscsi"
case vol.GCEPersistentDisk != nil:
d.volumeType = "gcePersistentDisk"
case vol.AWSElasticBlockStore != nil:
d.volumeType = "awsElasticBlockStore"
case vol.AzureDisk != nil:
d.volumeType = "azureDisk"
case vol.AzureFile != nil:
d.volumeType = "azureFile"
case vol.CephFS != nil:
d.volumeType = "cephfs"
case vol.Cinder != nil:
d.volumeType = "cinder"
case vol.FC != nil:
d.volumeType = "fc"
case vol.FlexVolume != nil:
d.volumeType = "flexVolume"
case vol.Flocker != nil:
d.volumeType = "flocker"
case vol.GitRepo != nil:
d.volumeType = "gitRepo"
case vol.Glusterfs != nil:
d.volumeType = "glusterfs"
case vol.PhotonPersistentDisk != nil:
d.volumeType = "photonPersistentDisk"
case vol.PortworxVolume != nil:
d.volumeType = "portworxVolume"
case vol.Quobyte != nil:
d.volumeType = "quobyte"
case vol.RBD != nil:
d.volumeType = "rbd"
case vol.ScaleIO != nil:
d.volumeType = "scaleIO"
case vol.StorageOS != nil:
d.volumeType = "storageos"
case vol.VsphereVolume != nil:
d.volumeType = "vsphereVolume"
case vol.Ephemeral != nil:
d.volumeType = "ephemeral"
}
return d
}
// DetectVolumeTypeFromPath attempts to identify volume type from kubelet path patterns.
// This is a best-effort fallback; accurate volume types require K8s API access via GetPodVolumes.
func DetectVolumeTypeFromPath(hostPath string) (volumeType, volumeName string) {
volumeType = "unknown"
volumeName = ""
// Map of path patterns to volume types
patterns := map[string]string{
"/kubernetes.io~empty-dir/": "emptyDir",
"/kubernetes.io~configmap/": "configMap",
"/kubernetes.io~secret/": "secret",
"/kubernetes.io~projected/": "projected",
"/kubernetes.io~downward-api/": "downwardAPI",
"/kubernetes.io~persistentvolumeclaim/": "persistentVolumeClaim",
"/kubernetes.io~hostpath/": "hostPath",
}
for pattern, vType := range patterns {
if strings.Contains(hostPath, pattern) {
volumeType = vType
// Extract volume name from path
parts := strings.Split(hostPath, pattern)
if len(parts) > 1 {
volumeName = strings.Split(parts[1], "/")[0]
}
break
}
}
return volumeType, volumeName
}
// metadata_builder provides checkpoint metadata construction.
package checkpoint
import (
"context"
"strings"
"github.com/sirupsen/logrus"
checkpointk8s "github.com/ai-dynamo/dynamo/deploy/chrek/pkg/checkpoint/k8s"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
)
// MetadataBuilderConfig holds configuration for building checkpoint metadata.
type MetadataBuilderConfig struct {
CheckpointID string
NodeName string
ContainerID string
ContainerName string
PodName string
PodNamespace string
PID int
CUDAPluginDir string
}
// BuildCheckpointMetadata constructs checkpoint metadata from container state.
func BuildCheckpointMetadata(
ctx context.Context,
cfg MetadataBuilderConfig,
containerInfo *checkpointk8s.ContainerInfo,
mounts []MountMapping,
namespaces map[NamespaceType]*NamespaceInfo,
k8sClient *checkpointk8s.K8sClient,
log *logrus.Entry,
) *common.CheckpointMetadata {
meta := common.NewCheckpointMetadata(cfg.CheckpointID)
meta.SourceNode = cfg.NodeName
meta.ContainerID = cfg.ContainerID
meta.PodName = cfg.PodName
meta.PodNamespace = cfg.PodNamespace
meta.PID = cfg.PID
meta.Image = containerInfo.Image
// Populate OCI spec derived paths
meta.MaskedPaths = containerInfo.GetMaskedPaths()
meta.ReadonlyPaths = containerInfo.GetReadonlyPaths()
// Build mount metadata
ociMountByDest := buildOCIMountLookup(containerInfo, meta)
// Get K8s volume types if available
k8sVolumes := getK8sVolumes(ctx, k8sClient, cfg, log)
// Add mount metadata
for _, mount := range mounts {
mountMeta := buildMountMetadata(mount, k8sVolumes, ociMountByDest)
meta.Mounts = append(meta.Mounts, mountMeta)
}
// Add namespace metadata
for nsType, nsInfo := range namespaces {
meta.Namespaces = append(meta.Namespaces, common.NamespaceMetadata{
Type: string(nsType),
Inode: nsInfo.Inode,
IsExternal: nsInfo.IsExternal,
})
}
// Set CRIU options (hardcoded as always-on for K8s, stored for compatibility)
meta.CRIUOptions = common.CRIUOptionsMetadata{
TcpEstablished: false, // Always false - we close TCP connections
TcpClose: true, // Always true - pod IPs change on restore
ShellJob: true, // Always true - containers are session leaders
FileLocks: true, // Always true - apps use file locks
LeaveRunning: true, // Always true - keep process running after checkpoint
LinkRemap: true, // Always true - handle deleted-but-open files
ExtMasters: true, // Always true - external bind mount masters
}
return meta
}
// buildOCIMountLookup builds a lookup map from OCI mounts and populates bind mount destinations.
func buildOCIMountLookup(containerInfo *checkpointk8s.ContainerInfo, meta *common.CheckpointMetadata) map[string]checkpointk8s.MountInfo {
ociMounts := containerInfo.GetMounts()
ociMountByDest := make(map[string]checkpointk8s.MountInfo)
for _, m := range ociMounts {
ociMountByDest[m.Destination] = m
if m.Type == "bind" {
meta.BindMountDests = append(meta.BindMountDests, m.Destination)
}
}
return ociMountByDest
}
// getK8sVolumes fetches volume types from K8s API if available.
func getK8sVolumes(ctx context.Context, k8sClient *checkpointk8s.K8sClient, cfg MetadataBuilderConfig, log *logrus.Entry) map[string]*checkpointk8s.VolumeInfo {
if k8sClient == nil || cfg.PodNamespace == "" || cfg.PodName == "" || cfg.ContainerName == "" {
return nil
}
k8sVolumes, err := k8sClient.GetPodVolumes(ctx, cfg.PodNamespace, cfg.PodName, cfg.ContainerName)
if err != nil {
log.WithError(err).Warn("Failed to get volume types from K8s API, falling back to path-based detection")
return nil
}
log.WithField("volume_count", len(k8sVolumes)).Debug("Got volume types from K8s API")
return k8sVolumes
}
// buildMountMetadata constructs metadata for a single mount.
func buildMountMetadata(mount MountMapping, k8sVolumes map[string]*checkpointk8s.VolumeInfo, ociMountByDest map[string]checkpointk8s.MountInfo) common.MountMetadata {
var volumeType, volumeName string
// Try K8s API first for accurate volume types
if k8sVolumes != nil {
if volInfo, ok := k8sVolumes[mount.InsidePath]; ok {
volumeType = volInfo.VolumeType
volumeName = volInfo.VolumeName
}
}
// Fall back to path-based detection if K8s API didn't provide info
if volumeType == "" {
volumeType, volumeName = checkpointk8s.DetectVolumeTypeFromPath(mount.OutsidePath)
}
mountMeta := common.MountMetadata{
ContainerPath: mount.InsidePath,
HostPath: mount.OutsidePath,
VolumeType: volumeType,
VolumeName: volumeName,
FSType: mount.FSType,
ReadOnly: strings.Contains(mount.Options, "ro"),
}
// Cross-reference with OCI spec mount if available
if ociMount, ok := ociMountByDest[mount.InsidePath]; ok {
mountMeta.OCISource = ociMount.Source
mountMeta.OCIType = ociMount.Type
mountMeta.OCIOptions = ociMount.Options
}
return mountMeta
}
// mounts provides mount parsing from /proc for CRIU checkpoint.
// This is used for runtime mount state that requires /proc inspection.
package checkpoint
import (
"bufio"
"fmt"
"os"
"strings"
)
// MountMapping represents an external mount for CRIU
type MountMapping struct {
InsidePath string // Path inside container (mount point)
OutsidePath string // Path on host (source)
FSType string // Filesystem type
Source string // Mount source
Options string // Mount options
}
// System mount types that should be filtered out
var systemMountTypes = map[string]bool{
"proc": true,
"sysfs": true,
"devpts": true,
"mqueue": true,
"tmpfs": true, // Note: some tmpfs mounts may need special handling
"cgroup": true,
"cgroup2": true,
"securityfs": true,
"debugfs": true,
"tracefs": true,
"fusectl": true,
"configfs": true,
"devtmpfs": true,
"hugetlbfs": true,
"pstore": true,
"bpf": true,
}
// System mount paths that should always be filtered
var systemMountPaths = map[string]bool{
"/proc": true,
"/sys": true,
"/dev": true,
"/dev/pts": true,
"/dev/shm": true,
"/dev/mqueue": true,
"/run": true,
"/run/secrets": true,
}
// ParseMountInfo parses /proc/<pid>/mountinfo and returns bind mounts
// that need to be handled by CRIU as external mounts
func ParseMountInfo(pid int, hostProc string) ([]MountMapping, error) {
if hostProc == "" {
hostProc = "/proc"
}
mountinfoPath := fmt.Sprintf("%s/%d/mountinfo", hostProc, pid)
file, err := os.Open(mountinfoPath)
if err != nil {
return nil, fmt.Errorf("failed to open mountinfo: %w", err)
}
defer file.Close()
var mounts []MountMapping
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
mount, skip := parseMountInfoLine(line)
if skip {
continue
}
mounts = append(mounts, mount)
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading mountinfo: %w", err)
}
return mounts, nil
}
// parseMountInfoLine parses a single line from mountinfo
// Returns the mount mapping and whether to skip this mount
//
// mountinfo format:
// 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
// (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
//
// (1) mount ID
// (2) parent ID
// (3) major:minor
// (4) root: root of the mount within the filesystem (host-side path for bind mounts)
// (5) mount point: mount point relative to process's root
// (6) mount options
// (7) optional fields (terminated by single hyphen)
// (8) separator (hyphen)
// (9) filesystem type
// (10) mount source (device)
// (11) super options
func parseMountInfoLine(line string) (MountMapping, bool) {
fields := strings.Fields(line)
if len(fields) < 10 {
return MountMapping{}, true
}
root := fields[3] // Host-side path within the filesystem (important for bind mounts)
mountPoint := fields[4] // Container-side mount point
mountOptions := fields[5]
// Find separator (-) to get fstype and source
sepIdx := -1
for i, f := range fields {
if f == "-" {
sepIdx = i
break
}
}
if sepIdx == -1 || sepIdx+2 >= len(fields) {
return MountMapping{}, true
}
fsType := fields[sepIdx+1]
source := fields[sepIdx+2]
superOptions := ""
if sepIdx+3 < len(fields) {
superOptions = fields[sepIdx+3]
}
// Skip system mount types
if systemMountTypes[fsType] {
return MountMapping{}, true
}
// Skip system mount paths
if systemMountPaths[mountPoint] {
return MountMapping{}, true
}
// Skip /sys and /proc prefixed paths
if strings.HasPrefix(mountPoint, "/sys/") || strings.HasPrefix(mountPoint, "/proc/") {
return MountMapping{}, true
}
// Skip overlay (the root filesystem itself)
if fsType == "overlay" && mountPoint == "/" {
return MountMapping{}, true
}
// For bind mounts, the root field contains the actual host path
// Use root as OutsidePath since it gives us the host-side path for volume mounts
outsidePath := root
if root == "/" {
// If root is /, this isn't a bind mount from a subdirectory
outsidePath = source
}
return MountMapping{
InsidePath: mountPoint,
OutsidePath: outsidePath,
FSType: fsType,
Source: source,
Options: mountOptions + "," + superOptions,
}, false
}
// GetBindMounts returns only bind mounts (type "bind" or with bind option)
func GetBindMounts(pid int, hostProc string) ([]MountMapping, error) {
mounts, err := ParseMountInfo(pid, hostProc)
if err != nil {
return nil, err
}
var bindMounts []MountMapping
for _, m := range mounts {
// Bind mounts typically show the underlying filesystem type
// and have paths that look like kubelet volume paths
if strings.Contains(m.OutsidePath, "/var/lib/kubelet/pods/") ||
strings.Contains(m.OutsidePath, "/volumes/") ||
strings.Contains(m.Options, "bind") {
bindMounts = append(bindMounts, m)
}
}
return bindMounts, nil
}
// GetKubernetesVolumeMounts returns mounts that appear to be Kubernetes volumes
func GetKubernetesVolumeMounts(pid int, hostProc string) ([]MountMapping, error) {
mounts, err := ParseMountInfo(pid, hostProc)
if err != nil {
return nil, err
}
var k8sMounts []MountMapping
for _, m := range mounts {
// Kubernetes volumes are identified by:
// 1. Standard kubelet paths: /var/lib/kubelet/pods/
// 2. Minikube/Docker paths: /var/lib/docker/volumes/minikube/_data/lib/kubelet/pods/
// 3. Kubernetes volume markers: kubernetes.io~empty-dir, kubernetes.io~configmap, etc.
if strings.Contains(m.OutsidePath, "/kubelet/pods/") ||
strings.Contains(m.OutsidePath, "/kubernetes.io~") ||
strings.Contains(m.OutsidePath, "/containerd/io.containerd") {
k8sMounts = append(k8sMounts, m)
}
}
return k8sMounts, nil
}
// AllMountInfo represents a mount entry from /proc/<pid>/mountinfo
// This includes ALL mounts without filtering, which CRIU captures during checkpoint.
type AllMountInfo struct {
MountID string // Mount ID
ParentID string // Parent mount ID
MountPoint string // Mount point inside container (container-side path)
Root string // Root of mount within filesystem (host-side path for bind mounts)
FSType string // Filesystem type
Source string // Mount source
Options string // Mount options
SuperOptions string // Super block options
}
// GetAllMountsFromMountinfo parses /proc/<pid>/mountinfo and returns ALL mounts.
// This is used for CRIU checkpoint to mark ALL mounts as external, since CRIU
// captures everything from mountinfo, not just the filtered subset.
// Without marking ALL mounts as external, CRIU restore fails with
// "No mapping for <mount_id>:(null) mountpoint" errors.
func GetAllMountsFromMountinfo(pid int, hostProc string) ([]AllMountInfo, error) {
if hostProc == "" {
hostProc = "/proc"
}
mountinfoPath := fmt.Sprintf("%s/%d/mountinfo", hostProc, pid)
file, err := os.Open(mountinfoPath)
if err != nil {
return nil, fmt.Errorf("failed to open mountinfo: %w", err)
}
defer file.Close()
var mounts []AllMountInfo
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
mount, err := parseAllMountInfoLine(line)
if err != nil {
continue // Skip malformed lines
}
mounts = append(mounts, mount)
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading mountinfo: %w", err)
}
return mounts, nil
}
// parseAllMountInfoLine parses a single line from mountinfo without filtering.
// mountinfo format:
// 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
// (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
func parseAllMountInfoLine(line string) (AllMountInfo, error) {
fields := strings.Fields(line)
if len(fields) < 10 {
return AllMountInfo{}, fmt.Errorf("malformed mountinfo line: %s", line)
}
mountID := fields[0]
parentID := fields[1]
root := fields[3] // Host-side path within the filesystem
mountPoint := fields[4] // Container-side mount point
mountOptions := fields[5]
// Find separator (-) to get fstype and source
sepIdx := -1
for i, f := range fields {
if f == "-" {
sepIdx = i
break
}
}
if sepIdx == -1 || sepIdx+2 >= len(fields) {
return AllMountInfo{}, fmt.Errorf("malformed mountinfo line (no separator): %s", line)
}
fsType := fields[sepIdx+1]
source := fields[sepIdx+2]
superOptions := ""
if sepIdx+3 < len(fields) {
superOptions = fields[sepIdx+3]
}
return AllMountInfo{
MountID: mountID,
ParentID: parentID,
MountPoint: mountPoint,
Root: root,
FSType: fsType,
Source: source,
Options: mountOptions,
SuperOptions: superOptions,
}, nil
}
// namespaces provides Linux namespace introspection for CRIU checkpoint.
package checkpoint
import (
"fmt"
"os"
"strings"
"golang.org/x/sys/unix"
)
// NamespaceType represents a Linux namespace type
type NamespaceType string
const (
NamespaceNet NamespaceType = "net"
NamespacePID NamespaceType = "pid"
NamespaceMnt NamespaceType = "mnt"
NamespaceUTS NamespaceType = "uts"
NamespaceIPC NamespaceType = "ipc"
NamespaceUser NamespaceType = "user"
NamespaceCgroup NamespaceType = "cgroup"
)
// NamespaceInfo holds namespace identification information
type NamespaceInfo struct {
Type NamespaceType
Inode uint64
Path string
IsExternal bool // Whether NS is external (shared with pause container)
}
// GetNamespaceInode returns the inode number for a namespace
func GetNamespaceInode(pid int, nsType NamespaceType, hostProc string) (uint64, error) {
if hostProc == "" {
hostProc = "/proc"
}
nsPath := fmt.Sprintf("%s/%d/ns/%s", hostProc, pid, nsType)
var stat unix.Stat_t
if err := unix.Stat(nsPath, &stat); err != nil {
return 0, fmt.Errorf("failed to stat namespace %s: %w", nsPath, err)
}
return stat.Ino, nil
}
// GetNamespaceInfo returns detailed namespace information
func GetNamespaceInfo(pid int, nsType NamespaceType, hostProc string) (*NamespaceInfo, error) {
if hostProc == "" {
hostProc = "/proc"
}
nsPath := fmt.Sprintf("%s/%d/ns/%s", hostProc, pid, nsType)
// Get inode
var stat unix.Stat_t
if err := unix.Stat(nsPath, &stat); err != nil {
return nil, fmt.Errorf("failed to stat namespace %s: %w", nsPath, err)
}
// Read the symlink to get the namespace identifier
link, err := os.Readlink(nsPath)
if err != nil {
return nil, fmt.Errorf("failed to readlink %s: %w", nsPath, err)
}
// Check if this is different from init's namespace (PID 1)
initNsPath := fmt.Sprintf("%s/1/ns/%s", hostProc, nsType)
var initStat unix.Stat_t
isExternal := false
if err := unix.Stat(initNsPath, &initStat); err == nil {
// If the inode is different from init's, it's an external namespace
isExternal = stat.Ino != initStat.Ino
}
return &NamespaceInfo{
Type: nsType,
Inode: stat.Ino,
Path: link,
IsExternal: isExternal,
}, nil
}
// GetAllNamespaces returns information about all namespaces for a process
func GetAllNamespaces(pid int, hostProc string) (map[NamespaceType]*NamespaceInfo, error) {
nsTypes := []NamespaceType{
NamespaceNet,
NamespacePID,
NamespaceMnt,
NamespaceUTS,
NamespaceIPC,
NamespaceUser,
NamespaceCgroup,
}
namespaces := make(map[NamespaceType]*NamespaceInfo)
for _, nsType := range nsTypes {
info, err := GetNamespaceInfo(pid, nsType, hostProc)
if err != nil {
// Some namespaces might not exist, skip them
continue
}
namespaces[nsType] = info
}
return namespaces, nil
}
// IsNetNamespaceExternal checks if the network namespace is external
// (i.e., shared with the pause container in Kubernetes)
func IsNetNamespaceExternal(pid int, hostProc string) (bool, uint64, error) {
info, err := GetNamespaceInfo(pid, NamespaceNet, hostProc)
if err != nil {
return false, 0, err
}
return info.IsExternal, info.Inode, nil
}
// IsPIDNamespaceExternal checks if the PID namespace is external
func IsPIDNamespaceExternal(pid int, hostProc string) (bool, uint64, error) {
info, err := GetNamespaceInfo(pid, NamespacePID, hostProc)
if err != nil {
return false, 0, err
}
return info.IsExternal, info.Inode, nil
}
// OpenNamespaceFD opens a file descriptor to a namespace
// The caller is responsible for closing the returned file
func OpenNamespaceFD(pid int, nsType NamespaceType, hostProc string) (*os.File, error) {
if hostProc == "" {
hostProc = "/proc"
}
nsPath := fmt.Sprintf("%s/%d/ns/%s", hostProc, pid, nsType)
return os.Open(nsPath)
}
// FormatExternalNamespace formats namespace info for CRIU's External option
// Format: <type>[<inode>]:<key>
func FormatExternalNamespace(nsType NamespaceType, inode uint64) string {
key := formatNamespaceKey(nsType)
return fmt.Sprintf("%s[%d]:%s", nsType, inode, key)
}
// formatNamespaceKey creates the CRIU key for external namespaces
// Format: extRoot<Type>NS (e.g., extRootNetNS, extRootPidNS)
func formatNamespaceKey(nsType NamespaceType) string {
// Capitalize first letter of namespace type
nsName := string(nsType)
if len(nsName) > 0 {
nsName = strings.ToUpper(nsName[:1]) + nsName[1:]
}
return "extRoot" + nsName + "NS"
}
// GetNamespaceKey returns the CRIU key for a namespace type
func GetNamespaceKey(nsType NamespaceType) string {
return formatNamespaceKey(nsType)
}
// rootfs provides container rootfs introspection via /proc for CRIU checkpoint.
package checkpoint
import (
"bufio"
"encoding/json"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"github.com/sirupsen/logrus"
"github.com/ai-dynamo/dynamo/deploy/chrek/pkg/common"
)
// GetRootFS returns the container's root filesystem path
// For containers using overlayfs, this extracts the upperdir
func GetRootFS(pid int, hostProc string) (string, error) {
if hostProc == "" {
hostProc = "/proc"
}
// The rootfs is accessible via /proc/<pid>/root
// But for CRIU, we need the actual filesystem path
rootPath := fmt.Sprintf("%s/%d/root", hostProc, pid)
// Verify it exists
if _, err := os.Stat(rootPath); err != nil {
return "", fmt.Errorf("rootfs not accessible at %s: %w", rootPath, err)
}
return rootPath, nil
}
// GetOverlayUpperDir extracts the overlay upperdir from mountinfo
// This is the writable layer of the container's filesystem
func GetOverlayUpperDir(pid int, hostProc string) (string, error) {
if hostProc == "" {
hostProc = "/proc"
}
mountinfoPath := fmt.Sprintf("%s/%d/mountinfo", hostProc, pid)
file, err := os.Open(mountinfoPath)
if err != nil {
return "", fmt.Errorf("failed to open mountinfo: %w", err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
fields := strings.Fields(line)
// Look for the root mount (mount point is /)
// mountinfo format: id parent major:minor root mount-point options ... - fstype source super-options
if len(fields) < 5 {
continue
}
mountPoint := fields[4]
if mountPoint != "/" {
continue
}
// Find the separator (-) to get fstype and options
sepIdx := -1
for i, f := range fields {
if f == "-" {
sepIdx = i
break
}
}
if sepIdx == -1 || sepIdx+2 >= len(fields) {
continue
}
fsType := fields[sepIdx+1]
if fsType != "overlay" {
continue
}
// Parse super options to find upperdir
superOptions := fields[sepIdx+3]
for _, opt := range strings.Split(superOptions, ",") {
if strings.HasPrefix(opt, "upperdir=") {
return strings.TrimPrefix(opt, "upperdir="), nil
}
}
}
if err := scanner.Err(); err != nil {
return "", fmt.Errorf("error reading mountinfo: %w", err)
}
return "", fmt.Errorf("overlay upperdir not found for pid %d", pid)
}
// DefaultRootfsDiffExclusions are paths excluded from the rootfs diff capture.
// These directories are injected/bind-mounted by NVIDIA GPU Operator at container
// start time, so they already exist in the restore target and cause conflicts
// (especially socket files which cannot be overwritten).
var DefaultRootfsDiffExclusions = []string{
// NVIDIA GPU Operator injects drivers, libraries, and config here
"./usr",
"./etc",
"./opt",
"./var",
// NVIDIA GPU Operator creates runtime sockets and firmware mounts here
// Socket files cause fatal tar errors even with --keep-old-files
"./run",
}
// CaptureRootfsDiff captures the overlay upperdir to a tar file.
// The upperdir contains all filesystem modifications made by the container.
// Excludes bind mount destinations and system directories to avoid conflicts during restore.
// Returns the path to the tar file or empty string if capture failed.
func CaptureRootfsDiff(upperDir, checkpointDir string, excludePaths []string) (string, error) {
if upperDir == "" {
return "", fmt.Errorf("upperdir is empty")
}
rootfsDiffPath := filepath.Join(checkpointDir, "rootfs-diff.tar")
// Build tar arguments with xattrs and exclusions
tarArgs := []string{"--xattrs"}
// Add default exclusions for system directories and caches
for _, excl := range DefaultRootfsDiffExclusions {
tarArgs = append(tarArgs, "--exclude="+excl)
}
// Add bind mount exclusions passed from caller
for _, dest := range excludePaths {
// Convert absolute path to relative for tar (e.g., /etc/hosts -> ./etc/hosts)
tarArgs = append(tarArgs, "--exclude=."+dest)
}
tarArgs = append(tarArgs, "-C", upperDir, "-cf", rootfsDiffPath, ".")
cmd := exec.Command("tar", tarArgs...)
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("tar failed: %w (output: %s)", err, string(output))
}
return rootfsDiffPath, nil
}
// CaptureDeletedFiles finds whiteout files and saves them to a JSON file.
// Returns true if deleted files were found and saved.
func CaptureDeletedFiles(upperDir, checkpointDir string) (bool, error) {
if upperDir == "" {
return false, nil
}
whiteouts, err := FindWhiteoutFiles(upperDir)
if err != nil {
return false, fmt.Errorf("failed to find whiteout files: %w", err)
}
if len(whiteouts) == 0 {
return false, nil
}
deletedFilesPath := filepath.Join(checkpointDir, "deleted-files.json")
data, err := json.Marshal(whiteouts)
if err != nil {
return false, fmt.Errorf("failed to marshal whiteouts: %w", err)
}
if err := os.WriteFile(deletedFilesPath, data, 0644); err != nil {
return false, fmt.Errorf("failed to write deleted files: %w", err)
}
return true, nil
}
// FindWhiteoutFiles finds overlay whiteout files in the upperdir.
// Overlay filesystems use .wh.<filename> to mark deleted files.
// Returns a list of paths that were deleted in the container.
func FindWhiteoutFiles(upperDir string) ([]string, error) {
var whiteouts []string
err := filepath.Walk(upperDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
name := info.Name()
if strings.HasPrefix(name, ".wh.") {
// Convert whiteout marker to actual deleted path
relPath, _ := filepath.Rel(upperDir, path)
dir := filepath.Dir(relPath)
deletedFile := strings.TrimPrefix(name, ".wh.")
if dir == "." {
whiteouts = append(whiteouts, deletedFile)
} else {
whiteouts = append(whiteouts, filepath.Join(dir, deletedFile))
}
}
return nil
})
return whiteouts, err
}
// CaptureRootfsState captures the overlay upperdir and deleted files after CRIU dump.
// Updates the metadata with rootfs diff information and saves it.
func CaptureRootfsState(upperDir, checkpointDir string, meta *common.CheckpointMetadata, log *logrus.Entry) {
if upperDir == "" {
return
}
// Capture rootfs diff
log.WithFields(logrus.Fields{
"default_exclusions": DefaultRootfsDiffExclusions,
"bind_mount_exclusions": meta.BindMountDests,
}).Debug("Rootfs diff exclusions")
rootfsDiffPath, err := CaptureRootfsDiff(upperDir, checkpointDir, meta.BindMountDests)
if err != nil {
log.WithError(err).Warn("Failed to capture rootfs diff")
} else {
meta.RootfsDiffPath = rootfsDiffPath
meta.HasRootfsDiff = true
log.WithFields(logrus.Fields{
"upperdir": upperDir,
"tar_path": rootfsDiffPath,
}).Info("Captured rootfs diff")
}
// Capture deleted files (whiteouts)
hasDeletedFiles, err := CaptureDeletedFiles(upperDir, checkpointDir)
if err != nil {
log.WithError(err).Warn("Failed to capture deleted files")
} else if hasDeletedFiles {
meta.HasDeletedFiles = true
log.Info("Recorded deleted files (whiteouts)")
}
// Update metadata with rootfs diff info
if err := common.SaveMetadata(checkpointDir, meta); err != nil {
log.WithError(err).Warn("Failed to update metadata with rootfs diff info")
}
}
// criu.go provides shared CRIU utilities used by both checkpoint and restore.
package common
import (
"bufio"
"fmt"
"os"
"strings"
"golang.org/x/sys/unix"
)
// OpenDirForCRIU opens a directory and clears the CLOEXEC flag so the FD
// can be inherited by CRIU child processes.
// Returns the opened file and its FD. Caller must close the file when done.
func OpenDirForCRIU(path string) (*os.File, int32, error) {
dir, err := os.Open(path)
if err != nil {
return nil, 0, fmt.Errorf("failed to open %s: %w", path, err)
}
// Clear CLOEXEC so the FD is inherited by CRIU child process.
// Go's os.Open() sets O_CLOEXEC by default, but go-criu's swrk mode
// requires the FD to be inherited.
if _, err := unix.FcntlInt(dir.Fd(), unix.F_SETFD, 0); err != nil {
dir.Close()
return nil, 0, fmt.Errorf("failed to clear CLOEXEC on %s: %w", path, err)
}
return dir, int32(dir.Fd()), nil
}
// DefaultMaskedPaths returns the standard OCI masked paths.
// These paths are typically masked (made inaccessible) in containers.
// Used as fallback when checkpoint metadata doesn't include OCI-derived paths.
func DefaultMaskedPaths() []string {
return []string{
"/proc/bus",
"/proc/fs",
"/proc/irq",
"/proc/sys",
"/proc/sysrq-trigger",
"/proc/acpi",
"/proc/kcore",
"/proc/keys",
"/proc/latency_stats",
"/proc/timer_list",
"/proc/scsi",
"/proc/interrupts",
"/proc/asound",
"/sys/firmware",
"/sys/devices/virtual/powercap",
}
}
// DefaultReadonlyPaths returns the standard OCI readonly paths.
// These paths are typically mounted read-only in containers.
func DefaultReadonlyPaths() []string {
return []string{
"/proc/bus",
"/proc/fs",
"/proc/irq",
"/proc/sys",
"/proc/sysrq-trigger",
}
}
// CRIUMountPoint represents a parsed mount point from /proc/pid/mountinfo.
type CRIUMountPoint struct {
MountID string // Mount ID
ParentID string // Parent mount ID
Path string // Mount point path (container-side)
Root string // Root within filesystem (host-side for bind mounts)
FSType string // Filesystem type
Source string // Mount source
Options string // Mount options
SuperOpts string // Super block options
}
// ParseMountInfoFile parses a mountinfo file and returns all mount points.
// This is used by both checkpoint (to mark mounts as external) and
// restore (to generate external mount mappings).
func ParseMountInfoFile(path string) ([]CRIUMountPoint, error) {
file, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to open mountinfo: %w", err)
}
defer file.Close()
var mounts []CRIUMountPoint
scanner := bufio.NewScanner(file)
for scanner.Scan() {
mount, err := parseCRIUMountInfoLine(scanner.Text())
if err != nil {
continue // Skip malformed lines
}
mounts = append(mounts, mount)
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading mountinfo: %w", err)
}
return mounts, nil
}
// GetMountPointPaths returns just the mount point paths from a mountinfo file.
// This is a convenience function when you only need the paths.
func GetMountPointPaths(path string) ([]string, error) {
mounts, err := ParseMountInfoFile(path)
if err != nil {
return nil, err
}
paths := make([]string, 0, len(mounts))
for _, m := range mounts {
paths = append(paths, m.Path)
}
return paths, nil
}
// parseCRIUMountInfoLine parses a single line from mountinfo.
// mountinfo format:
// 36 35 98:0 /mnt1 /mnt2 rw,noatime master:1 - ext3 /dev/root rw,errors=continue
// (1)(2)(3) (4) (5) (6) (7) (8) (9) (10) (11)
func parseCRIUMountInfoLine(line string) (CRIUMountPoint, error) {
fields := strings.Fields(line)
if len(fields) < 10 {
return CRIUMountPoint{}, fmt.Errorf("malformed mountinfo line")
}
// Find separator (-) to get fstype and source
sepIdx := -1
for i, f := range fields {
if f == "-" {
sepIdx = i
break
}
}
if sepIdx == -1 || sepIdx+2 >= len(fields) {
return CRIUMountPoint{}, fmt.Errorf("malformed mountinfo line (no separator)")
}
superOpts := ""
if sepIdx+3 < len(fields) {
superOpts = fields[sepIdx+3]
}
return CRIUMountPoint{
MountID: fields[0],
ParentID: fields[1],
Root: fields[3],
Path: fields[4],
Options: fields[5],
FSType: fields[sepIdx+1],
Source: fields[sepIdx+2],
SuperOpts: superOpts,
}, nil
}
// metadata.go handles checkpoint metadata for cross-node restore operations.
package common
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
)
const (
// MetadataFilename is the name of the metadata file in checkpoint directories
MetadataFilename = "metadata.json"
// DescriptorsFilename is the name of the file descriptors file
DescriptorsFilename = "descriptors.json"
)
// CheckpointMetadata stores information needed for cross-node restore
type CheckpointMetadata struct {
// Checkpoint identification
CheckpointID string `json:"checkpoint_id"`
CreatedAt time.Time `json:"created_at"`
// Source information
SourceNode string `json:"source_node"`
SourcePodIP string `json:"source_pod_ip,omitempty"` // For cross-node TCP detection
ContainerID string `json:"container_id"`
PodName string `json:"pod_name"`
PodNamespace string `json:"pod_namespace"`
Image string `json:"image"`
// Process information
PID int `json:"pid"`
// Filesystem information
RootfsDiffPath string `json:"rootfs_diff_path,omitempty"` // Path to rootfs-diff.tar
UpperDir string `json:"upper_dir,omitempty"` // Original overlay upperdir
HasRootfsDiff bool `json:"has_rootfs_diff"` // Whether rootfs diff was captured
HasDeletedFiles bool `json:"has_deleted_files"` // Whether deleted files were tracked
// Mount mappings from original container
Mounts []MountMetadata `json:"mounts"`
// OCI spec derived paths (populated from containerd, used at restore)
// These replace hardcoded values with runtime-discovered configuration
MaskedPaths []string `json:"masked_paths,omitempty"` // From OCI spec Linux.MaskedPaths
ReadonlyPaths []string `json:"readonly_paths,omitempty"` // From OCI spec Linux.ReadonlyPaths
BindMountDests []string `json:"bind_mount_dests,omitempty"` // Destinations of bind mounts (for tar exclusions)
// Namespace information
Namespaces []NamespaceMetadata `json:"namespaces"`
// CRIU options used during checkpoint (for restore compatibility)
CRIUOptions CRIUOptionsMetadata `json:"criu_options"`
}
// CRIUOptionsMetadata stores CRIU options used during checkpoint.
// This allows restore to use compatible options.
// Note: In our implementation, most options are hardcoded as always-on for K8s,
// but we store them for compatibility and debugging purposes.
type CRIUOptionsMetadata struct {
TcpEstablished bool `json:"tcp_established"`
TcpClose bool `json:"tcp_close"`
ShellJob bool `json:"shell_job"`
FileLocks bool `json:"file_locks"`
LeaveRunning bool `json:"leave_running"`
LinkRemap bool `json:"link_remap"`
ExtMasters bool `json:"ext_masters"`
}
// MountMetadata stores information about a mount for remapping during restore
type MountMetadata struct {
ContainerPath string `json:"container_path"` // Path inside container (e.g., /usr/share/nginx/html)
HostPath string `json:"host_path"` // Original host path from mountinfo
OCISource string `json:"oci_source,omitempty"` // Source path from OCI spec (may differ from HostPath)
OCIType string `json:"oci_type,omitempty"` // Mount type from OCI spec (bind, tmpfs, etc.)
OCIOptions []string `json:"oci_options,omitempty"` // Mount options from OCI spec
VolumeType string `json:"volume_type"` // emptyDir, pvc, configMap, secret, hostPath (best-effort)
VolumeName string `json:"volume_name"` // Kubernetes volume name (best-effort from path parsing)
FSType string `json:"fs_type"` // Filesystem type from mountinfo
ReadOnly bool `json:"read_only"` // Whether mount is read-only
}
// NamespaceMetadata stores namespace information
type NamespaceMetadata struct {
Type string `json:"type"` // net, pid, mnt, etc.
Inode uint64 `json:"inode"` // Namespace inode
IsExternal bool `json:"is_external"` // Whether namespace is external (shared)
}
// NewCheckpointMetadata creates a new metadata instance
func NewCheckpointMetadata(checkpointID string) *CheckpointMetadata {
return &CheckpointMetadata{
CheckpointID: checkpointID,
CreatedAt: time.Now().UTC(),
Mounts: make([]MountMetadata, 0),
Namespaces: make([]NamespaceMetadata, 0),
}
}
// SaveMetadata writes metadata to a JSON file in the checkpoint directory
func SaveMetadata(checkpointDir string, meta *CheckpointMetadata) error {
data, err := json.MarshalIndent(meta, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal metadata: %w", err)
}
metadataPath := filepath.Join(checkpointDir, MetadataFilename)
if err := os.WriteFile(metadataPath, data, 0644); err != nil {
return fmt.Errorf("failed to write metadata file: %w", err)
}
return nil
}
// LoadMetadata reads metadata from a checkpoint directory
func LoadMetadata(checkpointDir string) (*CheckpointMetadata, error) {
metadataPath := filepath.Join(checkpointDir, MetadataFilename)
data, err := os.ReadFile(metadataPath)
if err != nil {
return nil, fmt.Errorf("failed to read metadata file: %w", err)
}
var meta CheckpointMetadata
if err := json.Unmarshal(data, &meta); err != nil {
return nil, fmt.Errorf("failed to unmarshal metadata: %w", err)
}
return &meta, nil
}
// SaveDescriptors writes file descriptor information to the checkpoint directory
func SaveDescriptors(checkpointDir string, descriptors []string) error {
data, err := json.Marshal(descriptors)
if err != nil {
return fmt.Errorf("failed to marshal descriptors: %w", err)
}
descriptorsPath := filepath.Join(checkpointDir, DescriptorsFilename)
if err := os.WriteFile(descriptorsPath, data, 0600); err != nil {
return fmt.Errorf("failed to write descriptors file: %w", err)
}
return nil
}
// LoadDescriptors reads file descriptor information from checkpoint directory
func LoadDescriptors(checkpointDir string) ([]string, error) {
descriptorsPath := filepath.Join(checkpointDir, DescriptorsFilename)
data, err := os.ReadFile(descriptorsPath)
if err != nil {
return nil, fmt.Errorf("failed to read descriptors file: %w", err)
}
var descriptors []string
if err := json.Unmarshal(data, &descriptors); err != nil {
return nil, fmt.Errorf("failed to unmarshal descriptors: %w", err)
}
return descriptors, nil
}
// GetCheckpointDir returns the path to a checkpoint directory
func GetCheckpointDir(baseDir, checkpointID string) string {
return filepath.Join(baseDir, checkpointID)
}
// ListCheckpoints returns all checkpoint IDs in the base directory
func ListCheckpoints(baseDir string) ([]string, error) {
entries, err := os.ReadDir(baseDir)
if err != nil {
return nil, fmt.Errorf("failed to read checkpoint directory: %w", err)
}
var checkpoints []string
for _, entry := range entries {
if !entry.IsDir() {
continue
}
// Check if metadata file exists
metadataPath := filepath.Join(baseDir, entry.Name(), MetadataFilename)
if _, err := os.Stat(metadataPath); err == nil {
checkpoints = append(checkpoints, entry.Name())
}
}
return checkpoints, nil
}
// GetCheckpointInfo returns metadata for a specific checkpoint
func GetCheckpointInfo(baseDir, checkpointID string) (*CheckpointMetadata, error) {
checkpointDir := GetCheckpointDir(baseDir, checkpointID)
return LoadMetadata(checkpointDir)
}
// DeleteCheckpoint removes a checkpoint directory
func DeleteCheckpoint(baseDir, checkpointID string) error {
checkpointDir := GetCheckpointDir(baseDir, checkpointID)
return os.RemoveAll(checkpointDir)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment