Unverified Commit a0ecd4dd authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix(snapshot): use watch-based waits in snapshotctl (#8024)


Signed-off-by: default avatarSchwinn Saereesitthipitak <schwinns@nvidia.com>
parent e319cd2b
...@@ -17,13 +17,13 @@ from dynamo.common.utils.namespace import get_worker_namespace ...@@ -17,13 +17,13 @@ from dynamo.common.utils.namespace import get_worker_namespace
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PODINFO_ROOT = "/etc/podinfo" PODINFO_ROOT = "/etc/podinfo"
REQUIRED_PODINFO_FILES = { KUBERNETES_REQUIRED_PODINFO_FILES = {
"DYN_NAMESPACE": "dyn_namespace", "DYN_NAMESPACE": "dyn_namespace",
"DYN_COMPONENT": "dyn_component", "DYN_COMPONENT": "dyn_component",
"DYN_PARENT_DGD_K8S_NAME": "dyn_parent_dgd_k8s_name", "DYN_PARENT_DGD_K8S_NAME": "dyn_parent_dgd_k8s_name",
"DYN_PARENT_DGD_K8S_NAMESPACE": "dyn_parent_dgd_k8s_namespace", "DYN_PARENT_DGD_K8S_NAMESPACE": "dyn_parent_dgd_k8s_namespace",
} }
OPTIONAL_PODINFO_FILES = { KUBERNETES_OPTIONAL_PODINFO_FILES = {
"DYN_NAMESPACE_WORKER_SUFFIX": "dyn_namespace_worker_suffix", "DYN_NAMESPACE_WORKER_SUFFIX": "dyn_namespace_worker_suffix",
} }
EngineT = TypeVar("EngineT") EngineT = TypeVar("EngineT")
...@@ -203,12 +203,29 @@ class EngineSnapshotController(Generic[EngineT]): ...@@ -203,12 +203,29 @@ class EngineSnapshotController(Generic[EngineT]):
*self.quiesce_args, *self.quiesce_args,
) )
def reload_restore_identity(self) -> tuple[str, str]: def reload_restore_identity(
return reload_snapshot_restore_identity() self,
namespace: str,
discovery_backend: str,
) -> tuple[str, str]:
return reload_snapshot_restore_identity(namespace, discovery_backend)
def reload_snapshot_restore_identity(
namespace: str,
discovery_backend: str,
) -> tuple[str, str]:
if discovery_backend != "kubernetes":
logger.info(
"Snapshot restore reusing configured discovery backend",
extra={
"dynamo_namespace": namespace,
"discovery_backend": discovery_backend,
},
)
return namespace, discovery_backend
def reload_snapshot_restore_identity() -> tuple[str, str]: for env_name, podinfo_file in KUBERNETES_REQUIRED_PODINFO_FILES.items():
for env_name, podinfo_file in REQUIRED_PODINFO_FILES.items():
podinfo_path = os.path.join(PODINFO_ROOT, podinfo_file) podinfo_path = os.path.join(PODINFO_ROOT, podinfo_file)
if not os.path.isfile(podinfo_path): if not os.path.isfile(podinfo_path):
raise RuntimeError(f"snapshot restore requires {podinfo_path}") raise RuntimeError(f"snapshot restore requires {podinfo_path}")
...@@ -220,7 +237,7 @@ def reload_snapshot_restore_identity() -> tuple[str, str]: ...@@ -220,7 +237,7 @@ def reload_snapshot_restore_identity() -> tuple[str, str]:
os.environ[env_name] = value os.environ[env_name] = value
for env_name, podinfo_file in OPTIONAL_PODINFO_FILES.items(): for env_name, podinfo_file in KUBERNETES_OPTIONAL_PODINFO_FILES.items():
podinfo_path = os.path.join(PODINFO_ROOT, podinfo_file) podinfo_path = os.path.join(PODINFO_ROOT, podinfo_file)
if not os.path.isfile(podinfo_path): if not os.path.isfile(podinfo_path):
os.environ.pop(env_name, None) os.environ.pop(env_name, None)
...@@ -234,7 +251,6 @@ def reload_snapshot_restore_identity() -> tuple[str, str]: ...@@ -234,7 +251,6 @@ def reload_snapshot_restore_identity() -> tuple[str, str]:
os.environ[env_name] = value os.environ[env_name] = value
# Snapshot restore only runs in Kubernetes-managed pods, so discovery resets here.
os.environ["DYN_DISCOVERY_BACKEND"] = "kubernetes" os.environ["DYN_DISCOVERY_BACKEND"] = "kubernetes"
return get_worker_namespace(), "kubernetes" return get_worker_namespace(), "kubernetes"
......
...@@ -50,7 +50,10 @@ async def worker(): ...@@ -50,7 +50,10 @@ async def worker():
( (
dynamo_args.namespace, dynamo_args.namespace,
dynamo_args.discovery_backend, dynamo_args.discovery_backend,
) = snapshot_controller.reload_restore_identity() ) = snapshot_controller.reload_restore_identity(
dynamo_args.namespace,
dynamo_args.discovery_backend,
)
shutdown_event = asyncio.Event() shutdown_event = asyncio.Event()
shutdown_endpoints: list = [] shutdown_endpoints: list = []
......
...@@ -146,7 +146,10 @@ async def worker() -> None: ...@@ -146,7 +146,10 @@ async def worker() -> None:
( (
config.namespace, config.namespace,
config.discovery_backend, config.discovery_backend,
) = snapshot_controller.reload_restore_identity() ) = snapshot_controller.reload_restore_identity(
config.namespace,
config.discovery_backend,
)
# HEADLESS MODE: bypass DistributedRuntime entirely. # HEADLESS MODE: bypass DistributedRuntime entirely.
# Workers run vLLM only (no NATS, etcd, or dynamo endpoints). # Workers run vLLM only (no NATS, etcd, or dynamo endpoints).
......
...@@ -24,7 +24,7 @@ data: ...@@ -24,7 +24,7 @@ data:
restore: restore:
nsRestorePath: {{ .Values.config.restore.nsRestorePath | quote }} nsRestorePath: {{ .Values.config.restore.nsRestorePath | quote }}
restoreReadyTimeoutSeconds: {{ .Values.config.restore.restoreReadyTimeoutSeconds }} restoreTimeoutSeconds: {{ .Values.config.restore.restoreTimeoutSeconds }}
criu: criu:
binaryPath: {{ .Values.config.criu.binaryPath | quote }} binaryPath: {{ .Values.config.criu.binaryPath | quote }}
......
...@@ -137,8 +137,8 @@ config: ...@@ -137,8 +137,8 @@ config:
restore: restore:
# Path to the nsrestore binary in the placeholder image # Path to the nsrestore binary in the placeholder image
nsRestorePath: /usr/local/bin/nsrestore nsRestorePath: /usr/local/bin/nsrestore
# Maximum seconds to wait for a restored pod to become Ready (0 = no timeout) # Maximum seconds to allow a restore attempt before snapshot-agent marks it failed
restoreReadyTimeoutSeconds: 0 restoreTimeoutSeconds: 7200
criu: criu:
# Path to the criu binary # Path to the criu binary
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1" nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint" "github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/consts" "github.com/ai-dynamo/dynamo/deploy/operator/internal/consts"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/discovery"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo" "github.com/ai-dynamo/dynamo/deploy/operator/internal/dynamo"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol" snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
batchv1 "k8s.io/api/batch/v1" batchv1 "k8s.io/api/batch/v1"
...@@ -67,6 +68,9 @@ func buildCheckpointJob( ...@@ -67,6 +68,9 @@ func buildCheckpointJob(
if podTemplate.Annotations == nil { if podTemplate.Annotations == nil {
podTemplate.Annotations = make(map[string]string) podTemplate.Annotations = make(map[string]string)
} }
if podTemplate.Spec.ServiceAccountName == "" {
podTemplate.Spec.ServiceAccountName = discovery.GetK8sDiscoveryServiceAccountName(ckpt.Name)
}
checkpoint.EnsurePodInfoVolume(&podTemplate.Spec) checkpoint.EnsurePodInfoVolume(&podTemplate.Spec)
......
...@@ -25,6 +25,7 @@ import ( ...@@ -25,6 +25,7 @@ import (
batchv1 "k8s.io/api/batch/v1" batchv1 "k8s.io/api/batch/v1"
coordinationv1 "k8s.io/api/coordination/v1" coordinationv1 "k8s.io/api/coordination/v1"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
rbacv1 "k8s.io/api/rbac/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors" apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/api/meta"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
...@@ -40,6 +41,7 @@ import ( ...@@ -40,6 +41,7 @@ import (
nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1" nvidiacomv1alpha1 "github.com/ai-dynamo/dynamo/deploy/operator/api/v1alpha1"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint" "github.com/ai-dynamo/dynamo/deploy/operator/internal/checkpoint"
commonController "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common" commonController "github.com/ai-dynamo/dynamo/deploy/operator/internal/controller_common"
"github.com/ai-dynamo/dynamo/deploy/operator/internal/discovery"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol" snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
) )
...@@ -175,6 +177,11 @@ func (r *CheckpointReconciler) Reconcile(ctx context.Context, req ctrl.Request) ...@@ -175,6 +177,11 @@ func (r *CheckpointReconciler) Reconcile(ctx context.Context, req ctrl.Request)
func (r *CheckpointReconciler) handlePending(ctx context.Context, ckpt *nvidiacomv1alpha1.DynamoCheckpoint) (ctrl.Result, error) { func (r *CheckpointReconciler) handlePending(ctx context.Context, ckpt *nvidiacomv1alpha1.DynamoCheckpoint) (ctrl.Result, error) {
logger := log.FromContext(ctx) logger := log.FromContext(ctx)
if err := r.reconcileK8sDiscoveryResources(ctx, ckpt); err != nil {
logger.Error(err, "Failed to reconcile K8s discovery resources for checkpoint")
return ctrl.Result{}, fmt.Errorf("failed to reconcile K8s discovery resources for checkpoint: %w", err)
}
hash := ckpt.Status.IdentityHash hash := ckpt.Status.IdentityHash
if hash == "" { if hash == "" {
var err error var err error
...@@ -223,6 +230,47 @@ func (r *CheckpointReconciler) handlePending(ctx context.Context, ckpt *nvidiaco ...@@ -223,6 +230,47 @@ func (r *CheckpointReconciler) handlePending(ctx context.Context, ckpt *nvidiaco
return ctrl.Result{}, nil return ctrl.Result{}, nil
} }
func (r *CheckpointReconciler) reconcileK8sDiscoveryResources(ctx context.Context, ckpt *nvidiacomv1alpha1.DynamoCheckpoint) (err error) {
logger := log.FromContext(ctx)
resourceName := ""
defer func() {
if err == nil {
return
}
logger.Error(err, "failed to sync checkpoint k8s discovery resource", "resource", resourceName)
err = fmt.Errorf("failed to sync checkpoint k8s discovery %s: %w", resourceName, err)
}()
resourceName = "service account"
serviceAccount := discovery.GetK8sDiscoveryServiceAccount(ckpt.Name, ckpt.Namespace)
_, _, err = commonController.SyncResource(ctx, r, ckpt, func(ctx context.Context) (*corev1.ServiceAccount, bool, error) {
return serviceAccount, false, nil
})
if err != nil {
return err
}
resourceName = "role"
role := discovery.GetK8sDiscoveryRole(ckpt.Name, ckpt.Namespace)
_, _, err = commonController.SyncResource(ctx, r, ckpt, func(ctx context.Context) (*rbacv1.Role, bool, error) {
return role, false, nil
})
if err != nil {
return err
}
resourceName = "role binding"
roleBinding := discovery.GetK8sDiscoveryRoleBinding(ckpt.Name, ckpt.Namespace)
_, _, err = commonController.SyncResource(ctx, r, ckpt, func(ctx context.Context) (*rbacv1.RoleBinding, bool, error) {
return roleBinding, false, nil
})
if err != nil {
return err
}
return nil
}
func (r *CheckpointReconciler) handleCreating(ctx context.Context, ckpt *nvidiacomv1alpha1.DynamoCheckpoint) (ctrl.Result, error) { func (r *CheckpointReconciler) handleCreating(ctx context.Context, ckpt *nvidiacomv1alpha1.DynamoCheckpoint) (ctrl.Result, error) {
logger := log.FromContext(ctx) logger := log.FromContext(ctx)
......
...@@ -32,6 +32,7 @@ import ( ...@@ -32,6 +32,7 @@ import (
batchv1 "k8s.io/api/batch/v1" batchv1 "k8s.io/api/batch/v1"
coordinationv1 "k8s.io/api/coordination/v1" coordinationv1 "k8s.io/api/coordination/v1"
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
rbacv1 "k8s.io/api/rbac/v1"
"k8s.io/apimachinery/pkg/api/resource" "k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime"
...@@ -67,6 +68,7 @@ func checkpointTestScheme() *runtime.Scheme { ...@@ -67,6 +68,7 @@ func checkpointTestScheme() *runtime.Scheme {
_ = corev1.AddToScheme(s) _ = corev1.AddToScheme(s)
_ = batchv1.AddToScheme(s) _ = batchv1.AddToScheme(s)
_ = coordinationv1.AddToScheme(s) _ = coordinationv1.AddToScheme(s)
_ = rbacv1.AddToScheme(s)
return s return s
} }
......
...@@ -2,6 +2,7 @@ package main ...@@ -2,6 +2,7 @@ package main
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"strings" "strings"
"time" "time"
...@@ -10,7 +11,8 @@ import ( ...@@ -10,7 +11,8 @@ import (
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors" apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/watch"
"k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol" snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
...@@ -108,69 +110,87 @@ func runCheckpointFlow(ctx context.Context, opts checkpointOptions) (*result, er ...@@ -108,69 +110,87 @@ func runCheckpointFlow(ctx context.Context, opts checkpointOptions) (*result, er
func waitForCheckpoint(ctx context.Context, clientset kubernetes.Interface, namespace string, jobName string) (string, error) { func waitForCheckpoint(ctx context.Context, clientset kubernetes.Interface, namespace string, jobName string) (string, error) {
var status string var status string
if err := wait.PollUntilContextCancel(ctx, 2*time.Second, true, func(ctx context.Context) (bool, error) { err := watchNamedObject(
job, err := clientset.BatchV1().Jobs(namespace).Get(ctx, jobName, metav1.GetOptions{}) ctx,
if err != nil { jobName,
if apierrors.IsNotFound(err) { &batchv1.Job{},
return false, nil func(ctx context.Context, options metav1.ListOptions) (runtime.Object, error) {
return clientset.BatchV1().Jobs(namespace).List(ctx, options)
},
func(ctx context.Context, options metav1.ListOptions) (watch.Interface, error) {
return clientset.BatchV1().Jobs(namespace).Watch(ctx, options)
},
func(event watch.Event) (bool, error) {
if event.Type == watch.Error {
return false, apierrors.FromObject(event.Object)
} }
return false, fmt.Errorf("get checkpoint job %s/%s: %w", namespace, jobName, err)
}
status = strings.TrimSpace(job.Annotations[snapshotprotocol.CheckpointStatusAnnotation]) job, ok := event.Object.(*batchv1.Job)
if status == snapshotprotocol.CheckpointStatusCompleted { if !ok {
return true, nil return false, fmt.Errorf("unexpected checkpoint watch object %T", event.Object)
}
if status == snapshotprotocol.CheckpointStatusFailed {
return false, fmt.Errorf("checkpoint job %s/%s failed", namespace, jobName)
}
if job.Status.Failed > 0 {
return false, fmt.Errorf("checkpoint job %s/%s failed", namespace, jobName)
}
for _, condition := range job.Status.Conditions {
if condition.Status != corev1.ConditionTrue {
continue
} }
if condition.Type == batchv1.JobFailed {
return false, fmt.Errorf("checkpoint job %s/%s failed: %s", namespace, jobName, strings.TrimSpace(condition.Message))
}
}
return false, nil
}); err != nil {
if !wait.Interrupted(err) {
return "", err
}
summaryCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) status = strings.TrimSpace(job.Annotations[snapshotprotocol.CheckpointStatusAnnotation])
defer cancel() if status == snapshotprotocol.CheckpointStatusCompleted {
pods, err := clientset.CoreV1().Pods(namespace).List(summaryCtx, metav1.ListOptions{ return true, nil
LabelSelector: "batch.kubernetes.io/job-name=" + jobName,
})
summary := "no checkpoint pod created yet"
if err != nil {
summary = "unable to list checkpoint pod: " + err.Error()
} else if len(pods.Items) != 0 {
pod := pods.Items[0]
parts := []string{
fmt.Sprintf("job_status=%q", status),
fmt.Sprintf("pod=%s phase=%s", pod.Name, pod.Status.Phase),
} }
for _, condition := range pod.Status.Conditions { if status == snapshotprotocol.CheckpointStatusFailed {
if condition.Status == corev1.ConditionTrue || condition.Status == corev1.ConditionFalse { return false, fmt.Errorf("checkpoint job %s/%s failed", namespace, jobName)
parts = append(parts, fmt.Sprintf("%s=%s", condition.Type, condition.Status)) }
} if job.Status.Failed > 0 {
return false, fmt.Errorf("checkpoint job %s/%s failed", namespace, jobName)
} }
for _, status := range pod.Status.ContainerStatuses { for _, condition := range job.Status.Conditions {
if status.State.Waiting != nil { if condition.Status != corev1.ConditionTrue {
parts = append(parts, fmt.Sprintf("container=%s waiting=%s", status.Name, status.State.Waiting.Reason)) continue
} }
if status.State.Terminated != nil { if condition.Type == batchv1.JobFailed {
parts = append(parts, fmt.Sprintf("container=%s terminated=%s", status.Name, status.State.Terminated.Reason)) return false, fmt.Errorf("checkpoint job %s/%s failed: %s", namespace, jobName, strings.TrimSpace(condition.Message))
} }
} }
summary = strings.Join(parts, " ") return false, nil
},
)
if err != nil {
if !errors.Is(err, context.DeadlineExceeded) {
return "", err
} }
return "", fmt.Errorf("checkpoint job %s/%s timed out: %s", namespace, jobName, summary) return "", fmt.Errorf("checkpoint job %s/%s timed out: %s", namespace, jobName, checkpointTimeoutSummary(clientset, namespace, jobName, status))
} }
return status, nil return status, nil
} }
func checkpointTimeoutSummary(clientset kubernetes.Interface, namespace string, jobName string, status string) string {
summaryCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
pods, err := clientset.CoreV1().Pods(namespace).List(summaryCtx, metav1.ListOptions{
LabelSelector: "batch.kubernetes.io/job-name=" + jobName,
})
if err != nil {
return "unable to list checkpoint pod: " + err.Error()
}
if len(pods.Items) == 0 {
return "no checkpoint pod created yet"
}
pod := pods.Items[0]
parts := []string{
fmt.Sprintf("job_status=%q", status),
fmt.Sprintf("pod=%s phase=%s", pod.Name, pod.Status.Phase),
}
for _, condition := range pod.Status.Conditions {
if condition.Status == corev1.ConditionTrue || condition.Status == corev1.ConditionFalse {
parts = append(parts, fmt.Sprintf("%s=%s", condition.Type, condition.Status))
}
}
for _, containerStatus := range pod.Status.ContainerStatuses {
if containerStatus.State.Waiting != nil {
parts = append(parts, fmt.Sprintf("container=%s waiting=%s", containerStatus.Name, containerStatus.State.Waiting.Reason))
}
if containerStatus.State.Terminated != nil {
parts = append(parts, fmt.Sprintf("container=%s terminated=%s", containerStatus.Name, containerStatus.State.Terminated.Reason))
}
}
return strings.Join(parts, " ")
}
...@@ -3,6 +3,7 @@ package main ...@@ -3,6 +3,7 @@ package main
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strings" "strings"
"time" "time"
...@@ -10,8 +11,9 @@ import ( ...@@ -10,8 +11,9 @@ import (
corev1 "k8s.io/api/core/v1" corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors" apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/watch"
"k8s.io/client-go/kubernetes" "k8s.io/client-go/kubernetes"
snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol" snapshotprotocol "github.com/ai-dynamo/dynamo/deploy/snapshot/protocol"
...@@ -155,28 +157,79 @@ func runRestoreFlow(ctx context.Context, opts restoreOptions) (*result, error) { ...@@ -155,28 +157,79 @@ func runRestoreFlow(ctx context.Context, opts restoreOptions) (*result, error) {
func waitForRestore(ctx context.Context, clientset kubernetes.Interface, namespace string, podName string) (string, error) { func waitForRestore(ctx context.Context, clientset kubernetes.Interface, namespace string, podName string) (string, error) {
var status string var status string
if err := wait.PollUntilContextCancel(ctx, 2*time.Second, true, func(ctx context.Context) (bool, error) { err := watchNamedObject(
pod, err := clientset.CoreV1().Pods(namespace).Get(ctx, podName, metav1.GetOptions{}) ctx,
if err != nil { podName,
return false, fmt.Errorf("get restore pod %s/%s: %w", namespace, podName, err) &corev1.Pod{},
func(ctx context.Context, options metav1.ListOptions) (runtime.Object, error) {
return clientset.CoreV1().Pods(namespace).List(ctx, options)
},
func(ctx context.Context, options metav1.ListOptions) (watch.Interface, error) {
return clientset.CoreV1().Pods(namespace).Watch(ctx, options)
},
func(event watch.Event) (bool, error) {
if event.Type == watch.Error {
return false, apierrors.FromObject(event.Object)
}
pod, ok := event.Object.(*corev1.Pod)
if !ok {
return false, fmt.Errorf("unexpected restore watch object %T", event.Object)
}
status = strings.TrimSpace(pod.Annotations[snapshotprotocol.RestoreStatusAnnotation])
if status == snapshotprotocol.RestoreStatusCompleted {
return true, nil
}
if status == snapshotprotocol.RestoreStatusFailed {
return false, fmt.Errorf("restore pod %s/%s failed", namespace, podName)
}
if pod.Status.Phase == corev1.PodFailed {
return false, fmt.Errorf("restore pod %s/%s entered phase Failed (%s)", namespace, podName, pod.Status.Reason)
}
return false, nil
},
)
if err != nil {
if !errors.Is(err, context.DeadlineExceeded) {
return "", err
} }
return "", fmt.Errorf("restore pod %s/%s timed out: %s", namespace, podName, restoreTimeoutSummary(clientset, namespace, podName, status))
}
return status, nil
}
func restoreTimeoutSummary(clientset kubernetes.Interface, namespace string, podName string, status string) string {
summaryCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
status = strings.TrimSpace(pod.Annotations[snapshotprotocol.RestoreStatusAnnotation]) pod, err := clientset.CoreV1().Pods(namespace).Get(summaryCtx, podName, metav1.GetOptions{})
if status == snapshotprotocol.RestoreStatusCompleted { if err != nil {
return true, nil if apierrors.IsNotFound(err) {
return fmt.Sprintf("restore_status=%q pod not found", status)
} }
if status == snapshotprotocol.RestoreStatusFailed { return "unable to get restore pod: " + err.Error()
return false, fmt.Errorf("restore pod %s/%s failed", namespace, podName) }
parts := []string{
fmt.Sprintf("restore_status=%q", status),
fmt.Sprintf("pod=%s phase=%s", pod.Name, pod.Status.Phase),
}
if reason := strings.TrimSpace(pod.Status.Reason); reason != "" {
parts = append(parts, fmt.Sprintf("reason=%s", reason))
}
for _, condition := range pod.Status.Conditions {
if condition.Status == corev1.ConditionTrue || condition.Status == corev1.ConditionFalse {
parts = append(parts, fmt.Sprintf("%s=%s", condition.Type, condition.Status))
} }
if pod.Status.Phase == corev1.PodFailed { }
return false, fmt.Errorf("restore pod %s/%s entered phase Failed (%s)", namespace, podName, pod.Status.Reason) for _, containerStatus := range pod.Status.ContainerStatuses {
if containerStatus.State.Waiting != nil {
parts = append(parts, fmt.Sprintf("container=%s waiting=%s", containerStatus.Name, containerStatus.State.Waiting.Reason))
} }
return false, nil if containerStatus.State.Terminated != nil {
}); err != nil { parts = append(parts, fmt.Sprintf("container=%s terminated=%s", containerStatus.Name, containerStatus.State.Terminated.Reason))
if !wait.Interrupted(err) {
return "", err
} }
return "", fmt.Errorf("restore pod %s/%s timed out: status=%q", namespace, podName, status)
} }
return status, nil return strings.Join(parts, " ")
} }
package main
import (
"context"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/fields"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/watch"
"k8s.io/client-go/tools/cache"
watchtools "k8s.io/client-go/tools/watch"
)
func watchNamedObject(
ctx context.Context,
name string,
objType runtime.Object,
listFn func(context.Context, metav1.ListOptions) (runtime.Object, error),
watchFn func(context.Context, metav1.ListOptions) (watch.Interface, error),
condition func(watch.Event) (bool, error),
) error {
fieldSelector := fields.OneTermEqualSelector("metadata.name", name).String()
lw := &cache.ListWatch{
ListFunc: func(options metav1.ListOptions) (runtime.Object, error) {
options.FieldSelector = fieldSelector
return listFn(ctx, options)
},
WatchFunc: func(options metav1.ListOptions) (watch.Interface, error) {
options.FieldSelector = fieldSelector
return watchFn(ctx, options)
},
}
_, err := watchtools.UntilWithSync(ctx, lw, objType, nil, condition)
if ctx.Err() != nil {
return ctx.Err()
}
return err
}
...@@ -291,7 +291,7 @@ func (w *NodeController) reconcileRestorePod(ctx context.Context, pod *corev1.Po ...@@ -291,7 +291,7 @@ func (w *NodeController) reconcileRestorePod(ctx context.Context, pod *corev1.Po
annotationStatus := pod.Annotations[snapshotprotocol.RestoreStatusAnnotation] annotationStatus := pod.Annotations[snapshotprotocol.RestoreStatusAnnotation]
annotationContainerID := pod.Annotations[snapshotprotocol.RestoreContainerIDAnnotation] annotationContainerID := pod.Annotations[snapshotprotocol.RestoreContainerIDAnnotation]
if annotationContainerID == containerID && (annotationStatus == snapshotprotocol.RestoreStatusCompleted || annotationStatus == snapshotprotocol.RestoreStatusInProgress) { if annotationContainerID == containerID && (annotationStatus == snapshotprotocol.RestoreStatusCompleted || annotationStatus == snapshotprotocol.RestoreStatusFailed) {
return return
} }
...@@ -468,6 +468,12 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai ...@@ -468,6 +468,12 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
w.release(restoreAttemptKey) w.release(restoreAttemptKey)
} }
}() }()
restoreCtx := ctx
if timeout := w.config.Restore.RestoreTimeout(); timeout > 0 {
var cancel context.CancelFunc
restoreCtx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
podKey := fmt.Sprintf("%s/%s", pod.Namespace, pod.Name) podKey := fmt.Sprintf("%s/%s", pod.Namespace, pod.Name)
log := w.log.WithValues("pod", podKey, "checkpoint_id", checkpointID, "container_id", containerID) log := w.log.WithValues("pod", podKey, "checkpoint_id", checkpointID, "container_id", containerID)
setRestoreStatus := func(value string) error { setRestoreStatus := func(value string) error {
...@@ -476,7 +482,7 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai ...@@ -476,7 +482,7 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
snapshotprotocol.RestoreContainerIDAnnotation: containerID, snapshotprotocol.RestoreContainerIDAnnotation: containerID,
} }
if err := annotatePod(ctx, w.clientset, log, pod, annotations); err != nil { if err := annotatePod(ctx, w.clientset, log, pod, annotations); err != nil {
if value == snapshotprotocol.RestoreStatusCompleted { if value == snapshotprotocol.RestoreStatusCompleted || value == snapshotprotocol.RestoreStatusFailed {
releaseOnExit = false releaseOnExit = false
return fmt.Errorf("failed to persist terminal restore status %q: %w", value, err) return fmt.Errorf("failed to persist terminal restore status %q: %w", value, err)
} }
...@@ -485,10 +491,7 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai ...@@ -485,10 +491,7 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
return nil return nil
} }
if err := annotatePod(ctx, w.clientset, log, pod, map[string]string{ if err := setRestoreStatus(snapshotprotocol.RestoreStatusInProgress); err != nil {
snapshotprotocol.RestoreStatusAnnotation: snapshotprotocol.RestoreStatusInProgress,
snapshotprotocol.RestoreContainerIDAnnotation: containerID,
}); err != nil {
return fmt.Errorf("failed to annotate pod with restore in_progress: %w", err) return fmt.Errorf("failed to annotate pod with restore in_progress: %w", err)
} }
...@@ -503,17 +506,18 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai ...@@ -503,17 +506,18 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
ContainerName: containerName, ContainerName: containerName,
Clientset: w.clientset, Clientset: w.clientset,
} }
restoredPID, err := executor.Restore(ctx, w.containerd, log, req) restoredPID, err := executor.Restore(restoreCtx, w.containerd, log, req)
if err != nil { if err != nil {
log.Error(err, "External restore failed") log.Error(err, "External restore failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error()) emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil {
return statusErr
}
placeholderHostPID, _, pidErr := snapshotruntime.ResolveContainerByPod(ctx, w.containerd, pod.Name, pod.Namespace, containerName) placeholderHostPID, _, pidErr := snapshotruntime.ResolveContainerByPod(ctx, w.containerd, pod.Name, pod.Namespace, containerName)
if pidErr != nil { if pidErr != nil {
releaseOnExit = false
return fmt.Errorf("restore failed and placeholder PID could not be resolved: %w", pidErr) return fmt.Errorf("restore failed and placeholder PID could not be resolved: %w", pidErr)
} }
if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore failed"); killErr != nil { if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore failed"); killErr != nil {
releaseOnExit = false
return fmt.Errorf("restore failed and placeholder could not be killed: %w", killErr) return fmt.Errorf("restore failed and placeholder could not be killed: %w", killErr)
} }
return nil return nil
...@@ -524,33 +528,33 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai ...@@ -524,33 +528,33 @@ func (w *NodeController) runRestore(ctx context.Context, pod *corev1.Pod, contai
if err != nil { if err != nil {
log.Error(err, "Failed to resolve placeholder host PID for signaling") log.Error(err, "Failed to resolve placeholder host PID for signaling")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error()) emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
releaseOnExit = false if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil {
return statusErr
}
return fmt.Errorf("failed to resolve placeholder host PID for signaling: %w", err) return fmt.Errorf("failed to resolve placeholder host PID for signaling: %w", err)
} }
if err := snapshotruntime.SendSignalViaPIDNamespace(ctx, log, placeholderHostPID, restoredPID, syscall.SIGCONT, "restore complete"); err != nil { if err := snapshotruntime.SendSignalViaPIDNamespace(restoreCtx, log, placeholderHostPID, restoredPID, syscall.SIGCONT, "restore complete"); err != nil {
log.Error(err, "Failed to signal restored runtime process") log.Error(err, "Failed to signal restored runtime process")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error()) emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil {
return statusErr
}
if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore signaling failed"); killErr != nil { if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore signaling failed"); killErr != nil {
log.Error(killErr, "Failed to kill placeholder after restore signaling failure") log.Error(killErr, "Failed to kill placeholder after restore signaling failure")
} }
releaseOnExit = false
return fmt.Errorf("failed to signal restored runtime process: %w", err) return fmt.Errorf("failed to signal restored runtime process: %w", err)
} }
// Step 3: Wait for the pod to become Ready // Step 3: Wait for the pod to become Ready
readyCtx := ctx if err := waitForPodReady(restoreCtx, w.clientset, pod.Namespace, pod.Name, containerName); err != nil {
if timeout := w.config.Restore.RestoreReadyTimeout(); timeout > 0 {
var cancel context.CancelFunc
readyCtx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
if err := waitForPodReady(readyCtx, w.clientset, pod.Namespace, pod.Name, containerName); err != nil {
log.Error(err, "Restore post-signal readiness check failed") log.Error(err, "Restore post-signal readiness check failed")
emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error()) emitPodEvent(ctx, w.clientset, log, pod, "snapshot", corev1.EventTypeWarning, "RestoreFailed", err.Error())
if statusErr := setRestoreStatus(snapshotprotocol.RestoreStatusFailed); statusErr != nil {
return statusErr
}
if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore readiness failed"); killErr != nil { if killErr := snapshotruntime.SendSignalToPID(log, placeholderHostPID, syscall.SIGKILL, "restore readiness failed"); killErr != nil {
log.Error(killErr, "Failed to kill placeholder after restore readiness failure") log.Error(killErr, "Failed to kill placeholder after restore readiness failure")
} }
releaseOnExit = false
return fmt.Errorf("restore post-signal readiness check failed: %w", err) return fmt.Errorf("restore post-signal readiness check failed: %w", err)
} }
......
...@@ -324,7 +324,7 @@ func TestReconcileRestorePod(t *testing.T) { ...@@ -324,7 +324,7 @@ func TestReconcileRestorePod(t *testing.T) {
want: false, want: false,
}, },
{ {
name: "already in progress for same container", name: "in progress for same container retries after restart",
nodeName: testNodeName, nodeName: testNodeName,
phase: corev1.PodRunning, phase: corev1.PodRunning,
ready: false, ready: false,
...@@ -332,6 +332,17 @@ func TestReconcileRestorePod(t *testing.T) { ...@@ -332,6 +332,17 @@ func TestReconcileRestorePod(t *testing.T) {
annotationStatus: "in_progress", annotationStatus: "in_progress",
annotationContainerID: testContainerID, annotationContainerID: testContainerID,
createDir: true, createDir: true,
want: true,
},
{
name: "already failed for same container",
nodeName: testNodeName,
phase: corev1.PodRunning,
ready: false,
hash: "abc123",
annotationStatus: "failed",
annotationContainerID: testContainerID,
createDir: true,
want: false, want: false,
}, },
{ {
...@@ -345,6 +356,17 @@ func TestReconcileRestorePod(t *testing.T) { ...@@ -345,6 +356,17 @@ func TestReconcileRestorePod(t *testing.T) {
createDir: true, createDir: true,
want: true, want: true,
}, },
{
name: "failed for previous container retries",
nodeName: testNodeName,
phase: corev1.PodRunning,
ready: false,
hash: "abc123",
annotationStatus: "failed",
annotationContainerID: "old-container",
createDir: true,
want: true,
},
{ {
name: "in progress for previous container retries", name: "in progress for previous container retries",
nodeName: testNodeName, nodeName: testNodeName,
......
...@@ -56,21 +56,24 @@ type StorageSpec struct { ...@@ -56,21 +56,24 @@ type StorageSpec struct {
// RestoreSpec holds settings for the CRIU restore process. // RestoreSpec holds settings for the CRIU restore process.
type RestoreSpec struct { type RestoreSpec struct {
NSRestorePath string `yaml:"nsRestorePath"` NSRestorePath string `yaml:"nsRestorePath"`
RestoreReadyTimeoutSeconds int `yaml:"restoreReadyTimeoutSeconds"` RestoreTimeoutSeconds int `yaml:"restoreTimeoutSeconds"`
} }
func (c *RestoreSpec) RestoreReadyTimeout() time.Duration { func (c *RestoreSpec) RestoreTimeout() time.Duration {
if c.RestoreReadyTimeoutSeconds <= 0 { if c.RestoreTimeoutSeconds <= 0 {
return 0 return 0
} }
return time.Duration(c.RestoreReadyTimeoutSeconds) * time.Second return time.Duration(c.RestoreTimeoutSeconds) * time.Second
} }
func (c *RestoreSpec) Validate() error { func (c *RestoreSpec) Validate() error {
if c.NSRestorePath == "" { if c.NSRestorePath == "" {
return &ConfigError{Field: "nsRestorePath", Message: "nsRestorePath is required"} return &ConfigError{Field: "nsRestorePath", Message: "nsRestorePath is required"}
} }
if c.RestoreTimeoutSeconds <= 0 {
return &ConfigError{Field: "restoreTimeoutSeconds", Message: "restoreTimeoutSeconds must be greater than zero"}
}
return nil return nil
} }
......
...@@ -6,6 +6,7 @@ package protocol ...@@ -6,6 +6,7 @@ package protocol
import ( import (
"context" "context"
"fmt" "fmt"
"math"
"strings" "strings"
appsv1 "k8s.io/api/apps/v1" appsv1 "k8s.io/api/apps/v1"
...@@ -61,12 +62,28 @@ func PrepareRestorePodSpec( ...@@ -61,12 +62,28 @@ func PrepareRestorePodSpec(
if isCheckpointReady { if isCheckpointReady {
container.Command = []string{"sleep", "infinity"} container.Command = []string{"sleep", "infinity"}
container.Args = nil container.Args = nil
container.StartupProbe = nil ensureRestoreStartupProbe(container)
container.LivenessProbe = nil
container.ReadinessProbe = nil
} }
} }
func ensureRestoreStartupProbe(container *corev1.Container) {
startup := container.StartupProbe
if startup == nil {
startup = container.LivenessProbe
if startup == nil {
startup = container.ReadinessProbe
}
}
if startup == nil {
return
}
startup = startup.DeepCopy()
startup.FailureThreshold = math.MaxInt32
startup.SuccessThreshold = 1
container.StartupProbe = startup
}
func ValidateRestorePodSpec( func ValidateRestorePodSpec(
podSpec *corev1.PodSpec, podSpec *corev1.PodSpec,
storage Storage, storage Storage,
......
package protocol package protocol
import ( import (
"math"
"testing" "testing"
appsv1 "k8s.io/api/apps/v1" appsv1 "k8s.io/api/apps/v1"
...@@ -9,6 +10,9 @@ import ( ...@@ -9,6 +10,9 @@ import (
) )
func TestNewRestorePod(t *testing.T) { func TestNewRestorePod(t *testing.T) {
readinessProbe := &corev1.Probe{PeriodSeconds: 7, TimeoutSeconds: 3}
livenessProbe := &corev1.Probe{InitialDelaySeconds: 11}
startupProbe := &corev1.Probe{FailureThreshold: 120}
restorePod := NewRestorePod(&corev1.Pod{ restorePod := NewRestorePod(&corev1.Pod{
ObjectMeta: metav1.ObjectMeta{ ObjectMeta: metav1.ObjectMeta{
Name: "worker", Name: "worker",
...@@ -22,9 +26,9 @@ func TestNewRestorePod(t *testing.T) { ...@@ -22,9 +26,9 @@ func TestNewRestorePod(t *testing.T) {
Image: "test:latest", Image: "test:latest",
Command: []string{"python3", "-m", "dynamo.vllm"}, Command: []string{"python3", "-m", "dynamo.vllm"},
Args: []string{"--model", "Qwen"}, Args: []string{"--model", "Qwen"},
ReadinessProbe: &corev1.Probe{}, ReadinessProbe: readinessProbe.DeepCopy(),
LivenessProbe: &corev1.Probe{}, LivenessProbe: livenessProbe.DeepCopy(),
StartupProbe: &corev1.Probe{}, StartupProbe: startupProbe.DeepCopy(),
}}, }},
}, },
}, PodOptions{ }, PodOptions{
...@@ -60,14 +64,23 @@ func TestNewRestorePod(t *testing.T) { ...@@ -60,14 +64,23 @@ func TestNewRestorePod(t *testing.T) {
if restorePod.Spec.Containers[0].Args != nil { if restorePod.Spec.Containers[0].Args != nil {
t.Fatalf("expected restore args to be cleared: %#v", restorePod.Spec.Containers[0].Args) t.Fatalf("expected restore args to be cleared: %#v", restorePod.Spec.Containers[0].Args)
} }
if restorePod.Spec.Containers[0].ReadinessProbe != nil { if restorePod.Spec.Containers[0].ReadinessProbe == nil {
t.Fatalf("expected readiness probe to be cleared: %#v", restorePod.Spec.Containers[0].ReadinessProbe) t.Fatalf("expected readiness probe to be preserved")
} }
if restorePod.Spec.Containers[0].LivenessProbe != nil { if got := restorePod.Spec.Containers[0].ReadinessProbe.PeriodSeconds; got != readinessProbe.PeriodSeconds {
t.Fatalf("expected liveness probe to be cleared: %#v", restorePod.Spec.Containers[0].LivenessProbe) t.Fatalf("expected readiness probe period %d, got %d", readinessProbe.PeriodSeconds, got)
} }
if restorePod.Spec.Containers[0].StartupProbe != nil { if restorePod.Spec.Containers[0].LivenessProbe == nil {
t.Fatalf("expected startup probe to be cleared: %#v", restorePod.Spec.Containers[0].StartupProbe) t.Fatalf("expected liveness probe to be preserved")
}
if got := restorePod.Spec.Containers[0].LivenessProbe.InitialDelaySeconds; got != livenessProbe.InitialDelaySeconds {
t.Fatalf("expected liveness initial delay %d, got %d", livenessProbe.InitialDelaySeconds, got)
}
if restorePod.Spec.Containers[0].StartupProbe == nil {
t.Fatalf("expected startup probe to be preserved")
}
if got := restorePod.Spec.Containers[0].StartupProbe.FailureThreshold; got != math.MaxInt32 {
t.Fatalf("expected startup failure threshold %d, got %d", math.MaxInt32, got)
} }
if restorePod.Spec.SecurityContext == nil || restorePod.Spec.SecurityContext.SeccompProfile == nil { if restorePod.Spec.SecurityContext == nil || restorePod.Spec.SecurityContext.SeccompProfile == nil {
t.Fatalf("expected seccomp profile to be injected: %#v", restorePod.Spec.SecurityContext) t.Fatalf("expected seccomp profile to be injected: %#v", restorePod.Spec.SecurityContext)
...@@ -82,12 +95,15 @@ func TestNewRestorePod(t *testing.T) { ...@@ -82,12 +95,15 @@ func TestNewRestorePod(t *testing.T) {
func TestPrepareRestorePodSpec(t *testing.T) { func TestPrepareRestorePodSpec(t *testing.T) {
podSpec := corev1.PodSpec{} podSpec := corev1.PodSpec{}
readinessProbe := &corev1.Probe{PeriodSeconds: 13, SuccessThreshold: 1}
livenessProbe := &corev1.Probe{TimeoutSeconds: 5}
startupProbe := &corev1.Probe{FailureThreshold: 60}
container := corev1.Container{ container := corev1.Container{
Command: []string{"python3", "-m", "dynamo.vllm"}, Command: []string{"python3", "-m", "dynamo.vllm"},
Args: []string{"--model", "Qwen"}, Args: []string{"--model", "Qwen"},
ReadinessProbe: &corev1.Probe{}, ReadinessProbe: readinessProbe.DeepCopy(),
LivenessProbe: &corev1.Probe{}, LivenessProbe: livenessProbe.DeepCopy(),
StartupProbe: &corev1.Probe{}, StartupProbe: startupProbe.DeepCopy(),
} }
storage := Storage{ storage := Storage{
...@@ -113,8 +129,99 @@ func TestPrepareRestorePodSpec(t *testing.T) { ...@@ -113,8 +129,99 @@ func TestPrepareRestorePodSpec(t *testing.T) {
if container.Args != nil { if container.Args != nil {
t.Fatalf("expected restore args to be cleared: %#v", container.Args) t.Fatalf("expected restore args to be cleared: %#v", container.Args)
} }
if container.ReadinessProbe != nil || container.LivenessProbe != nil || container.StartupProbe != nil { if container.ReadinessProbe == nil {
t.Fatalf("expected probes to be cleared: %#v %#v %#v", container.ReadinessProbe, container.LivenessProbe, container.StartupProbe) t.Fatalf("expected readiness probe to be preserved")
}
if got := container.ReadinessProbe.PeriodSeconds; got != readinessProbe.PeriodSeconds {
t.Fatalf("expected readiness probe period %d, got %d", readinessProbe.PeriodSeconds, got)
}
if container.LivenessProbe == nil {
t.Fatalf("expected liveness probe to be preserved")
}
if got := container.LivenessProbe.TimeoutSeconds; got != livenessProbe.TimeoutSeconds {
t.Fatalf("expected liveness timeout %d, got %d", livenessProbe.TimeoutSeconds, got)
}
if container.StartupProbe == nil {
t.Fatalf("expected startup probe to be preserved")
}
if got := container.StartupProbe.FailureThreshold; got != math.MaxInt32 {
t.Fatalf("expected startup failure threshold %d, got %d", math.MaxInt32, got)
}
if got := container.StartupProbe.SuccessThreshold; got != 1 {
t.Fatalf("expected startup success threshold 1, got %d", got)
}
}
func TestPrepareRestorePodSpecSynthesizesStartupProbeFromLiveness(t *testing.T) {
podSpec := corev1.PodSpec{}
livenessProbe := &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
HTTPGet: &corev1.HTTPGetAction{Path: "/livez"},
},
PeriodSeconds: 5,
TimeoutSeconds: 4,
FailureThreshold: 2,
}
container := corev1.Container{
Command: []string{"python3", "-m", "dynamo.vllm"},
Args: []string{"--model", "Qwen"},
LivenessProbe: livenessProbe.DeepCopy(),
}
PrepareRestorePodSpec(&podSpec, &container, Storage{}, "", true)
if container.LivenessProbe == nil {
t.Fatalf("expected liveness probe to be preserved")
}
if container.StartupProbe == nil {
t.Fatalf("expected startup probe to be synthesized")
}
if container.StartupProbe.HTTPGet == nil || container.StartupProbe.HTTPGet.Path != "/livez" {
t.Fatalf("expected startup probe HTTP path /livez, got %#v", container.StartupProbe.HTTPGet)
}
if got := container.StartupProbe.FailureThreshold; got != math.MaxInt32 {
t.Fatalf("expected startup failure threshold %d, got %d", math.MaxInt32, got)
}
if got := container.StartupProbe.SuccessThreshold; got != 1 {
t.Fatalf("expected startup success threshold 1, got %d", got)
}
}
func TestPrepareRestorePodSpecSynthesizesStartupProbeFromReadiness(t *testing.T) {
podSpec := corev1.PodSpec{}
readinessProbe := &corev1.Probe{
ProbeHandler: corev1.ProbeHandler{
Exec: &corev1.ExecAction{Command: []string{"cat", "/tmp/ready"}},
},
PeriodSeconds: 13,
SuccessThreshold: 3,
FailureThreshold: 4,
}
container := corev1.Container{
Command: []string{"python3", "-m", "dynamo.vllm"},
Args: []string{"--model", "Qwen"},
ReadinessProbe: readinessProbe.DeepCopy(),
}
PrepareRestorePodSpec(&podSpec, &container, Storage{}, "", true)
if container.ReadinessProbe == nil {
t.Fatalf("expected readiness probe to be preserved")
}
if got := container.ReadinessProbe.SuccessThreshold; got != readinessProbe.SuccessThreshold {
t.Fatalf("expected readiness success threshold %d, got %d", readinessProbe.SuccessThreshold, got)
}
if container.StartupProbe == nil {
t.Fatalf("expected startup probe to be synthesized")
}
if container.StartupProbe.Exec == nil || len(container.StartupProbe.Exec.Command) != 2 || container.StartupProbe.Exec.Command[0] != "cat" || container.StartupProbe.Exec.Command[1] != "/tmp/ready" {
t.Fatalf("expected startup probe exec command to match readiness probe: %#v", container.StartupProbe.Exec)
}
if got := container.StartupProbe.FailureThreshold; got != math.MaxInt32 {
t.Fatalf("expected startup failure threshold %d, got %d", math.MaxInt32, got)
}
if got := container.StartupProbe.SuccessThreshold; got != 1 {
t.Fatalf("expected startup success threshold 1, got %d", got)
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment