Unverified Commit a0ecd4dd authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix(snapshot): use watch-based waits in snapshotctl (#8024)


Signed-off-by: default avatarSchwinn Saereesitthipitak <schwinns@nvidia.com>
parent e319cd2b
......@@ -17,13 +17,13 @@ from dynamo.common.utils.namespace import get_worker_namespace
logger = logging.getLogger(__name__)
PODINFO_ROOT = "/etc/podinfo"
REQUIRED_PODINFO_FILES = {
KUBERNETES_REQUIRED_PODINFO_FILES = {
"DYN_NAMESPACE": "dyn_namespace",
"DYN_COMPONENT": "dyn_component",
"DYN_PARENT_DGD_K8S_NAME": "dyn_parent_dgd_k8s_name",
"DYN_PARENT_DGD_K8S_NAMESPACE": "dyn_parent_dgd_k8s_namespace",
}
OPTIONAL_PODINFO_FILES = {
KUBERNETES_OPTIONAL_PODINFO_FILES = {
"DYN_NAMESPACE_WORKER_SUFFIX": "dyn_namespace_worker_suffix",
}
EngineT = TypeVar("EngineT")
......@@ -203,12 +203,29 @@ class EngineSnapshotController(Generic[EngineT]):
*self.quiesce_args,
)
def reload_restore_identity(self) -> tuple[str, str]:
return reload_snapshot_restore_identity()
def reload_restore_identity(
self,
namespace: str,
discovery_backend: str,
) -> tuple[str, str]:
return reload_snapshot_restore_identity(namespace, discovery_backend)
def reload_snapshot_restore_identity(
namespace: str,
discovery_backend: str,
) -> tuple[str, str]:
if discovery_backend != "kubernetes":
logger.info(
"Snapshot restore reusing configured discovery backend",
extra={
"dynamo_namespace": namespace,
"discovery_backend": discovery_backend,
},
)
return namespace, discovery_backend
def reload_snapshot_restore_identity() -> tuple[str, str]:
for env_name, podinfo_file in REQUIRED_PODINFO_FILES.items():
for env_name, podinfo_file in KUBERNETES_REQUIRED_PODINFO_FILES.items():
podinfo_path = os.path.join(PODINFO_ROOT, podinfo_file)
if not os.path.isfile(podinfo_path):
raise RuntimeError(f"snapshot restore requires {podinfo_path}")
......@@ -220,7 +237,7 @@ def reload_snapshot_restore_identity() -> tuple[str, str]:
os.environ[env_name] = value
for env_name, podinfo_file in OPTIONAL_PODINFO_FILES.items():
for env_name, podinfo_file in KUBERNETES_OPTIONAL_PODINFO_FILES.items():
podinfo_path = os.path.join(PODINFO_ROOT, podinfo_file)
if not os.path.isfile(podinfo_path):
os.environ.pop(env_name, None)
......@@ -234,7 +251,6 @@ def reload_snapshot_restore_identity() -> tuple[str, str]:
os.environ[env_name] = value
# Snapshot restore only runs in Kubernetes-managed pods, so discovery resets here.
os.environ["DYN_DISCOVERY_BACKEND"] = "kubernetes"
return get_worker_namespace(), "kubernetes"
......
......@@ -50,7 +50,10 @@ async def worker():
(
dynamo_args.namespace,
dynamo_args.discovery_backend,
) = snapshot_controller.reload_restore_identity()
) = snapshot_controller.reload_restore_identity(
dynamo_args.namespace,
dynamo_args.discovery_backend,
)
shutdown_event = asyncio.Event()
shutdown_endpoints: list = []
......
......@@ -146,7 +146,10 @@ async def worker() -> None:
(
config.namespace,
config.discovery_backend,
) = snapshot_controller.reload_restore_identity()
) = snapshot_controller.reload_restore_identity(
config.namespace,
config.discovery_backend,
)
# HEADLESS MODE: bypass DistributedRuntime entirely.
# Workers run vLLM only (no NATS, etcd, or dynamo endpoints).
......
......@@ -24,7 +24,7 @@ data:
restore:
nsRestorePath: {{ .Values.config.restore.nsRestorePath | quote }}
restoreReadyTimeoutSeconds: {{ .Values.config.restore.restoreReadyTimeoutSeconds }}
restoreTimeoutSeconds: {{ .Values.config.restore.restoreTimeoutSeconds }}
criu:
binaryPath: {{ .Values.config.criu.binaryPath | quote }}
......
......@@ -137,8 +137,8 @@ config:
restore:
# Path to the nsrestore binary in the placeholder image
nsRestorePath: /usr/local/bin/nsrestore
# Maximum seconds to wait for a restored pod to become Ready (0 = no timeout)
restoreReadyTimeoutSeconds: 0
# Maximum seconds to allow a restore attempt before snapshot-agent marks it failed
restoreTimeoutSeconds: 7200
criu:
# Path to the criu binary
......
......@@ -10,6 +10,7 @@ import (
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/discovery"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
batchv1 "k8s.io/api/batch/v1"
......@@ -67,6 +68,9 @@ func buildCheckpointJob(
if podTemplate.Annotations == nil {
podTemplate.Annotations = make(map[string]string)
}
if podTemplate.Spec.ServiceAccountName == "" {
podTemplate.Spec.ServiceAccountName = discovery.GetK8sDiscoveryServiceAccountName(ckpt.Name)
}
checkpoint.EnsurePodInfoVolume(&podTemplate.Spec)
......
......@@ -25,6 +25,7 @@ import (
batchv1 "k8s.io/api/batch/v1"
coordinationv1 "k8s.io/api/coordination/v1"
corev1 "k8s.io/api/core/v1"
rbacv1 "k8s.io/api/rbac/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/meta"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
......@@ -40,6 +41,7 @@ import (
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint"
commonController "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/discovery"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
)
......@@ -175,6 +177,11 @@ func (r *CheckpointReconciler) Reconcile(ctx context.Context, req ctrl.Request)
func (r *CheckpointReconciler) handlePending(ctx context.Context, ckpt *nvidiacomv1alpha1.DynamoCheckpoint) (ctrl.Result, error) {
logger := log.FromContext(ctx)
if err := r.reconcileK8sDiscoveryResources(ctx, ckpt); err != nil {
logger.Error(err, "Failed to reconcile K8s discovery resources for checkpoint")
return ctrl.Result{}, fmt.Errorf("failed to reconcile K8s discovery resources for checkpoint: %w", err)
}
hash := ckpt.Status.IdentityHash
if hash == "" {
var err error
......@@ -223,6 +230,47 @@ func (r *CheckpointReconciler) handlePending(ctx context.Context, ckpt *nvidiaco
return ctrl.Result{}, nil
}
func (r *CheckpointReconciler) reconcileK8sDiscoveryResources(ctx context.Context, ckpt *nvidiacomv1alpha1.DynamoCheckpoint) (err error) {
logger := log.FromContext(ctx)
resourceName := ""
defer func() {
if err == nil {
return
}
logger.Error(err, "failed to sync checkpoint k8s discovery resource", "resource", resourceName)
err = fmt.Errorf("failed to sync checkpoint k8s discovery %s: %w", resourceName, err)
}()
resourceName = "service account"
serviceAccount := discovery.GetK8sDiscoveryServiceAccount(ckpt.Name, ckpt.Namespace)
_, _, err = commonController.SyncResource(ctx, r, ckpt, func(ctx context.Context) (*corev1.ServiceAccount, bool, error) {
return serviceAccount, false, nil
})
if err != nil {
return err
}
resourceName = "role"
role := discovery.GetK8sDiscoveryRole(ckpt.Name, ckpt.Namespace)
_, _, err = commonController.SyncResource(ctx, r, ckpt, func(ctx context.Context) (*rbacv1.Role, bool, error) {
return role, false, nil
})
if err != nil {
return err
}
resourceName = "role binding"
roleBinding := discovery.GetK8sDiscoveryRoleBinding(ckpt.Name, ckpt.Namespace)
_, _, err = commonController.SyncResource(ctx, r, ckpt, func(ctx context.Context) (*rbacv1.RoleBinding, bool, error) {
return roleBinding, false, nil
})
if err != nil {
return err
}
return nil
}
func (r *CheckpointReconciler) handleCreating(ctx context.Context, ckpt *nvidiacomv1alpha1.DynamoCheckpoint) (ctrl.Result, error) {
logger := log.FromContext(ctx)
......
......@@ -32,6 +32,7 @@ import (
batchv1 "k8s.io/api/batch/v1"
coordinationv1 "k8s.io/api/coordination/v1"
corev1 "k8s.io/api/core/v1"
rbacv1 "k8s.io/api/rbac/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
......@@ -67,6 +68,7 @@ func checkpointTestScheme() *runtime.Scheme {
_ = corev1.AddToScheme(s)
_ = batchv1.AddToScheme(s)
_ = coordinationv1.AddToScheme(s)
_ = rbacv1.AddToScheme(s)
return s
}
......
......@@ -2,6 +2,7 @@ package main
import (
"context"
"errors"
"fmt"
"strings"
"time"
......@@ -10,7 +11,8 @@ import (
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/watch"
"k8s.io/client-go/kubernetes"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
......@@ -108,69 +110,87 @@ func runCheckpointFlow(ctx context.Context, opts checkpointOptions) (*result, er
func waitForCheckpoint(ctx context.Context, clientset kubernetes.Interface, namespace string, jobName string) (string, error) {
var status string
if err := wait.PollUntilContextCancel(ctx, 2*time.Second, true, func(ctx context.Context) (bool, error) {
job, err := clientset.BatchV1().Jobs(namespace).Get(ctx, jobName, metav1.GetOptions{})
if err != nil {
if apierrors.IsNotFound(err) {
return false, nil
err := watchNamedObject(
ctx,
jobName,
&batchv1.Job{},
func(ctx context.Context, options metav1.ListOptions) (runtime.Object, error) {
return clientset.BatchV1().Jobs(namespace).List(ctx, options)
},
func(ctx context.Context, options metav1.ListOptions) (watch.Interface, error) {
return clientset.BatchV1().Jobs(namespace).Watch(ctx, options)
},
func(event watch.Event) (bool, error) {
if event.Type == watch.Error {
return false, apierrors.FromObject(event.Object)
}
return false, fmt.Errorf("get checkpoint job %s/%s: %w", namespace, jobName, err)
}
status = strings.TrimSpace(job.Annotations[snapshotprotocol.CheckpointStatusAnnotation])
if status == snapshotprotocol.CheckpointStatusCompleted {
return true, nil
}
if status == snapshotprotocol.CheckpointStatusFailed {
return false, fmt.Errorf("checkpoint job %s/%s failed", namespace, jobName)
}
if job.Status.Failed > 0 {
return false, fmt.Errorf("checkpoint job %s/%s failed", namespace, jobName)
}
for _, condition := range job.Status.Conditions {
if condition.Status != corev1.ConditionTrue {
continue
job, ok := event.Object.(*batchv1.Job)
if !ok {
return false, fmt.Errorf("unexpected checkpoint watch object %T", event.Object)
}
if condition.Type == batchv1.JobFailed {
return false, fmt.Errorf("checkpoint job %s/%s failed: %s", namespace, jobName, strings.TrimSpace(condition.Message))
}
}
return false, nil
}); err != nil {
if !wait.Interrupted(err) {
return "", err
}
summaryCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
pods, err := clientset.CoreV1().Pods(namespace).List(summaryCtx, metav1.ListOptions{
LabelSelector: "batch.kubernetes.io/job-name=" + jobName,
})
summary := "no checkpoint pod created yet"
if err != nil {
summary = "unable to list checkpoint pod: " + err.Error()
} else if len(pods.Items) != 0 {
pod := pods.Items[0]
parts := []string{
fmt.Sprintf("job_status=%q", status),
fmt.Sprintf("pod=%s phase=%s", pod.Name, pod.Status.Phase),
status = strings.TrimSpace(job.Annotations[snapshotprotocol.CheckpointStatusAnnotation])
if status == snapshotprotocol.CheckpointStatusCompleted {
return true, nil
}
for _, condition := range pod.Status.Conditions {
if condition.Status == corev1.ConditionTrue || condition.Status == corev1.ConditionFalse {
parts = append(parts, fmt.Sprintf("%s=%s", condition.Type, condition.Status))
}
if status == snapshotprotocol.CheckpointStatusFailed {
return false, fmt.Errorf("checkpoint job %s/%s failed", namespace, jobName)
}
if job.Status.Failed > 0 {
return false, fmt.Errorf("checkpoint job %s/%s failed", namespace, jobName)
}
for _, status := range pod.Status.ContainerStatuses {
if status.State.Waiting != nil {
parts = append(parts, fmt.Sprintf("container=%s waiting=%s", status.Name, status.State.Waiting.Reason))
for _, condition := range job.Status.Conditions {
if condition.Status != corev1.ConditionTrue {
continue
}
if status.State.Terminated != nil {
parts = append(parts, fmt.Sprintf("container=%s terminated=%s", status.Name, status.State.Terminated.Reason))
if condition.Type == batchv1.JobFailed {
return false, fmt.Errorf("checkpoint job %s/%s failed: %s", namespace, jobName, strings.TrimSpace(condition.Message))
}
}
summary = strings.Join(parts, " ")
return false, nil
},
)
if err != nil {
if !errors.Is(err, context.DeadlineExceeded) {
return "", err
}
return "", fmt.Errorf("checkpoint job %s/%s timed out: %s", namespace, jobName, summary)
return "", fmt.Errorf("checkpoint job %s/%s timed out: %s", namespace, jobName, checkpointTimeoutSummary(clientset, namespace, jobName, status))
}
return status, nil
}
func checkpointTimeoutSummary(clientset kubernetes.Interface, namespace string, jobName string, status string) string {
summaryCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
pods, err := clientset.CoreV1().Pods(namespace).List(summaryCtx, metav1.ListOptions{
LabelSelector: "batch.kubernetes.io/job-name=" + jobName,
})
if err != nil {
return "unable to list checkpoint pod: " + err.Error()
}
if len(pods.Items) == 0 {
return "no checkpoint pod created yet"
}
pod := pods.Items[0]
parts := []string{
fmt.Sprintf("job_status=%q", status),
fmt.Sprintf("pod=%s phase=%s", pod.Name, pod.Status.Phase),
}
for _, condition := range pod.Status.Conditions {
if condition.Status == corev1.ConditionTrue || condition.Status == corev1.ConditionFalse {
parts = append(parts, fmt.Sprintf("%s=%s", condition.Type, condition.Status))
}
}
for _, containerStatus := range pod.Status.ContainerStatuses {
if containerStatus.State.Waiting != nil {
parts = append(parts, fmt.Sprintf("container=%s waiting=%s", containerStatus.Name, containerStatus.State.Waiting.Reason))
}
if containerStatus.State.Terminated != nil {
parts = append(parts, fmt.Sprintf("container=%s terminated=%s", containerStatus.Name, containerStatus.State.Terminated.Reason))
}
}
return strings.Join(parts, " ")
}
......@@ -3,6 +3,7 @@ package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
......@@ -10,8 +11,9 @@ import (
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/apimachinery/pkg/watch"
"k8s.io/client-go/kubernetes"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
......@@ -155,28 +157,79 @@ func runRestoreFlow(ctx context.Context, opts restoreOptions) (*result, error) {
func waitForRestore(ctx context.Context, clientset kubernetes.Interface, namespace string, podName string) (string, error) {
var status string
if err := wait.PollUntilContextCancel(ctx, 2*time.Second, true, func(ctx context.Context) (bool, error) {
pod, err := clientset.CoreV1().Pods(namespace).Get(ctx, podName, metav1.GetOptions{})
if err != nil {
return false, fmt.Errorf("get restore pod %s/%s: %w", namespace, podName, err)
err := watchNamedObject(
ctx,
podName,
&corev1.Pod{},
func(ctx context.Context, options metav1.ListOptions) (runtime.Object, error) {
return clientset.CoreV1().Pods(namespace).List(ctx, options)
},
func(ctx context.Context, options metav1.ListOptions) (watch.Interface, error) {
return clientset.CoreV1().Pods(namespace).Watch(ctx, options)
},
func(event watch.Event) (bool, error) {
if event.Type == watch.Error {
return false, apierrors.FromObject(event.Object)
}
pod, ok := event.Object.(*corev1.Pod)
if !ok {
return false, fmt.Errorf("unexpected restore watch object %T", event.Object)
}
status = strings.TrimSpace(pod.Annotations[snapshotprotocol.RestoreStatusAnnotation])
if status == snapshotprotocol.RestoreStatusCompleted {
return true, nil
}
if status == snapshotprotocol.RestoreStatusFailed {
return false, fmt.Errorf("restore pod %s/%s failed", namespace, podName)
}
if pod.Status.Phase == corev1.PodFailed {
return false, fmt.Errorf("restore pod %s/%s entered phase Failed (%s)", namespace, podName, pod.Status.Reason)
}
return false, nil
},
)
if err != nil {
if !errors.Is(err, context.DeadlineExceeded) {
return "", err
}
return "", fmt.Errorf("restore pod %s/%s timed out: %s", namespace, podName, restoreTimeoutSummary(clientset, namespace, podName, status))
}
return status, nil
}
func restoreTimeoutSummary(clientset kubernetes.Interface, namespace string, podName string, status string) string {
summaryCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
status = strings.TrimSpace(pod.Annotations[snapshotprotocol.RestoreStatusAnnotation])
if status == snapshotprotocol.RestoreStatusCompleted {
return true, nil
pod, err := clientset.CoreV1().Pods(namespace).Get(summaryCtx, podName, metav1.GetOptions{})
if err != nil {
if apierrors.IsNotFound(err) {
return fmt.Sprintf("restore_status=%q pod not found", status)
}
if status == snapshotprotocol.RestoreStatusFailed {
return false, fmt.Errorf("restore pod %s/%s failed", namespace, podName)
return "unable to get restore pod: " + err.Error()
}
parts := []string{
fmt.Sprintf("restore_status=%q", status),
fmt.Sprintf("pod=%s phase=%s", pod.Name, pod.Status.Phase),
}
if reason := strings.TrimSpace(pod.Status.Reason); reason != "" {
parts = append(parts, fmt.Sprintf("reason=%s", reason))
}
for _, condition := range pod.Status.Conditions {
if condition.Status == corev1.ConditionTrue || condition.Status == corev1.ConditionFalse {
parts = append(parts, fmt.Sprintf("%s=%s", condition.Type, condition.Status))
}
if pod.Status.Phase == corev1.PodFailed {
return false, fmt.Errorf("restore pod %s/%s entered phase Failed (%s)", namespace, podName, pod.Status.Reason)
}
for _, containerStatus := range pod.Status.ContainerStatuses {
if containerStatus.State.Waiting != nil {
parts = append(parts, fmt.Sprintf("container=%s waiting=%s", containerStatus.Name, containerStatus.State.Waiting.Reason))
}
return false, nil
}); err != nil {
if !wait.Interrupted(err) {
return "", err
if containerStatus.State.Terminated != nil {
parts = append(parts, fmt.Sprintf("container=%s terminated=%s", containerStatus.Name, containerStatus.State.Terminated.Reason))
}
return "", fmt.Errorf("restore pod %s/%s timed out: status=%q", namespace, podName, status)
}
return status, nil
return strings.Join(parts, " ")
}
package main
import (
"context"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/fields"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/watch"
"k8s.io/client-go/tools/cache"
watchtools "k8s.io/client-go/tools/watch"
)
func watchNamedObject(
ctx context.Context,
name string,
objType runtime.Object,
listFn func(context.Context, metav1.ListOptions) (runtime.Object, error),
watchFn func(context.Context, metav1.ListOptions) (watch.Interface, error),
condition func(watch.Event) (bool, error),
) error {
fieldSelector := fields.OneTermEqualSelector("metadata.name", name).String()
lw := &cache.ListWatch{
ListFunc: func(options metav1.ListOptions) (runtime.Object, error) {
options.FieldSelector = fieldSelector
return listFn(ctx, options)
},
WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) {
options.FieldSelector = fieldSelector
return watchFn(ctx, options)
},
}
_, err := watchtools.UntilWithSync(ctx, lw, objType, nil, condition)
if ctx.Err() != nil {
return ctx.Err()
}
return err
}
......@@ -291,7 +291,7 @@ func (w *NodeController) reconcileRestorePod(ctx context.Context, pod *corev1.Po
annotationStatus := pod.Annotations[snapshotprotocol.RestoreStatusAnnotation]
annotationContainerID := pod.Annotations[snapshotprotocol.RestoreContainerIDAnnotation]
if annotationContainerID == containerID && (annotationStatus == snapshotprotocol.RestoreStatusCompleted || annotationStatus == snapshotprotocol.RestoreStatusInProgress) {
if annotationContainerID == containerID && (annotationStatus == snapshotprotocol.RestoreStatusCompleted || annotationStatus == snapshotprotocol.RestoreStatusFailed) {
return
}
......@@ -468,6 +468,12 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
w.release(restoreAttemptKey)
}
}()
restoreCtx := ctx
if timeout := w.config.Restore.RestoreTimeout(); timeout > 0 {
var cancel context.CancelFunc
restoreCtx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
podKey := fmt.Sprintf("%s/%s", pod.Namespace, pod.Name)
log := w.log.WithValues("pod", podKey, "checkpoint_id", checkpointID, "container_id", containerID)
setRestoreStatus := func(value string) error {
......@@ -476,7 +482,7 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
snapshotprotocol.RestoreContainerIDAnnotation: containerID,
}
if err := annotatePod(ctx, w.clientset, log, pod, annotations); err != nil {
if value == snapshotprotocol.RestoreStatusCompleted {
if value == snapshotprotocol.RestoreStatusCompleted || value == snapshotprotocol.RestoreStatusFailed {
releaseOnExit = false
return fmt.Errorf("failed to persist terminal restore status %q: %w", value, err)
}
......@@ -485,10 +491,7 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
return nil
}
if err := annotatePod(ctx, w.clientset, log, pod, map[string]string{
snapshotprotocol.RestoreStatusAnnotation: snapshotprotocol.RestoreStatusInProgress,
snapshotprotocol.RestoreContainerIDAnnotation: containerID,
}); err != nil {
if err := setRestoreStatus(snapshotprotocol.RestoreStatusInProgress); err != nil {
return fmt.Errorf("failed to annotate pod with restore in_progress: %w", err)
}
......@@ -503,17 +506,18 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
ContainerName: containerName,
Clientset: w.clientset,
}
restoredPID, err := executor.Restore(ctx, w.containerd, log, req)
restoredPID, err := executor.Restore(restoreCtx, w.containerd, log, req)
if err != nil {
log.Error(err, "External restore failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil {
return statusErr
}
placeholderHostPID, _, pidErr := snapshotruntime.ResolveContainerByPod(ctx, w.containerd, pod.Name, pod.Namespace, containerName)
if pidErr != nil {
releaseOnExit = false
return fmt.Errorf("restore failed and placeholder PID could not be resolved: %w", pidErr)
}
if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore failed"); killErr != nil {
releaseOnExit = false
return fmt.Errorf("restore failed and placeholder could not be killed: %w", killErr)
}
return nil
......@@ -524,33 +528,33 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
if err != nil {
log.Error(err, "Failed to resolve placeholder host PID for signaling")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
releaseOnExit = false
if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil {
return statusErr
}
return fmt.Errorf("failed to resolve placeholder host PID for signaling: %w", err)
}
if err := snapshotruntime.SendSignalViaPIDNamespace(ctx, log, placeholderHostPID, restoredPID, syscall.SIGCONT, "restore complete"); err != nil {
if err := snapshotruntime.SendSignalViaPIDNamespace(restoreCtx, log, placeholderHostPID, restoredPID, syscall.SIGCONT, "restore complete"); err != nil {
log.Error(err, "Failed to signal restored runtime process")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil {
return statusErr
}
if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore signaling failed"); killErr != nil {
log.Error(killErr, "Failed to kill placeholder after restore signaling failure")
}
releaseOnExit = false
return fmt.Errorf("failed to signal restored runtime process: %w", err)
}
// Step 3: Wait for the pod to become Ready
readyCtx := ctx
if timeout := w.config.Restore.RestoreReadyTimeout(); timeout > 0 {
var cancel context.CancelFunc
readyCtx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
if err := waitForPodReady(readyCtx, w.clientset, pod.Namespace, pod.Name, containerName); err != nil {
if err := waitForPodReady(restoreCtx, w.clientset, pod.Namespace, pod.Name, containerName); err != nil {
log.Error(err, "Restore post-signal readiness check failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil {
return statusErr
}
if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore readiness failed"); killErr != nil {
log.Error(killErr, "Failed to kill placeholder after restore readiness failure")
}
releaseOnExit = false
return fmt.Errorf("restore post-signal readiness check failed: %w", err)
}
......
......@@ -324,7 +324,7 @@ func TestReconcileRestorePod(t *testing.T) {
want: false,
},
{
name: "already in progress for same container",
name: "in progress for same container retries after restart",
nodeName: testNodeName,
phase: corev1.PodRunning,
ready: false,
......@@ -332,6 +332,17 @@ func TestReconcileRestorePod(t *testing.T) {
annotationStatus: "in_progress",
annotationContainerID: testContainerID,
createDir: true,
want: true,
},
{
name: "already failed for same container",
nodeName: testNodeName,
phase: corev1.PodRunning,
ready: false,
hash: "abc123",
annotationStatus: "failed",
annotationContainerID: testContainerID,
createDir: true,
want: false,
},
{
......@@ -345,6 +356,17 @@ func TestReconcileRestorePod(t *testing.T) {
createDir: true,
want: true,
},
{
name: "failed for previous container retries",
nodeName: testNodeName,
phase: corev1.PodRunning,
ready: false,
hash: "abc123",
annotationStatus: "failed",
annotationContainerID: "old-container",
createDir: true,
want: true,
},
{
name: "in progress for previous container retries",
nodeName: testNodeName,
......
......@@ -56,21 +56,24 @@ type StorageSpec struct {
// RestoreSpec holds settings for the CRIU restore process.
type RestoreSpec struct {
NSRestorePath string `yaml:"nsRestorePath"`
RestoreReadyTimeoutSeconds int `yaml:"restoreReadyTimeoutSeconds"`
NSRestorePath string `yaml:"nsRestorePath"`
RestoreTimeoutSeconds int `yaml:"restoreTimeoutSeconds"`
}
func (c *RestoreSpec) RestoreReadyTimeout() time.Duration {
if c.RestoreReadyTimeoutSeconds <= 0 {
func (c *RestoreSpec) RestoreTimeout() time.Duration {
if c.RestoreTimeoutSeconds <= 0 {
return 0
}
return time.Duration(c.RestoreReadyTimeoutSeconds) * time.Second
return time.Duration(c.RestoreTimeoutSeconds) * time.Second
}
func (c *RestoreSpec) Validate() error {
if c.NSRestorePath == "" {
return &ConfigError{Field: "nsRestorePath", Message: "nsRestorePath is required"}
}
if c.RestoreTimeoutSeconds <= 0 {
return &ConfigError{Field: "restoreTimeoutSeconds", Message: "restoreTimeoutSeconds must be greater than zero"}
}
return nil
}
......
......@@ -6,6 +6,7 @@ package protocol
import (
"context"
"fmt"
"math"
"strings"
appsv1 "k8s.io/api/apps/v1"
......@@ -61,12 +62,28 @@ func PrepareRestorePodSpec(
if isCheckpointReady {
container.Command = []string{"sleep", "infinity"}
container.Args = nil
container.StartupProbe = nil
container.LivenessProbe = nil
container.ReadinessProbe = nil
ensureRestoreStartupProbe(container)
}
}
func ensureRestoreStartupProbe(container *corev1.Container) {
startup := container.StartupProbe
if startup == nil {
startup = container.LivenessProbe
if startup == nil {
startup = container.ReadinessProbe
}
}
if startup == nil {
return
}
startup = startup.DeepCopy()
startup.FailureThreshold = math.MaxInt32
startup.SuccessThreshold = 1
container.StartupProbe = startup
}
func ValidateRestorePodSpec(
podSpec *corev1.PodSpec,
storage Storage,
......
package protocol
import (
"math"
"testing"
appsv1 "k8s.io/api/apps/v1"
......@@ -9,6 +10,9 @@ import (
)
func TestNewRestorePod(t *testing.T) {
readinessProbe := &corev1.Probe{PeriodSeconds: 7, TimeoutSeconds: 3}
livenessProbe := &corev1.Probe{InitialDelaySeconds: 11}
startupProbe := &corev1.Probe{FailureThreshold: 120}
restorePod := NewRestorePod(&corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: "worker",
......@@ -22,9 +26,9 @@ func TestNewRestorePod(t *testing.T) {
Image: "test:latest",
Command: []string{"python3", "-m", "dynamo.vllm"},
Args: []string{"--model", "Qwen"},
ReadinessProbe: &corev1.Probe{},
LivenessProbe: &corev1.Probe{},
StartupProbe: &corev1.Probe{},
ReadinessProbe: readinessProbe.DeepCopy(),
LivenessProbe: livenessProbe.DeepCopy(),
StartupProbe: startupProbe.DeepCopy(),
}},
},
}, PodOptions{
......@@ -60,14 +64,23 @@ func TestNewRestorePod(t *testing.T) {
if restorePod.Spec.Containers[0].Args != nil {
t.Fatalf("expected restore args to be cleared: %#v", restorePod.Spec.Containers[0].Args)
}
if restorePod.Spec.Containers[0].ReadinessProbe != nil {
t.Fatalf("expected readiness probe to be cleared: %#v", restorePod.Spec.Containers[0].ReadinessProbe)
if restorePod.Spec.Containers[0].ReadinessProbe == nil {
t.Fatalf("expected readiness probe to be preserved")
}
if restorePod.Spec.Containers[0].LivenessProbe != nil {
t.Fatalf("expected liveness probe to be cleared: %#v", restorePod.Spec.Containers[0].LivenessProbe)
if got := restorePod.Spec.Containers[0].ReadinessProbe.PeriodSeconds; got != readinessProbe.PeriodSeconds {
t.Fatalf("expected readiness probe period %d, got %d", readinessProbe.PeriodSeconds, got)
}
if restorePod.Spec.Containers[0].StartupProbe != nil {
t.Fatalf("expected startup probe to be cleared: %#v", restorePod.Spec.Containers[0].StartupProbe)
if restorePod.Spec.Containers[0].LivenessProbe == nil {
t.Fatalf("expected liveness probe to be preserved")
}
if got := restorePod.Spec.Containers[0].LivenessProbe.InitialDelaySeconds; got != livenessProbe.InitialDelaySeconds {
t.Fatalf("expected liveness initial delay %d, got %d", livenessProbe.InitialDelaySeconds, got)
}
if restorePod.Spec.Containers[0].StartupProbe == nil {
t.Fatalf("expected startup probe to be preserved")
}
if got := restorePod.Spec.Containers[0].StartupProbe.FailureThreshold; got != math.MaxInt32 {
t.Fatalf("expected startup failure threshold %d, got %d", math.MaxInt32, got)
}
if restorePod.Spec.SecurityContext == nil || restorePod.Spec.SecurityContext.SeccompProfile == nil {
t.Fatalf("expected seccomp profile to be injected: %#v", restorePod.Spec.SecurityContext)
......@@ -82,12 +95,15 @@ func TestNewRestorePod(t *testing.T) {
func TestPrepareRestorePodSpec(t *testing.T) {
podSpec := corev1.PodSpec{}
readinessProbe := &corev1.Probe{PeriodSeconds: 13, SuccessThreshold: 1}
livenessProbe := &corev1.Probe{TimeoutSeconds: 5}
startupProbe := &corev1.Probe{FailureThreshold: 60}
container := corev1.Container{
Command: []string{"python3", "-m", "dynamo.vllm"},
Args: []string{"--model", "Qwen"},
ReadinessProbe: &corev1.Probe{},
LivenessProbe: &corev1.Probe{},
StartupProbe: &corev1.Probe{},
ReadinessProbe: readinessProbe.DeepCopy(),
LivenessProbe: livenessProbe.DeepCopy(),
StartupProbe: startupProbe.DeepCopy(),
}
storage := Storage{
......@@ -113,8 +129,99 @@ func TestPrepareRestorePodSpec(t *testing.T) {
if container.Args != nil {
t.Fatalf("expected restore args to be cleared: %#v", container.Args)
}
if container.ReadinessProbe != nil || container.LivenessProbe != nil || container.StartupProbe != nil {
t.Fatalf("expected probes to be cleared: %#v %#v %#v", container.ReadinessProbe, container.LivenessProbe, container.StartupProbe)
if container.ReadinessProbe == nil {
t.Fatalf("expected readiness probe to be preserved")
}
if got := container.ReadinessProbe.PeriodSeconds; got != readinessProbe.PeriodSeconds {
t.Fatalf("expected readiness probe period %d, got %d", readinessProbe.PeriodSeconds, got)
}
if container.LivenessProbe == nil {
t.Fatalf("expected liveness probe to be preserved")
}
if got := container.LivenessProbe.TimeoutSeconds; got != livenessProbe.TimeoutSeconds {
t.Fatalf("expected liveness timeout %d, got %d", livenessProbe.TimeoutSeconds, got)
}
if container.StartupProbe == nil {
t.Fatalf("expected startup probe to be preserved")
}
if got := container.StartupProbe.FailureThreshold; got != math.MaxInt32 {
t.Fatalf("expected startup failure threshold %d, got %d", math.MaxInt32, got)
}
if got := container.StartupProbe.SuccessThreshold; got != 1 {
t.Fatalf("expected startup success threshold 1, got %d", got)
}
}
func TestPrepareRestorePodSpecSynthesizesStartupProbeFromLiveness(t *testing.T) {
podSpec := corev1.PodSpec{}
livenessProbe := &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{Path: "/livez"},
},
PeriodSeconds: 5,
TimeoutSeconds: 4,
FailureThreshold: 2,
}
container := corev1.Container{
Command: []string{"python3", "-m", "dynamo.vllm"},
Args: []string{"--model", "Qwen"},
LivenessProbe: livenessProbe.DeepCopy(),
}
PrepareRestorePodSpec(&podSpec, &container, Storage{}, "", true)
if container.LivenessProbe == nil {
t.Fatalf("expected liveness probe to be preserved")
}
if container.StartupProbe == nil {
t.Fatalf("expected startup probe to be synthesized")
}
if container.StartupProbe.HTTPGet == nil || container.StartupProbe.HTTPGet.Path != "/livez" {
t.Fatalf("expected startup probe HTTP path /livez, got %#v", container.StartupProbe.HTTPGet)
}
if got := container.StartupProbe.FailureThreshold; got != math.MaxInt32 {
t.Fatalf("expected startup failure threshold %d, got %d", math.MaxInt32, got)
}
if got := container.StartupProbe.SuccessThreshold; got != 1 {
t.Fatalf("expected startup success threshold 1, got %d", got)
}
}
func TestPrepareRestorePodSpecSynthesizesStartupProbeFromReadiness(t *testing.T) {
podSpec := corev1.PodSpec{}
readinessProbe := &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
Exec: &corev1.ExecAction{Command: []string{"cat", "/tmp/ready"}},
},
PeriodSeconds: 13,
SuccessThreshold: 3,
FailureThreshold: 4,
}
container := corev1.Container{
Command: []string{"python3", "-m", "dynamo.vllm"},
Args: []string{"--model", "Qwen"},
ReadinessProbe: readinessProbe.DeepCopy(),
}
PrepareRestorePodSpec(&podSpec, &container, Storage{}, "", true)
if container.ReadinessProbe == nil {
t.Fatalf("expected readiness probe to be preserved")
}
if got := container.ReadinessProbe.SuccessThreshold; got != readinessProbe.SuccessThreshold {
t.Fatalf("expected readiness success threshold %d, got %d", readinessProbe.SuccessThreshold, got)
}
if container.StartupProbe == nil {
t.Fatalf("expected startup probe to be synthesized")
}
if container.StartupProbe.Exec == nil || len(container.StartupProbe.Exec.Command) != 2 || container.StartupProbe.Exec.Command[0] != "cat" || container.StartupProbe.Exec.Command[1] != "/tmp/ready" {
t.Fatalf("expected startup probe exec command to match readiness probe: %#v", container.StartupProbe.Exec)
}
if got := container.StartupProbe.FailureThreshold; got != math.MaxInt32 {
t.Fatalf("expected startup failure threshold %d, got %d", math.MaxInt32, got)
}
if got := container.StartupProbe.SuccessThreshold; got != 1 {
t.Fatalf("expected startup success threshold 1, got %d", got)
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment