Unverified Commit f6976e7f authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Request Migration disable on max seq len exceeded (#8020)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent eb772381
...@@ -24,6 +24,8 @@ from dynamo.common.configuration.utils import ( ...@@ -24,6 +24,8 @@ from dynamo.common.configuration.utils import (
from . import __version__ from . import __version__
_U32_MAX = 2**32 - 1
def validate_model_name(value: str) -> str: def validate_model_name(value: str) -> str:
"""Validate that model-name is a non-empty string.""" """Validate that model-name is a non-empty string."""
...@@ -60,6 +62,7 @@ class FrontendConfig(KvRouterConfigBase, AicPerfConfigBase): ...@@ -60,6 +62,7 @@ class FrontendConfig(KvRouterConfigBase, AicPerfConfigBase):
enforce_disagg: bool enforce_disagg: bool
migration_limit: int migration_limit: int
migration_max_seq_len: Optional[int]
active_decode_blocks_threshold: Optional[float] active_decode_blocks_threshold: Optional[float]
active_prefill_tokens_threshold: Optional[int] active_prefill_tokens_threshold: Optional[int]
active_prefill_tokens_threshold_frac: Optional[float] active_prefill_tokens_threshold_frac: Optional[float]
...@@ -91,9 +94,15 @@ class FrontendConfig(KvRouterConfigBase, AicPerfConfigBase): ...@@ -91,9 +94,15 @@ class FrontendConfig(KvRouterConfigBase, AicPerfConfigBase):
raise ValueError( raise ValueError(
"--tls-cert-path and --tls-key-path must be provided together" "--tls-cert-path and --tls-key-path must be provided together"
) )
if self.migration_limit < 0 or self.migration_limit > 4294967295: if self.migration_limit < 0 or self.migration_limit > _U32_MAX:
raise ValueError(
f"--migration-limit must be between 0 and {_U32_MAX} (0=disabled)"
)
if self.migration_max_seq_len is not None and (
self.migration_max_seq_len < 1 or self.migration_max_seq_len > _U32_MAX
):
raise ValueError( raise ValueError(
"--migration-limit must be between 0 and 4294967295 (0=disabled)" f"--migration-max-seq-len must be between 1 and {_U32_MAX}"
) )
if self.min_initial_workers < 0: if self.min_initial_workers < 0:
raise ValueError("--router-min-initial-workers must be >= 0") raise ValueError("--router-min-initial-workers must be >= 0")
...@@ -301,6 +310,20 @@ class FrontendArgGroup(ArgGroup): ...@@ -301,6 +310,20 @@ class FrontendArgGroup(ArgGroup):
arg_type=int, arg_type=int,
) )
add_argument(
g,
flag_name="--migration-max-seq-len",
env_var="DYN_MIGRATION_MAX_SEQ_LEN",
default=None,
help=(
"Maximum sequence length (prompt + generated tokens) for migration state tracking. "
"Once the accumulated token count exceeds this limit, the request becomes "
"non-migratable. Prevents unbounded memory growth from caching long sequences. "
"Default: no limit."
),
arg_type=int,
)
add_argument( add_argument(
g, g,
flag_name="--active-decode-blocks-threshold", flag_name="--active-decode-blocks-threshold",
......
...@@ -184,9 +184,14 @@ async def async_main(): ...@@ -184,9 +184,14 @@ async def async_main():
os.environ["DYN_TOKENIZER"] = "fastokens" os.environ["DYN_TOKENIZER"] = "fastokens"
else: else:
os.environ.pop("DYN_TOKENIZER", None) os.environ.pop("DYN_TOKENIZER", None)
max_seq_info = (
f", max_seq_len: {config.migration_max_seq_len}"
if config.migration_max_seq_len is not None
else ""
)
logger.info( logger.info(
f"Request migration {'enabled' if config.migration_limit > 0 else 'disabled'} " f"Request migration {'enabled' if config.migration_limit > 0 else 'disabled'} "
f"(limit: {config.migration_limit})" f"(limit: {config.migration_limit}{max_seq_info})"
) )
# Warn if DYN_SYSTEM_PORT is set (frontend doesn't use system metrics server) # Warn if DYN_SYSTEM_PORT is set (frontend doesn't use system metrics server)
if os.environ.get("DYN_SYSTEM_PORT"): if os.environ.get("DYN_SYSTEM_PORT"):
...@@ -266,6 +271,8 @@ async def async_main(): ...@@ -266,6 +271,8 @@ async def async_main():
"router_config": router_config, "router_config": router_config,
"migration_limit": config.migration_limit, "migration_limit": config.migration_limit,
} }
if config.migration_max_seq_len is not None:
kwargs["migration_max_seq_len"] = config.migration_max_seq_len
if config.model_name: if config.model_name:
kwargs["model_name"] = config.model_name kwargs["model_name"] = config.model_name
......
...@@ -30,6 +30,16 @@ The migration limit is configured at the **frontend** level and applies globally ...@@ -30,6 +30,16 @@ The migration limit is configured at the **frontend** level and applies globally
- Set via `--migration-limit` flag on the frontend - Set via `--migration-limit` flag on the frontend
- Applies to all models served by the frontend - Applies to all models served by the frontend
### Max Sequence Length Configuration
The max sequence length setting controls how long the migration system will cache token state for a request. Once the total sequence length (prompt + generated tokens) exceeds this limit, migration is disabled for that request and token tracking stops:
- Default behavior: no limit (`--migration-max-seq-len` unset)
- Set via `--migration-max-seq-len` flag or `DYN_MIGRATION_MAX_SEQ_LEN` environment variable on the frontend
- Prevents unbounded memory growth from caching long sequences
- Boundary: exactly at the limit is still migratable; only strictly exceeding it disables migration
- The check runs both at request initialization (prompt length) and during generation (prompt + output tokens)
## Token State Tracking and Request Migration ## Token State Tracking and Request Migration
The core of the migration system is the ability to preserve and continue partial generations through token state management. This ensures that when a worker fails mid-generation, the new worker can seamlessly continue from the exact point of failure. The core of the migration system is the ability to preserve and continue partial generations through token state management. This ensures that when a worker fails mid-generation, the new worker can seamlessly continue from the exact point of failure.
...@@ -118,17 +128,22 @@ The migration system exposes Prometheus metrics to monitor migration activity. T ...@@ -118,17 +128,22 @@ The migration system exposes Prometheus metrics to monitor migration activity. T
- Labels: - Labels:
- `model`: The model name being served - `model`: The model name being served
- `migration_type`: Either `new_request` (initial connection failure) or `ongoing_request` (mid-stream disconnection) - `migration_type`: Either `new_request` (initial connection failure) or `ongoing_request` (mid-stream disconnection)
- `dynamo_frontend_model_migration_max_seq_len_exceeded_total`: Counter tracking the number of times migration was disabled because the sequence length exceeded the configured `--migration-max-seq-len`
- Labels:
- `model`: The model name being served
**Example metrics output:** **Example metrics output:**
``` ```text
dynamo_frontend_model_migration_total{migration_type="ongoing_request",model="Qwen/Qwen3-0.6B"} 3 dynamo_frontend_model_migration_total{migration_type="ongoing_request",model="Qwen/Qwen3-0.6B"} 3
dynamo_frontend_model_migration_total{migration_type="new_request",model="Qwen/Qwen3-0.6B"} 1 dynamo_frontend_model_migration_total{migration_type="new_request",model="Qwen/Qwen3-0.6B"} 1
dynamo_frontend_model_migration_max_seq_len_exceeded_total{model="Qwen/Qwen3-0.6B"} 2
``` ```
These metrics can be used to: These metrics can be used to:
- Monitor worker reliability and failure patterns - Monitor worker reliability and failure patterns
- Alert on excessive migration rates indicating infrastructure issues - Alert on excessive migration rates indicating infrastructure issues
- Track the effectiveness of fault tolerance mechanisms - Track the effectiveness of fault tolerance mechanisms
- Monitor how often `--migration-max-seq-len` is being reached, which may indicate the limit needs adjustment
For more information on Dynamo metrics, see the [Metrics documentation](../observability/metrics.md). For more information on Dynamo metrics, see the [Metrics documentation](../observability/metrics.md).
......
...@@ -324,6 +324,7 @@ pub(crate) struct EntrypointArgs { ...@@ -324,6 +324,7 @@ pub(crate) struct EntrypointArgs {
namespace_prefix: Option<String>, namespace_prefix: Option<String>,
is_prefill: bool, is_prefill: bool,
migration_limit: u32, migration_limit: u32,
migration_max_seq_len: Option<u32>,
chat_engine_factory: Option<PyEngineFactory>, chat_engine_factory: Option<PyEngineFactory>,
aic_perf_config: Option<AicPerfConfig>, aic_perf_config: Option<AicPerfConfig>,
} }
...@@ -332,7 +333,7 @@ pub(crate) struct EntrypointArgs { ...@@ -332,7 +333,7 @@ pub(crate) struct EntrypointArgs {
impl EntrypointArgs { impl EntrypointArgs {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[new] #[new]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, mocker_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None, aic_perf_config=None))] #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, mocker_engine_args=None, runtime_config=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, migration_max_seq_len=None, chat_engine_factory=None, aic_perf_config=None))]
pub fn new( pub fn new(
py: Python<'_>, py: Python<'_>,
engine_type: EngineType, engine_type: EngineType,
...@@ -355,6 +356,7 @@ impl EntrypointArgs { ...@@ -355,6 +356,7 @@ impl EntrypointArgs {
namespace_prefix: Option<String>, namespace_prefix: Option<String>,
is_prefill: bool, is_prefill: bool,
migration_limit: u32, migration_limit: u32,
migration_max_seq_len: Option<u32>,
chat_engine_factory: Option<PyObject>, chat_engine_factory: Option<PyObject>,
aic_perf_config: Option<AicPerfConfig>, aic_perf_config: Option<AicPerfConfig>,
) -> PyResult<Self> { ) -> PyResult<Self> {
...@@ -404,6 +406,7 @@ impl EntrypointArgs { ...@@ -404,6 +406,7 @@ impl EntrypointArgs {
namespace_prefix, namespace_prefix,
is_prefill, is_prefill,
migration_limit, migration_limit,
migration_max_seq_len,
chat_engine_factory, chat_engine_factory,
aic_perf_config, aic_perf_config,
}) })
...@@ -438,6 +441,7 @@ pub fn make_engine<'p>( ...@@ -438,6 +441,7 @@ pub fn make_engine<'p>(
.kv_cache_block_size(args.kv_cache_block_size) .kv_cache_block_size(args.kv_cache_block_size)
.router_config(args.router_config.clone().map(|rc| rc.into())) .router_config(args.router_config.clone().map(|rc| rc.into()))
.migration_limit(Some(args.migration_limit)) .migration_limit(Some(args.migration_limit))
.migration_max_seq_len(args.migration_max_seq_len)
.http_host(args.http_host.clone()) .http_host(args.http_host.clone())
.http_port(args.http_port) .http_port(args.http_port)
.http_metrics_port(args.http_metrics_port) .http_metrics_port(args.http_metrics_port)
......
...@@ -72,6 +72,7 @@ pub struct ModelWatcher { ...@@ -72,6 +72,7 @@ pub struct ModelWatcher {
drt: DistributedRuntime, drt: DistributedRuntime,
router_config: RouterConfig, router_config: RouterConfig,
migration_limit: u32, migration_limit: u32,
migration_max_seq_len: Option<u32>,
notify_on_model: Notify, notify_on_model: Notify,
model_update_tx: Option<Sender<ModelUpdate>>, model_update_tx: Option<Sender<ModelUpdate>>,
chat_engine_factory: Option<ChatEngineFactoryCallback>, chat_engine_factory: Option<ChatEngineFactoryCallback>,
...@@ -114,11 +115,13 @@ fn is_model_type_list_empty(manager: &ModelManager, model_type: ModelType) -> bo ...@@ -114,11 +115,13 @@ fn is_model_type_list_empty(manager: &ModelManager, model_type: ModelType) -> bo
} }
impl ModelWatcher { impl ModelWatcher {
#[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
runtime: DistributedRuntime, runtime: DistributedRuntime,
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
router_config: RouterConfig, router_config: RouterConfig,
migration_limit: u32, migration_limit: u32,
migration_max_seq_len: Option<u32>,
chat_engine_factory: Option<ChatEngineFactoryCallback>, chat_engine_factory: Option<ChatEngineFactoryCallback>,
prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>, prefill_load_estimator: Option<Arc<dyn PrefillLoadEstimator>>,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
...@@ -128,6 +131,7 @@ impl ModelWatcher { ...@@ -128,6 +131,7 @@ impl ModelWatcher {
drt: runtime, drt: runtime,
router_config, router_config,
migration_limit, migration_limit,
migration_max_seq_len,
notify_on_model: Notify::new(), notify_on_model: Notify::new(),
model_update_tx: None, model_update_tx: None,
chat_engine_factory, chat_engine_factory,
...@@ -577,6 +581,7 @@ impl ModelWatcher { ...@@ -577,6 +581,7 @@ impl ModelWatcher {
prefill_chooser.clone(), prefill_chooser.clone(),
self.router_config.enforce_disagg, self.router_config.enforce_disagg,
self.migration_limit, self.migration_limit,
self.migration_max_seq_len,
self.metrics.clone(), self.metrics.clone(),
) )
.await .await
...@@ -610,6 +615,7 @@ impl ModelWatcher { ...@@ -610,6 +615,7 @@ impl ModelWatcher {
prefill_chooser, prefill_chooser,
self.router_config.enforce_disagg, self.router_config.enforce_disagg,
self.migration_limit, self.migration_limit,
self.migration_max_seq_len,
self.metrics.clone(), self.metrics.clone(),
) )
.await .await
......
...@@ -106,6 +106,7 @@ pub async fn prepare_engine( ...@@ -106,6 +106,7 @@ pub async fn prepare_engine(
model_manager.clone(), model_manager.clone(),
RouterConfig::default(), RouterConfig::default(),
local_model.migration_limit(), local_model.migration_limit(),
local_model.migration_max_seq_len(),
None, None,
prefill_load_estimator, prefill_load_estimator,
metrics, metrics,
...@@ -221,6 +222,7 @@ pub async fn build_routed_pipeline<Req, Resp>( ...@@ -221,6 +222,7 @@ pub async fn build_routed_pipeline<Req, Resp>(
prefill_chooser: Option<Arc<PrefillRouter>>, prefill_chooser: Option<Arc<PrefillRouter>>,
enforce_disagg: bool, enforce_disagg: bool,
migration_limit: u32, migration_limit: u32,
migration_max_seq_len: Option<u32>,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>> ) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where where
...@@ -250,6 +252,7 @@ where ...@@ -250,6 +252,7 @@ where
prefill_chooser, prefill_chooser,
enforce_disagg, enforce_disagg,
migration_limit, migration_limit,
migration_max_seq_len,
metrics, metrics,
) )
.await .await
...@@ -268,6 +271,7 @@ pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>( ...@@ -268,6 +271,7 @@ pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>(
prefill_chooser: Option<Arc<PrefillRouter>>, prefill_chooser: Option<Arc<PrefillRouter>>,
enforce_disagg: bool, enforce_disagg: bool,
migration_limit: u32, migration_limit: u32,
migration_max_seq_len: Option<u32>,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>> ) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where where
...@@ -283,7 +287,8 @@ where ...@@ -283,7 +287,8 @@ where
let frontend = SegmentSource::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new(); let frontend = SegmentSource::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
let preprocessor_op = preprocessor.into_operator(); let preprocessor_op = preprocessor.into_operator();
let backend = Backend::from_tokenizer(tokenizer).into_operator(); let backend = Backend::from_tokenizer(tokenizer).into_operator();
let migration = Migration::from_mdc(card, migration_limit, metrics).into_operator(); let migration =
Migration::from_mdc(card, migration_limit, migration_max_seq_len, metrics).into_operator();
let min_initial_workers = min_initial_workers_from_env()?; let min_initial_workers = min_initial_workers_from_env()?;
// For KV routing, use the client from the chooser to ensure shared state // For KV routing, use the client from the chooser to ensure shared state
......
...@@ -41,6 +41,7 @@ pub async fn run( ...@@ -41,6 +41,7 @@ pub async fn run(
let grpc_service = grpc_service_builder.build()?; let grpc_service = grpc_service_builder.build()?;
let router_config = model.router_config(); let router_config = model.router_config();
let migration_limit = model.migration_limit(); let migration_limit = model.migration_limit();
let migration_max_seq_len = model.migration_max_seq_len();
// Listen for models registering themselves, add them to gRPC service // Listen for models registering themselves, add them to gRPC service
let namespace_filter = NamespaceFilter::from_namespace_and_prefix( let namespace_filter = NamespaceFilter::from_namespace_and_prefix(
model.namespace(), model.namespace(),
...@@ -51,6 +52,7 @@ pub async fn run( ...@@ -51,6 +52,7 @@ pub async fn run(
grpc_service.state().manager_clone(), grpc_service.state().manager_clone(),
router_config.clone(), router_config.clone(),
migration_limit, migration_limit,
migration_max_seq_len,
namespace_filter, namespace_filter,
prefill_load_estimator.clone(), prefill_load_estimator.clone(),
) )
...@@ -115,6 +117,7 @@ async fn run_watcher( ...@@ -115,6 +117,7 @@ async fn run_watcher(
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
router_config: RouterConfig, router_config: RouterConfig,
migration_limit: u32, migration_limit: u32,
migration_max_seq_len: Option<u32>,
namespace_filter: NamespaceFilter, namespace_filter: NamespaceFilter,
prefill_load_estimator: Option<Arc<dyn dynamo_kv_router::PrefillLoadEstimator>>, prefill_load_estimator: Option<Arc<dyn dynamo_kv_router::PrefillLoadEstimator>>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
...@@ -125,6 +128,7 @@ async fn run_watcher( ...@@ -125,6 +128,7 @@ async fn run_watcher(
model_manager, model_manager,
router_config, router_config,
migration_limit, migration_limit,
migration_max_seq_len,
None, None,
prefill_load_estimator, prefill_load_estimator,
metrics, metrics,
......
...@@ -76,6 +76,7 @@ pub async fn run( ...@@ -76,6 +76,7 @@ pub async fn run(
let router_config = model.router_config(); let router_config = model.router_config();
let migration_limit = model.migration_limit(); let migration_limit = model.migration_limit();
let migration_max_seq_len = model.migration_max_seq_len();
// Listen for models registering themselves, add them to HTTP service // Listen for models registering themselves, add them to HTTP service
// Create namespace filter from model configuration // Create namespace filter from model configuration
let namespace_filter = NamespaceFilter::from_namespace_and_prefix( let namespace_filter = NamespaceFilter::from_namespace_and_prefix(
...@@ -87,6 +88,7 @@ pub async fn run( ...@@ -87,6 +88,7 @@ pub async fn run(
http_service.state().manager_clone(), http_service.state().manager_clone(),
router_config.clone(), router_config.clone(),
migration_limit, migration_limit,
migration_max_seq_len,
namespace_filter, namespace_filter,
Arc::new(http_service.clone()), Arc::new(http_service.clone()),
http_service.state().metrics_clone(), http_service.state().metrics_clone(),
...@@ -165,6 +167,7 @@ async fn run_watcher( ...@@ -165,6 +167,7 @@ async fn run_watcher(
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
router_config: RouterConfig, router_config: RouterConfig,
migration_limit: u32, migration_limit: u32,
migration_max_seq_len: Option<u32>,
namespace_filter: NamespaceFilter, namespace_filter: NamespaceFilter,
http_service: Arc<HttpService>, http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>, metrics: Arc<crate::http::service::metrics::Metrics>,
...@@ -176,6 +179,7 @@ async fn run_watcher( ...@@ -176,6 +179,7 @@ async fn run_watcher(
model_manager, model_manager,
router_config, router_config,
migration_limit, migration_limit,
migration_max_seq_len,
chat_engine_factory, chat_engine_factory,
prefill_load_estimator, prefill_load_estimator,
metrics.clone(), metrics.clone(),
......
...@@ -266,6 +266,7 @@ pub struct Metrics { ...@@ -266,6 +266,7 @@ pub struct Metrics {
model_kv_cache_block_size: IntGaugeVec, model_kv_cache_block_size: IntGaugeVec,
model_migration_limit: IntGaugeVec, model_migration_limit: IntGaugeVec,
model_migration_total: IntCounterVec, model_migration_total: IntCounterVec,
model_migration_max_seq_len_exceeded_total: IntCounterVec,
model_cancellation_total: IntCounterVec, model_cancellation_total: IntCounterVec,
model_rejection_total: IntCounterVec, model_rejection_total: IntCounterVec,
} }
...@@ -684,6 +685,15 @@ impl Metrics { ...@@ -684,6 +685,15 @@ impl Metrics {
) )
.unwrap(); .unwrap();
let model_migration_max_seq_len_exceeded_total = IntCounterVec::new(
Opts::new(
frontend_metric_name(frontend_service::MODEL_MIGRATION_MAX_SEQ_LEN_EXCEEDED_TOTAL),
"Total number of times migration was disabled by max_seq_len limit",
),
&["model"],
)
.unwrap();
let model_cancellation_total = IntCounterVec::new( let model_cancellation_total = IntCounterVec::new(
Opts::new( Opts::new(
frontend_metric_name(frontend_service::MODEL_CANCELLATION_TOTAL), frontend_metric_name(frontend_service::MODEL_CANCELLATION_TOTAL),
...@@ -722,6 +732,7 @@ impl Metrics { ...@@ -722,6 +732,7 @@ impl Metrics {
model_kv_cache_block_size, model_kv_cache_block_size,
model_migration_limit, model_migration_limit,
model_migration_total, model_migration_total,
model_migration_max_seq_len_exceeded_total,
model_cancellation_total, model_cancellation_total,
model_rejection_total, model_rejection_total,
} }
...@@ -828,6 +839,9 @@ impl Metrics { ...@@ -828,6 +839,9 @@ impl Metrics {
registry.register(Box::new(self.model_kv_cache_block_size.clone()))?; registry.register(Box::new(self.model_kv_cache_block_size.clone()))?;
registry.register(Box::new(self.model_migration_limit.clone()))?; registry.register(Box::new(self.model_migration_limit.clone()))?;
registry.register(Box::new(self.model_migration_total.clone()))?; registry.register(Box::new(self.model_migration_total.clone()))?;
registry.register(Box::new(
self.model_migration_max_seq_len_exceeded_total.clone(),
))?;
registry.register(Box::new(self.model_cancellation_total.clone()))?; registry.register(Box::new(self.model_cancellation_total.clone()))?;
registry.register(Box::new(self.model_rejection_total.clone()))?; registry.register(Box::new(self.model_rejection_total.clone()))?;
...@@ -913,6 +927,20 @@ impl Metrics { ...@@ -913,6 +927,20 @@ impl Metrics {
.get() .get()
} }
/// Increment the counter for migrations disabled by max_seq_len being exceeded
pub fn inc_migration_max_seq_len_exceeded(&self, model: &str) {
self.model_migration_max_seq_len_exceeded_total
.with_label_values(&[model])
.inc();
}
/// Get the current count of migrations disabled by max_seq_len being exceeded
pub fn get_migration_max_seq_len_exceeded_count(&self, model: &str) -> u64 {
self.model_migration_max_seq_len_exceeded_total
.with_label_values(&[model])
.get()
}
/// Increment the cancellation counter /// Increment the cancellation counter
pub fn inc_cancellation(&self, labels: &CancellationLabels) { pub fn inc_cancellation(&self, labels: &CancellationLabels) {
self.model_cancellation_total self.model_cancellation_total
......
...@@ -48,6 +48,7 @@ pub struct LocalModelBuilder { ...@@ -48,6 +48,7 @@ pub struct LocalModelBuilder {
tls_cert_path: Option<PathBuf>, tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>, tls_key_path: Option<PathBuf>,
migration_limit: u32, migration_limit: u32,
migration_max_seq_len: Option<u32>,
is_mocker: bool, is_mocker: bool,
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
runtime_config: ModelRuntimeConfig, runtime_config: ModelRuntimeConfig,
...@@ -76,6 +77,7 @@ impl Default for LocalModelBuilder { ...@@ -76,6 +77,7 @@ impl Default for LocalModelBuilder {
template_file: Default::default(), template_file: Default::default(),
router_config: Default::default(), router_config: Default::default(),
migration_limit: Default::default(), migration_limit: Default::default(),
migration_max_seq_len: Default::default(),
is_mocker: Default::default(), is_mocker: Default::default(),
extra_engine_args: Default::default(), extra_engine_args: Default::default(),
runtime_config: Default::default(), runtime_config: Default::default(),
...@@ -180,6 +182,11 @@ impl LocalModelBuilder { ...@@ -180,6 +182,11 @@ impl LocalModelBuilder {
self self
} }
pub fn migration_max_seq_len(&mut self, max_seq_len: Option<u32>) -> &mut Self {
self.migration_max_seq_len = max_seq_len;
self
}
pub fn is_mocker(&mut self, is_mocker: bool) -> &mut Self { pub fn is_mocker(&mut self, is_mocker: bool) -> &mut Self {
self.is_mocker = is_mocker; self.is_mocker = is_mocker;
self self
...@@ -260,6 +267,7 @@ impl LocalModelBuilder { ...@@ -260,6 +267,7 @@ impl LocalModelBuilder {
namespace: self.namespace.clone(), namespace: self.namespace.clone(),
namespace_prefix: self.namespace_prefix.clone(), namespace_prefix: self.namespace_prefix.clone(),
migration_limit: self.migration_limit, migration_limit: self.migration_limit,
migration_max_seq_len: self.migration_max_seq_len,
}); });
} }
...@@ -313,6 +321,7 @@ impl LocalModelBuilder { ...@@ -313,6 +321,7 @@ impl LocalModelBuilder {
namespace: self.namespace.clone(), namespace: self.namespace.clone(),
namespace_prefix: self.namespace_prefix.clone(), namespace_prefix: self.namespace_prefix.clone(),
migration_limit: self.migration_limit, migration_limit: self.migration_limit,
migration_max_seq_len: self.migration_max_seq_len,
}) })
} }
} }
...@@ -333,6 +342,7 @@ pub struct LocalModel { ...@@ -333,6 +342,7 @@ pub struct LocalModel {
namespace: Option<String>, namespace: Option<String>,
namespace_prefix: Option<String>, namespace_prefix: Option<String>,
migration_limit: u32, migration_limit: u32,
migration_max_seq_len: Option<u32>,
} }
impl LocalModel { impl LocalModel {
...@@ -400,6 +410,10 @@ impl LocalModel { ...@@ -400,6 +410,10 @@ impl LocalModel {
self.migration_limit self.migration_limit
} }
pub fn migration_max_seq_len(&self) -> Option<u32> {
self.migration_max_seq_len
}
pub fn namespace(&self) -> Option<&str> { pub fn namespace(&self) -> Option<&str> {
self.namespace.as_deref() self.namespace.as_deref()
} }
......
...@@ -33,15 +33,27 @@ fn is_migratable(err: &(dyn StdError + 'static)) -> bool { ...@@ -33,15 +33,27 @@ fn is_migratable(err: &(dyn StdError + 'static)) -> bool {
pub struct Migration { pub struct Migration {
migration_limit: u32, migration_limit: u32,
max_seq_len: Option<u32>,
model_name: Arc<String>, model_name: Arc<String>,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
} }
impl Migration { impl Migration {
pub fn new(migration_limit: u32, model_name: String, metrics: Arc<Metrics>) -> Arc<Self> { pub fn new(
tracing::debug!("model {} migration limit {}", model_name, migration_limit); migration_limit: u32,
max_seq_len: Option<u32>,
model_name: String,
metrics: Arc<Metrics>,
) -> Arc<Self> {
tracing::debug!(
"model {} migration limit {} max_seq_len {:?}",
model_name,
migration_limit,
max_seq_len
);
Arc::new(Self { Arc::new(Self {
migration_limit, migration_limit,
max_seq_len,
model_name: Arc::new(model_name), model_name: Arc::new(model_name),
metrics, metrics,
}) })
...@@ -50,9 +62,15 @@ impl Migration { ...@@ -50,9 +62,15 @@ impl Migration {
pub fn from_mdc( pub fn from_mdc(
mdc: &ModelDeploymentCard, mdc: &ModelDeploymentCard,
migration_limit: u32, migration_limit: u32,
max_seq_len: Option<u32>,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
) -> Arc<Self> { ) -> Arc<Self> {
Self::new(migration_limit, mdc.display_name.clone(), metrics) Self::new(
migration_limit,
max_seq_len,
mdc.display_name.clone(),
metrics,
)
} }
} }
...@@ -78,6 +96,7 @@ impl ...@@ -78,6 +96,7 @@ impl
preprocessed_request, preprocessed_request,
next, next,
self.migration_limit, self.migration_limit,
self.max_seq_len,
self.model_name.clone(), self.model_name.clone(),
self.metrics.clone(), self.metrics.clone(),
) )
...@@ -99,6 +118,7 @@ struct RetryManager { ...@@ -99,6 +118,7 @@ struct RetryManager {
next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>, next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>,
next_stream: Option<ManyOut<Annotated<BackendOutput>>>, next_stream: Option<ManyOut<Annotated<BackendOutput>>>,
retries_left: u32, retries_left: u32,
max_seq_len: Option<u32>,
model_name: Arc<String>, model_name: Arc<String>,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
} }
...@@ -109,6 +129,7 @@ impl RetryManager { ...@@ -109,6 +129,7 @@ impl RetryManager {
preprocessed_request: PreprocessedRequest, preprocessed_request: PreprocessedRequest,
next: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>, next: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>>,
mut retries_left: u32, mut retries_left: u32,
max_seq_len: Option<u32>,
model_name: Arc<String>, model_name: Arc<String>,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
) -> Result<Self> { ) -> Result<Self> {
...@@ -137,10 +158,12 @@ impl RetryManager { ...@@ -137,10 +158,12 @@ impl RetryManager {
next_generate: next, next_generate: next,
next_stream: None, next_stream: None,
retries_left: retries_left + 1, // +1 to account for the initial attempt retries_left: retries_left + 1, // +1 to account for the initial attempt
max_seq_len,
model_name, model_name,
metrics, metrics,
}; };
slf.new_stream().await?; slf.new_stream().await?;
slf.exceed_max_seq_len(0); // disable migration if prompt len > max_seq_len
Ok(slf) Ok(slf)
} }
...@@ -216,6 +239,9 @@ impl RetryManager { ...@@ -216,6 +239,9 @@ impl RetryManager {
Some(output) => output, Some(output) => output,
None => return, None => return,
}; };
if self.exceed_max_seq_len(llm_engine_output.token_ids.len() as u32) {
return;
}
if let Some(max_tokens) = self.request.stop_conditions.max_tokens { if let Some(max_tokens) = self.request.stop_conditions.max_tokens {
self.request.stop_conditions.max_tokens = self.request.stop_conditions.max_tokens =
Some(max_tokens.saturating_sub(llm_engine_output.token_ids.len() as u32)); Some(max_tokens.saturating_sub(llm_engine_output.token_ids.len() as u32));
...@@ -224,6 +250,27 @@ impl RetryManager { ...@@ -224,6 +250,27 @@ impl RetryManager {
self.request.token_ids.push(*token_id); self.request.token_ids.push(*token_id);
} }
} }
/// Returns `true` if the tracked request token length plus `new_output_len`
/// exceeds the configured max_seq_len, in which case migration is disabled.
fn exceed_max_seq_len(&mut self, new_output_len: u32) -> bool {
if let Some(max_seq_len) = self.max_seq_len {
let total_len = self.request.token_ids.len() as u32 + new_output_len;
if total_len > max_seq_len {
tracing::warn!(
"Sequence length {} exceeds migration max_seq_len {}, \
disabling migration",
total_len,
max_seq_len
);
self.metrics
.inc_migration_max_seq_len_exceeded(&self.model_name);
self.retries_left = 0; // disable migration
return true;
}
}
false
}
} }
#[cfg(test)] #[cfg(test)]
...@@ -581,6 +628,7 @@ mod tests { ...@@ -581,6 +628,7 @@ mod tests {
request, request,
next_generate, next_generate,
0, 0,
None,
Arc::new(TEST_MODEL.to_string()), Arc::new(TEST_MODEL.to_string()),
metrics.clone(), metrics.clone(),
) )
...@@ -631,6 +679,7 @@ mod tests { ...@@ -631,6 +679,7 @@ mod tests {
request, request,
next_generate, next_generate,
3, 3,
None,
Arc::new(TEST_MODEL.to_string()), Arc::new(TEST_MODEL.to_string()),
metrics.clone(), metrics.clone(),
) )
...@@ -682,6 +731,7 @@ mod tests { ...@@ -682,6 +731,7 @@ mod tests {
request, request,
next_generate, next_generate,
3, 3,
None,
Arc::new(TEST_MODEL.to_string()), Arc::new(TEST_MODEL.to_string()),
metrics.clone(), metrics.clone(),
) )
...@@ -734,6 +784,7 @@ mod tests { ...@@ -734,6 +784,7 @@ mod tests {
request, request,
next_generate, next_generate,
3, 3,
None,
Arc::new(TEST_MODEL.to_string()), Arc::new(TEST_MODEL.to_string()),
metrics.clone(), metrics.clone(),
) )
...@@ -773,6 +824,7 @@ mod tests { ...@@ -773,6 +824,7 @@ mod tests {
request, request,
next_generate, next_generate,
3, 3,
None,
Arc::new(TEST_MODEL.to_string()), Arc::new(TEST_MODEL.to_string()),
metrics.clone(), metrics.clone(),
) // 3 retries ) // 3 retries
...@@ -831,6 +883,7 @@ mod tests { ...@@ -831,6 +883,7 @@ mod tests {
request, request,
next_generate, next_generate,
3, 3,
None,
Arc::new(TEST_MODEL.to_string()), Arc::new(TEST_MODEL.to_string()),
metrics.clone(), metrics.clone(),
) // 3 retries ) // 3 retries
...@@ -894,6 +947,7 @@ mod tests { ...@@ -894,6 +947,7 @@ mod tests {
request, request,
next_generate, next_generate,
3, 3,
None,
Arc::new(TEST_MODEL.to_string()), Arc::new(TEST_MODEL.to_string()),
metrics.clone(), metrics.clone(),
) )
...@@ -962,6 +1016,7 @@ mod tests { ...@@ -962,6 +1016,7 @@ mod tests {
request, request,
next_generate, next_generate,
3, // migration_limit=3 — should be ignored for guided-decoding requests 3, // migration_limit=3 — should be ignored for guided-decoding requests
None,
Arc::new(TEST_MODEL.to_string()), Arc::new(TEST_MODEL.to_string()),
metrics.clone(), metrics.clone(),
) )
...@@ -1003,4 +1058,211 @@ mod tests { ...@@ -1003,4 +1058,211 @@ mod tests {
"Error type should be Disconnected" "Error type should be Disconnected"
); );
} }
/// Test case 9: max_seq_len exceeded limit + 1 disables migration
///
/// Boundary test: prompt has 3 tokens, max_seq_len = 5. After 2 generated tokens the
/// total is 5 (== max_seq_len) — still migratable. The 3rd generated token would push
/// the total to 6 (> max_seq_len), which disables migration and stops caching.
/// The failure is placed right at that point (fail_after: 3) so we see the error
/// propagated instead of retried.
#[tokio::test]
async fn test_retry_manager_max_seq_len_exceeded() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
// Prompt = [1, 2, 3] (len 3). max_seq_len = 5.
// Token 101 → total 4 ≤ 5: tracked.
// Token 102 → total 5 ≤ 5: tracked.
// Token 103 → would-be 6 > 5: NOT tracked, migration disabled.
// Error follows immediately (fail_after: 3) → not retried.
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::MidStreamFail { fail_after: 3 },
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone()));
let metrics = Arc::new(Metrics::new());
let mut retry_manager = RetryManager::build(
ctx,
request,
next_generate,
3,
Some(5), // prompt(3) + 3 generated = 6 > 5 → disables migration
Arc::new(TEST_MODEL.to_string()),
metrics.clone(),
)
.await
.expect("Failed to build RetryManager");
let mut responses = Vec::new();
while let Some(response) = retry_manager.next().await {
responses.push(response);
}
// 3 successful tokens + 1 Disconnected error (migration disabled at token 103).
assert_eq!(
responses.len(),
4,
"Expected 3 successful + 1 error (migration disabled by max_seq_len)"
);
for (i, response) in responses[0..3].iter().enumerate() {
assert!(response.err().is_none(), "Response {} should be OK", i);
}
let err = responses[3]
.err()
.expect("Last response should be Disconnected error");
assert_eq!(err.error_type(), ErrorType::Disconnected);
// Migration was attempted but blocked because max_seq_len set retries_left to 0.
// The ongoing metric is still incremented (it counts attempts, not successes).
assert_eq!(metrics.get_migration_new_request_count(TEST_MODEL), 0);
assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 1);
// max_seq_len limit was triggered once (at token 103).
assert_eq!(
metrics.get_migration_max_seq_len_exceeded_count(TEST_MODEL),
1
);
}
/// Test case 10: Migration succeeds when sequence length is at max_seq_len
///
/// Boundary test: prompt has 3 tokens, max_seq_len = 5. After 2 generated tokens
/// the total is exactly 5 (== max_seq_len). The failure occurs at that point
/// (fail_after: 2). Because we use strict inequality (> not >=), the request is
/// still migratable and the retry succeeds.
#[tokio::test]
async fn test_retry_manager_max_seq_len_at_limit() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
// Prompt = [1, 2, 3] (len 3). max_seq_len = 5.
// Token 101 → total 4 ≤ 5: tracked.
// Token 102 → total 5 == 5: tracked (still migratable — strict >).
// Error (fail_after: 2) → migration succeeds, retry delivers remaining tokens.
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::MidStreamFail { fail_after: 2 },
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone()));
let metrics = Arc::new(Metrics::new());
let mut retry_manager = RetryManager::build(
ctx,
request,
next_generate,
3,
Some(5), // prompt(3) + 2 generated = 5 == max_seq_len → still migratable
Arc::new(TEST_MODEL.to_string()),
metrics.clone(),
)
.await
.expect("Failed to build RetryManager");
let mut responses = Vec::new();
while let Some(response) = retry_manager.next().await {
responses.push(response);
}
// Migration succeeds — all 10 responses delivered
assert_eq!(responses.len(), 10);
for response in &responses {
assert!(response.err().is_none());
}
assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 1);
// Tracked token_ids must equal exactly max_seq_len (5).
// The 2 tokens from the first stream were tracked (prompt 3 + gen 2 = 5).
// After migration the retry stream delivers remaining tokens, but the first
// new token would push to 6 > 5, so tracking stops and no more are appended.
assert_eq!(
retry_manager.request.token_ids.len(),
5,
"tracked token_ids should be exactly max_seq_len"
);
// The limit was triggered once (first token of the retry stream exceeded 5).
assert_eq!(
metrics.get_migration_max_seq_len_exceeded_count(TEST_MODEL),
1
);
}
/// Test case 11: Prompt length alone exceeds max_seq_len
///
/// When the prompt tokens already exceed max_seq_len, migration is disabled
/// in RetryManager::build before any tokens are generated. A mid-stream
/// failure should propagate the error without attempting migration.
#[tokio::test]
async fn test_retry_manager_max_seq_len_exceeded_by_prompt() {
dynamo_runtime::logging::init();
let context_id = uuid::Uuid::new_v4().to_string();
// Prompt = [1, 2, 3] (len 3). max_seq_len = 2, so prompt alone exceeds the limit.
let request = create_mock_request(10);
let mock_engine = Arc::new(MockEngine::new(
MockBehavior::MidStreamFail { fail_after: 3 },
10,
100,
context_id.clone(),
));
let next_generate: ServerStreamingEngine<PreprocessedRequest, Annotated<BackendOutput>> =
mock_engine;
let ctx = Arc::new(Controller::new(context_id.clone()));
let metrics = Arc::new(Metrics::new());
let mut retry_manager = RetryManager::build(
ctx,
request,
next_generate,
3,
Some(2), // prompt(3) > max_seq_len(2) → migration disabled at build time
Arc::new(TEST_MODEL.to_string()),
metrics.clone(),
)
.await
.expect("Failed to build RetryManager");
let mut responses = Vec::new();
while let Some(response) = retry_manager.next().await {
responses.push(response);
}
// 3 successful tokens + 1 Disconnected error (migration disabled from the start).
assert_eq!(
responses.len(),
4,
"Expected 3 successful + 1 error (migration disabled by prompt exceeding max_seq_len)"
);
for (i, response) in responses[0..3].iter().enumerate() {
assert!(response.err().is_none(), "Response {} should be OK", i);
}
let err = responses[3]
.err()
.expect("Last response should be Disconnected error");
assert_eq!(err.error_type(), ErrorType::Disconnected);
// max_seq_len was exceeded at build time (prompt len 3 > 2).
assert_eq!(
metrics.get_migration_max_seq_len_exceeded_count(TEST_MODEL),
1
);
// Migration was attempted but blocked (retries_left was already 0).
assert_eq!(metrics.get_migration_ongoing_request_count(TEST_MODEL), 1);
}
} }
...@@ -343,6 +343,7 @@ mod integration_tests { ...@@ -343,6 +343,7 @@ mod integration_tests {
service.state().manager_clone(), service.state().manager_clone(),
dynamo_llm::entrypoint::RouterConfig::default(), dynamo_llm::entrypoint::RouterConfig::default(),
0, // migration_limit 0, // migration_limit
None, // migration_max_seq_len
None, None,
None, None,
service.state().metrics_clone(), service.state().metrics_clone(),
......
...@@ -232,6 +232,11 @@ pub mod frontend_service { ...@@ -232,6 +232,11 @@ pub mod frontend_service {
/// Total number of request migrations due to worker unavailability /// Total number of request migrations due to worker unavailability
pub const MODEL_MIGRATION_TOTAL: &str = "model_migration_total"; pub const MODEL_MIGRATION_TOTAL: &str = "model_migration_total";
/// Total number of times migration was disabled because the sequence length
/// exceeded the configured max_seq_len limit
pub const MODEL_MIGRATION_MAX_SEQ_LEN_EXCEEDED_TOTAL: &str =
"model_migration_max_seq_len_exceeded_total";
/// Total number of request cancellations /// Total number of request cancellations
pub const MODEL_CANCELLATION_TOTAL: &str = "model_cancellation_total"; pub const MODEL_CANCELLATION_TOTAL: &str = "model_cancellation_total";
......
...@@ -34,6 +34,14 @@ pytestmark = [ ...@@ -34,6 +34,14 @@ pytestmark = [
pytest.mark.parametrize( pytest.mark.parametrize(
"migration_limit", [3, 0], ids=["migration_enabled", "migration_disabled"] "migration_limit", [3, 0], ids=["migration_enabled", "migration_disabled"]
), ),
pytest.mark.parametrize(
"migration_max_seq_len",
[
pytest.param(None, id="max_seq_len_disabled"),
pytest.param(1_000_000, id="max_seq_len_not_exceeded"),
pytest.param(1, id="max_seq_len_exceeded"),
],
),
pytest.mark.parametrize( pytest.mark.parametrize(
"immediate_kill", "immediate_kill",
[ [
...@@ -217,6 +225,7 @@ def test_request_migration_sglang_aggregated( ...@@ -217,6 +225,7 @@ def test_request_migration_sglang_aggregated(
set_ucx_tls_no_mm, set_ucx_tls_no_mm,
predownload_models, predownload_models,
migration_limit, migration_limit,
migration_max_seq_len,
immediate_kill, immediate_kill,
request_api, request_api,
stream, stream,
...@@ -227,12 +236,17 @@ def test_request_migration_sglang_aggregated( ...@@ -227,12 +236,17 @@ def test_request_migration_sglang_aggregated(
Parameters: Parameters:
immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM) immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
migration_limit: > 0 to verify migration succeeds, 0 to verify request fails migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
migration_max_seq_len: Max sequence length for migration state tracking
request_api: "chat" for chat completion API, "completion" for completion API request_api: "chat" for chat completion API, "completion" for completion API
stream: True for streaming, False for non-streaming stream: True for streaming, False for non-streaming
""" """
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, migration_limit=migration_limit) as frontend: with DynamoFrontendProcess(
request,
migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start 2 workers # Step 2: Start 2 workers
...@@ -253,6 +267,7 @@ def test_request_migration_sglang_aggregated( ...@@ -253,6 +267,7 @@ def test_request_migration_sglang_aggregated(
worker2, worker2,
receiving_pattern="New Request ID: ", receiving_pattern="New Request ID: ",
migration_limit=migration_limit, migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
immediate_kill=immediate_kill, immediate_kill=immediate_kill,
use_chat_completion=(request_api == "chat"), use_chat_completion=(request_api == "chat"),
stream=stream, stream=stream,
...@@ -269,6 +284,7 @@ def test_request_migration_sglang_prefill( ...@@ -269,6 +284,7 @@ def test_request_migration_sglang_prefill(
set_ucx_tls_no_mm, set_ucx_tls_no_mm,
predownload_models, predownload_models,
migration_limit, migration_limit,
migration_max_seq_len,
immediate_kill, immediate_kill,
request_api, request_api,
stream, stream,
...@@ -286,7 +302,11 @@ def test_request_migration_sglang_prefill( ...@@ -286,7 +302,11 @@ def test_request_migration_sglang_prefill(
""" """
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, migration_limit=migration_limit) as frontend: with DynamoFrontendProcess(
request,
migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start decode worker first (required for prefill workers to connect) # Step 2: Start decode worker first (required for prefill workers to connect)
...@@ -322,6 +342,7 @@ def test_request_migration_sglang_prefill( ...@@ -322,6 +342,7 @@ def test_request_migration_sglang_prefill(
prefill2, prefill2,
receiving_pattern="New Request ID: ", receiving_pattern="New Request ID: ",
migration_limit=migration_limit, migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
immediate_kill=immediate_kill, immediate_kill=immediate_kill,
use_chat_completion=(request_api == "chat"), use_chat_completion=(request_api == "chat"),
stream=stream, stream=stream,
...@@ -338,6 +359,7 @@ def test_request_migration_sglang_kv_transfer( ...@@ -338,6 +359,7 @@ def test_request_migration_sglang_kv_transfer(
set_ucx_tls_no_mm, set_ucx_tls_no_mm,
predownload_models, predownload_models,
migration_limit, migration_limit,
migration_max_seq_len,
immediate_kill, immediate_kill,
request_api, request_api,
stream, stream,
...@@ -355,7 +377,11 @@ def test_request_migration_sglang_kv_transfer( ...@@ -355,7 +377,11 @@ def test_request_migration_sglang_kv_transfer(
""" """
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, migration_limit=migration_limit) as frontend: with DynamoFrontendProcess(
request,
migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start prefill worker first # Step 2: Start prefill worker first
...@@ -391,6 +417,7 @@ def test_request_migration_sglang_kv_transfer( ...@@ -391,6 +417,7 @@ def test_request_migration_sglang_kv_transfer(
decode2, decode2,
receiving_pattern="New Request ID: ", receiving_pattern="New Request ID: ",
migration_limit=migration_limit, migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
immediate_kill=immediate_kill, immediate_kill=immediate_kill,
use_chat_completion=(request_api == "chat"), use_chat_completion=(request_api == "chat"),
stream=stream, stream=stream,
...@@ -406,6 +433,7 @@ def test_request_migration_sglang_decode( ...@@ -406,6 +433,7 @@ def test_request_migration_sglang_decode(
set_ucx_tls_no_mm, set_ucx_tls_no_mm,
predownload_models, predownload_models,
migration_limit, migration_limit,
migration_max_seq_len,
immediate_kill, immediate_kill,
request_api, request_api,
stream, stream,
...@@ -427,7 +455,11 @@ def test_request_migration_sglang_decode( ...@@ -427,7 +455,11 @@ def test_request_migration_sglang_decode(
) )
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, migration_limit=migration_limit) as frontend: with DynamoFrontendProcess(
request,
migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start prefill worker first # Step 2: Start prefill worker first
...@@ -463,6 +495,7 @@ def test_request_migration_sglang_decode( ...@@ -463,6 +495,7 @@ def test_request_migration_sglang_decode(
decode2, decode2,
receiving_pattern="New Request ID: ", receiving_pattern="New Request ID: ",
migration_limit=migration_limit, migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
immediate_kill=immediate_kill, immediate_kill=immediate_kill,
use_chat_completion=(request_api == "chat"), use_chat_completion=(request_api == "chat"),
stream=stream, stream=stream,
......
...@@ -34,6 +34,14 @@ pytestmark = [ ...@@ -34,6 +34,14 @@ pytestmark = [
pytest.mark.parametrize( pytest.mark.parametrize(
"migration_limit", [3, 0], ids=["migration_enabled", "migration_disabled"] "migration_limit", [3, 0], ids=["migration_enabled", "migration_disabled"]
), ),
pytest.mark.parametrize(
"migration_max_seq_len",
[
pytest.param(None, id="max_seq_len_disabled"),
pytest.param(1_000_000, id="max_seq_len_not_exceeded"),
pytest.param(1, id="max_seq_len_exceeded"),
],
),
pytest.mark.parametrize( pytest.mark.parametrize(
"immediate_kill", [True, False], ids=["worker_failure", "graceful_shutdown"] "immediate_kill", [True, False], ids=["worker_failure", "graceful_shutdown"]
), ),
...@@ -195,6 +203,7 @@ def test_request_migration_trtllm_aggregated( ...@@ -195,6 +203,7 @@ def test_request_migration_trtllm_aggregated(
set_ucx_tls_no_mm, set_ucx_tls_no_mm,
predownload_models, predownload_models,
migration_limit, migration_limit,
migration_max_seq_len,
immediate_kill, immediate_kill,
request_api, request_api,
stream, stream,
...@@ -205,12 +214,17 @@ def test_request_migration_trtllm_aggregated( ...@@ -205,12 +214,17 @@ def test_request_migration_trtllm_aggregated(
Parameters: Parameters:
immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM) immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
migration_limit: > 0 to verify migration succeeds, 0 to verify request fails migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
migration_max_seq_len: Max sequence length for migration state tracking
request_api: "chat" for chat completion API, "completion" for completion API request_api: "chat" for chat completion API, "completion" for completion API
stream: True for streaming, False for non-streaming stream: True for streaming, False for non-streaming
""" """
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, migration_limit=migration_limit) as frontend: with DynamoFrontendProcess(
request,
migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start 2 workers # Step 2: Start 2 workers
...@@ -231,6 +245,7 @@ def test_request_migration_trtllm_aggregated( ...@@ -231,6 +245,7 @@ def test_request_migration_trtllm_aggregated(
worker2, worker2,
receiving_pattern="AggregatedHandler Request ID: ", receiving_pattern="AggregatedHandler Request ID: ",
migration_limit=migration_limit, migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
immediate_kill=immediate_kill, immediate_kill=immediate_kill,
use_chat_completion=(request_api == "chat"), use_chat_completion=(request_api == "chat"),
stream=stream, stream=stream,
...@@ -248,6 +263,7 @@ def test_request_migration_trtllm_prefill( ...@@ -248,6 +263,7 @@ def test_request_migration_trtllm_prefill(
set_ucx_tls_no_mm, set_ucx_tls_no_mm,
predownload_models, predownload_models,
migration_limit, migration_limit,
migration_max_seq_len,
immediate_kill, immediate_kill,
request_api, request_api,
stream, stream,
...@@ -265,7 +281,11 @@ def test_request_migration_trtllm_prefill( ...@@ -265,7 +281,11 @@ def test_request_migration_trtllm_prefill(
""" """
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, migration_limit=migration_limit) as frontend: with DynamoFrontendProcess(
request,
migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start decode worker first (required for prefill workers to connect) # Step 2: Start decode worker first (required for prefill workers to connect)
...@@ -301,6 +321,7 @@ def test_request_migration_trtllm_prefill( ...@@ -301,6 +321,7 @@ def test_request_migration_trtllm_prefill(
prefill2, prefill2,
receiving_pattern="Prefill Request ID: ", receiving_pattern="Prefill Request ID: ",
migration_limit=migration_limit, migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
immediate_kill=immediate_kill, immediate_kill=immediate_kill,
use_chat_completion=(request_api == "chat"), use_chat_completion=(request_api == "chat"),
stream=stream, stream=stream,
...@@ -317,6 +338,7 @@ def test_request_migration_trtllm_kv_transfer( ...@@ -317,6 +338,7 @@ def test_request_migration_trtllm_kv_transfer(
set_ucx_tls_no_mm, set_ucx_tls_no_mm,
predownload_models, predownload_models,
migration_limit, migration_limit,
migration_max_seq_len,
immediate_kill, immediate_kill,
request_api, request_api,
stream, stream,
...@@ -334,7 +356,11 @@ def test_request_migration_trtllm_kv_transfer( ...@@ -334,7 +356,11 @@ def test_request_migration_trtllm_kv_transfer(
""" """
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, migration_limit=migration_limit) as frontend: with DynamoFrontendProcess(
request,
migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start prefill worker first # Step 2: Start prefill worker first
...@@ -370,6 +396,7 @@ def test_request_migration_trtllm_kv_transfer( ...@@ -370,6 +396,7 @@ def test_request_migration_trtllm_kv_transfer(
decode2, decode2,
receiving_pattern="Decode Request ID: ", receiving_pattern="Decode Request ID: ",
migration_limit=migration_limit, migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
immediate_kill=immediate_kill, immediate_kill=immediate_kill,
use_chat_completion=(request_api == "chat"), use_chat_completion=(request_api == "chat"),
stream=stream, stream=stream,
...@@ -385,6 +412,7 @@ def test_request_migration_trtllm_decode( ...@@ -385,6 +412,7 @@ def test_request_migration_trtllm_decode(
set_ucx_tls_no_mm, set_ucx_tls_no_mm,
predownload_models, predownload_models,
migration_limit, migration_limit,
migration_max_seq_len,
immediate_kill, immediate_kill,
request_api, request_api,
stream, stream,
...@@ -406,7 +434,11 @@ def test_request_migration_trtllm_decode( ...@@ -406,7 +434,11 @@ def test_request_migration_trtllm_decode(
) )
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, migration_limit=migration_limit) as frontend: with DynamoFrontendProcess(
request,
migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start prefill worker first # Step 2: Start prefill worker first
...@@ -442,6 +474,7 @@ def test_request_migration_trtllm_decode( ...@@ -442,6 +474,7 @@ def test_request_migration_trtllm_decode(
decode2, decode2,
receiving_pattern="Decode Request ID: ", receiving_pattern="Decode Request ID: ",
migration_limit=migration_limit, migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
immediate_kill=immediate_kill, immediate_kill=immediate_kill,
use_chat_completion=(request_api == "chat"), use_chat_completion=(request_api == "chat"),
stream=stream, stream=stream,
......
...@@ -35,6 +35,14 @@ pytestmark = [ ...@@ -35,6 +35,14 @@ pytestmark = [
pytest.mark.parametrize( pytest.mark.parametrize(
"migration_limit", [3, 0], ids=["migration_enabled", "migration_disabled"] "migration_limit", [3, 0], ids=["migration_enabled", "migration_disabled"]
), ),
pytest.mark.parametrize(
"migration_max_seq_len",
[
pytest.param(None, id="max_seq_len_disabled"),
pytest.param(1_000_000, id="max_seq_len_not_exceeded"),
pytest.param(1, id="max_seq_len_exceeded"),
],
),
pytest.mark.parametrize( pytest.mark.parametrize(
"immediate_kill", [True, False], ids=["worker_failure", "graceful_shutdown"] "immediate_kill", [True, False], ids=["worker_failure", "graceful_shutdown"]
), ),
...@@ -214,6 +222,7 @@ def test_request_migration_vllm_aggregated( ...@@ -214,6 +222,7 @@ def test_request_migration_vllm_aggregated(
set_ucx_tls_no_mm, set_ucx_tls_no_mm,
predownload_models, predownload_models,
migration_limit, migration_limit,
migration_max_seq_len,
immediate_kill, immediate_kill,
request_api, request_api,
stream, stream,
...@@ -224,12 +233,17 @@ def test_request_migration_vllm_aggregated( ...@@ -224,12 +233,17 @@ def test_request_migration_vllm_aggregated(
Parameters: Parameters:
immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM) immediate_kill: True for abrupt kill (SIGKILL), False for graceful shutdown (SIGTERM)
migration_limit: > 0 to verify migration succeeds, 0 to verify request fails migration_limit: > 0 to verify migration succeeds, 0 to verify request fails
migration_max_seq_len: Max sequence length for migration state tracking
request_api: "chat" for chat completion API, "completion" for completion API request_api: "chat" for chat completion API, "completion" for completion API
stream: True for streaming, False for non-streaming stream: True for streaming, False for non-streaming
""" """
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, migration_limit=migration_limit) as frontend: with DynamoFrontendProcess(
request,
migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start 2 workers # Step 2: Start 2 workers
...@@ -250,6 +264,7 @@ def test_request_migration_vllm_aggregated( ...@@ -250,6 +264,7 @@ def test_request_migration_vllm_aggregated(
worker2, worker2,
receiving_pattern="Decode Request ID: ", receiving_pattern="Decode Request ID: ",
migration_limit=migration_limit, migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
immediate_kill=immediate_kill, immediate_kill=immediate_kill,
use_chat_completion=(request_api == "chat"), use_chat_completion=(request_api == "chat"),
stream=stream, stream=stream,
...@@ -265,6 +280,7 @@ def test_request_migration_vllm_prefill( ...@@ -265,6 +280,7 @@ def test_request_migration_vllm_prefill(
set_ucx_tls_no_mm, set_ucx_tls_no_mm,
predownload_models, predownload_models,
migration_limit, migration_limit,
migration_max_seq_len,
immediate_kill, immediate_kill,
request_api, request_api,
stream, stream,
...@@ -282,7 +298,11 @@ def test_request_migration_vllm_prefill( ...@@ -282,7 +298,11 @@ def test_request_migration_vllm_prefill(
""" """
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, migration_limit=migration_limit) as frontend: with DynamoFrontendProcess(
request,
migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start decode worker first (required for prefill workers to connect) # Step 2: Start decode worker first (required for prefill workers to connect)
...@@ -318,6 +338,7 @@ def test_request_migration_vllm_prefill( ...@@ -318,6 +338,7 @@ def test_request_migration_vllm_prefill(
prefill2, prefill2,
receiving_pattern="Prefill Request ID: ", receiving_pattern="Prefill Request ID: ",
migration_limit=migration_limit, migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
immediate_kill=immediate_kill, immediate_kill=immediate_kill,
use_chat_completion=(request_api == "chat"), use_chat_completion=(request_api == "chat"),
stream=stream, stream=stream,
...@@ -343,6 +364,7 @@ def test_request_migration_vllm_kv_transfer( ...@@ -343,6 +364,7 @@ def test_request_migration_vllm_kv_transfer(
set_ucx_tls_no_mm, set_ucx_tls_no_mm,
predownload_models, predownload_models,
migration_limit, migration_limit,
migration_max_seq_len,
immediate_kill, immediate_kill,
request_api, request_api,
stream, stream,
...@@ -360,7 +382,11 @@ def test_request_migration_vllm_kv_transfer( ...@@ -360,7 +382,11 @@ def test_request_migration_vllm_kv_transfer(
""" """
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, migration_limit=migration_limit) as frontend: with DynamoFrontendProcess(
request,
migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start prefill worker first # Step 2: Start prefill worker first
...@@ -396,6 +422,7 @@ def test_request_migration_vllm_kv_transfer( ...@@ -396,6 +422,7 @@ def test_request_migration_vllm_kv_transfer(
decode2, decode2,
receiving_pattern="Decode Request ID: ", receiving_pattern="Decode Request ID: ",
migration_limit=migration_limit, migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
immediate_kill=immediate_kill, immediate_kill=immediate_kill,
use_chat_completion=(request_api == "chat"), use_chat_completion=(request_api == "chat"),
stream=stream, stream=stream,
...@@ -421,6 +448,7 @@ def test_request_migration_vllm_decode( ...@@ -421,6 +448,7 @@ def test_request_migration_vllm_decode(
set_ucx_tls_no_mm, set_ucx_tls_no_mm,
predownload_models, predownload_models,
migration_limit, migration_limit,
migration_max_seq_len,
immediate_kill, immediate_kill,
request_api, request_api,
stream, stream,
...@@ -442,7 +470,11 @@ def test_request_migration_vllm_decode( ...@@ -442,7 +470,11 @@ def test_request_migration_vllm_decode(
) )
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, migration_limit=migration_limit) as frontend: with DynamoFrontendProcess(
request,
migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start prefill worker first # Step 2: Start prefill worker first
...@@ -478,6 +510,7 @@ def test_request_migration_vllm_decode( ...@@ -478,6 +510,7 @@ def test_request_migration_vllm_decode(
decode2, decode2,
receiving_pattern="Decode Request ID: ", receiving_pattern="Decode Request ID: ",
migration_limit=migration_limit, migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
immediate_kill=immediate_kill, immediate_kill=immediate_kill,
use_chat_completion=(request_api == "chat"), use_chat_completion=(request_api == "chat"),
stream=stream, stream=stream,
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import json import json
import logging import logging
import re
import threading import threading
import time import time
...@@ -21,7 +22,12 @@ logger = logging.getLogger(__name__) ...@@ -21,7 +22,12 @@ logger = logging.getLogger(__name__)
class DynamoFrontendProcess(BaseDynamoFrontendProcess): class DynamoFrontendProcess(BaseDynamoFrontendProcess):
"""Fault-tolerance frontend wrapper (keeps env settings from the historical helper).""" """Fault-tolerance frontend wrapper (keeps env settings from the historical helper)."""
def __init__(self, request, migration_limit: int): def __init__(
self,
request,
migration_limit: int,
migration_max_seq_len: int | None,
):
extra_env = { extra_env = {
"DYN_REQUEST_PLANE": request.getfixturevalue("request_plane"), "DYN_REQUEST_PLANE": request.getfixturevalue("request_plane"),
# These tests expect full control over requests sent to workers. The canary # These tests expect full control over requests sent to workers. The canary
...@@ -33,6 +39,7 @@ class DynamoFrontendProcess(BaseDynamoFrontendProcess): ...@@ -33,6 +39,7 @@ class DynamoFrontendProcess(BaseDynamoFrontendProcess):
frontend_port=0, # allocate a free port (xdist-safe) frontend_port=0, # allocate a free port (xdist-safe)
router_mode="round-robin", router_mode="round-robin",
migration_limit=migration_limit, migration_limit=migration_limit,
migration_max_seq_len=migration_max_seq_len,
extra_env=extra_env, extra_env=extra_env,
terminate_all_matching_process_names=False, terminate_all_matching_process_names=False,
display_name="frontend", display_name="frontend",
...@@ -464,8 +471,6 @@ def _parse_migration_metric( ...@@ -464,8 +471,6 @@ def _parse_migration_metric(
Returns: Returns:
The metric count, or 0 if not found The metric count, or 0 if not found
""" """
import re
# Match pattern like: # Match pattern like:
# dynamo_frontend_model_migration_total{migration_type="ongoing_request",model="Qwen/Qwen3-0.6B"} 1 # dynamo_frontend_model_migration_total{migration_type="ongoing_request",model="Qwen/Qwen3-0.6B"} 1
# Labels can be in any order # Labels can be in any order
...@@ -485,10 +490,25 @@ def _parse_migration_metric( ...@@ -485,10 +490,25 @@ def _parse_migration_metric(
return 0 return 0
def _parse_migration_max_seq_len_exceeded_metric(
metrics_text: str, model_name: str
) -> int:
"""
Parse the migration max_seq_len exceeded counter from Prometheus metrics text.
Returns:
The metric count, or 0 if not found
"""
pattern = rf'dynamo_frontend_model_migration_max_seq_len_exceeded_total\{{[^}}]*model="{re.escape(model_name)}"[^}}]*\}}\s+(\d+)'
match = re.search(pattern, metrics_text)
return int(match.group(1)) if match else 0
def verify_migration_metrics( def verify_migration_metrics(
frontend_port: int, frontend_port: int,
expected_ongoing_request_count: int = 0, expected_ongoing_request_count: int = 0,
expected_new_request_count: int = 0, expected_new_request_count: int = 0,
expected_max_seq_len_exceeded_count: int = 0,
) -> None: ) -> None:
""" """
Verify migration metrics by querying the frontend's /metrics endpoint. Verify migration metrics by querying the frontend's /metrics endpoint.
...@@ -497,6 +517,7 @@ def verify_migration_metrics( ...@@ -497,6 +517,7 @@ def verify_migration_metrics(
frontend_port: Port where the frontend is running frontend_port: Port where the frontend is running
expected_ongoing_request_count: Expected count of ongoing_request migrations expected_ongoing_request_count: Expected count of ongoing_request migrations
expected_new_request_count: Expected count of new_request migrations expected_new_request_count: Expected count of new_request migrations
expected_max_seq_len_exceeded_count: Expected count of max_seq_len exceeded events
""" """
metrics_url = f"http://localhost:{frontend_port}/metrics" metrics_url = f"http://localhost:{frontend_port}/metrics"
...@@ -516,9 +537,14 @@ def verify_migration_metrics( ...@@ -516,9 +537,14 @@ def verify_migration_metrics(
new_request_count = _parse_migration_metric( new_request_count = _parse_migration_metric(
metrics_text, FAULT_TOLERANCE_MODEL_NAME, "new_request" metrics_text, FAULT_TOLERANCE_MODEL_NAME, "new_request"
) )
max_seq_len_exceeded_count = _parse_migration_max_seq_len_exceeded_metric(
metrics_text, FAULT_TOLERANCE_MODEL_NAME
)
logger.info( logger.info(
f"Migration metrics - ongoing_request: {ongoing_count}, new_request: {new_request_count}" f"Migration metrics - ongoing_request: {ongoing_count}, "
f"new_request: {new_request_count}, "
f"max_seq_len_exceeded: {max_seq_len_exceeded_count}"
) )
if expected_ongoing_request_count > 0: if expected_ongoing_request_count > 0:
...@@ -533,6 +559,11 @@ def verify_migration_metrics( ...@@ -533,6 +559,11 @@ def verify_migration_metrics(
f"but got {new_request_count}" f"but got {new_request_count}"
) )
assert max_seq_len_exceeded_count == expected_max_seq_len_exceeded_count, (
f"Expected {expected_max_seq_len_exceeded_count} "
f"max_seq_len_exceeded events, but got {max_seq_len_exceeded_count}"
)
def run_migration_test( def run_migration_test(
frontend: DynamoFrontendProcess, frontend: DynamoFrontendProcess,
...@@ -540,6 +571,7 @@ def run_migration_test( ...@@ -540,6 +571,7 @@ def run_migration_test(
worker2: ManagedProcess, worker2: ManagedProcess,
receiving_pattern: str, receiving_pattern: str,
migration_limit: int, migration_limit: int,
migration_max_seq_len: int | None,
immediate_kill: bool, immediate_kill: bool,
use_chat_completion: bool, use_chat_completion: bool,
stream: bool, stream: bool,
...@@ -555,6 +587,7 @@ def run_migration_test( ...@@ -555,6 +587,7 @@ def run_migration_test(
worker2: Second worker process worker2: Second worker process
receiving_pattern: Log pattern to identify which worker received the request receiving_pattern: Log pattern to identify which worker received the request
migration_limit: Migration limit setting (0 = disabled) migration_limit: Migration limit setting (0 = disabled)
migration_max_seq_len: Max sequence length for migration (None = no limit)
immediate_kill: True for immediate kill, False for graceful shutdown immediate_kill: True for immediate kill, False for graceful shutdown
use_chat_completion: Whether to use chat completion API (True) or completion API (False) use_chat_completion: Whether to use chat completion API (True) or completion API (False)
stream: Whether to use streaming responses stream: Whether to use streaming responses
...@@ -590,19 +623,18 @@ def run_migration_test( ...@@ -590,19 +623,18 @@ def run_migration_test(
) )
terminate_process_tree(worker.get_pid(), immediate_kill=False, timeout=10) terminate_process_tree(worker.get_pid(), immediate_kill=False, timeout=10)
# Step 5: Validate response based on migration setting # Step 5: Validate response and verify migration occurred.
if migration_limit > 0: # migration_enabled and not max_seq_len_exceeded -> migration should succeed
if migration_limit > 0 and migration_max_seq_len != 1:
validate_response(request_thread, response_list, validate_delay=stream) validate_response(request_thread, response_list, validate_delay=stream)
verify_migration_occurred(frontend) verify_migration_occurred(frontend)
verify_migration_metrics(
frontend.frontend_port, expected_ongoing_request_count=1
)
else: else:
try: try:
validate_response(request_thread, response_list, validate_delay=stream) validate_response(request_thread, response_list, validate_delay=stream)
pytest.fail("Request succeeded unexpectedly when migration was disabled") pytest.fail(
"Request succeeded unexpectedly when migration should have failed"
)
except Exception as e: except Exception as e:
# Request failed as expected - verify it's a known error type
error_str = str(e) error_str = str(e)
assert ( assert (
"SSE error event received:" in error_str "SSE error event received:" in error_str
...@@ -611,6 +643,13 @@ def run_migration_test( ...@@ -611,6 +643,13 @@ def run_migration_test(
try: try:
verify_migration_occurred(frontend) verify_migration_occurred(frontend)
pytest.fail("Migration unexpectedly occurred when disabled") pytest.fail("Migration unexpectedly succeeded")
except AssertionError as e: except AssertionError as e:
assert "'Cannot recreate stream: ...' error found in logs" in str(e) assert "'Cannot recreate stream: ...' error found in logs" in str(e)
# Step 6: Verify migration metrics
verify_migration_metrics(
frontend.frontend_port,
expected_ongoing_request_count=1 if migration_limit > 0 else 0,
expected_max_seq_len_exceeded_count=1 if migration_max_seq_len == 1 else 0,
)
...@@ -913,6 +913,7 @@ class DynamoFrontendProcess(ManagedProcess): ...@@ -913,6 +913,7 @@ class DynamoFrontendProcess(ManagedProcess):
frontend_port: Optional[int] = None, frontend_port: Optional[int] = None,
router_mode: str = "round-robin", router_mode: str = "round-robin",
migration_limit: int = 0, migration_limit: int = 0,
migration_max_seq_len: Optional[int] = None,
extra_args: Optional[list[str]] = None, extra_args: Optional[list[str]] = None,
extra_env: Optional[dict[str, str]] = None, extra_env: Optional[dict[str, str]] = None,
# Default to false so pytest-xdist workers don't kill each other's frontends. # Default to false so pytest-xdist workers don't kill each other's frontends.
...@@ -944,6 +945,8 @@ class DynamoFrontendProcess(ManagedProcess): ...@@ -944,6 +945,8 @@ class DynamoFrontendProcess(ManagedProcess):
command.extend(["--http-port", str(frontend_port)]) command.extend(["--http-port", str(frontend_port)])
# Migration limit is configured at the frontend level # Migration limit is configured at the frontend level
command.extend(["--migration-limit", str(migration_limit)]) command.extend(["--migration-limit", str(migration_limit)])
if migration_max_seq_len is not None:
command.extend(["--migration-max-seq-len", str(migration_max_seq_len)])
if extra_args: if extra_args:
command.extend(extra_args) command.extend(extra_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment