Unverified Commit 1ffa489e authored by Jacky's avatar Jacky Committed by GitHub
Browse files

refactor: Move --migration-limit flag from backend to frontend (#5918)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent 3842b244
...@@ -58,6 +58,7 @@ pub struct ModelWatcher { ...@@ -58,6 +58,7 @@ pub struct ModelWatcher {
manager: Arc<ModelManager>, manager: Arc<ModelManager>,
drt: DistributedRuntime, drt: DistributedRuntime,
router_config: RouterConfig, router_config: RouterConfig,
migration_limit: u32,
notify_on_model: Notify, notify_on_model: Notify,
model_update_tx: Option<Sender<ModelUpdate>>, model_update_tx: Option<Sender<ModelUpdate>>,
engine_factory: Option<EngineFactoryCallback>, engine_factory: Option<EngineFactoryCallback>,
...@@ -78,6 +79,7 @@ impl ModelWatcher { ...@@ -78,6 +79,7 @@ impl ModelWatcher {
runtime: DistributedRuntime, runtime: DistributedRuntime,
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
router_config: RouterConfig, router_config: RouterConfig,
migration_limit: u32,
engine_factory: Option<EngineFactoryCallback>, engine_factory: Option<EngineFactoryCallback>,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
) -> ModelWatcher { ) -> ModelWatcher {
...@@ -85,6 +87,7 @@ impl ModelWatcher { ...@@ -85,6 +87,7 @@ impl ModelWatcher {
manager: model_manager, manager: model_manager,
drt: runtime, drt: runtime,
router_config, router_config,
migration_limit,
notify_on_model: Notify::new(), notify_on_model: Notify::new(),
model_update_tx: None, model_update_tx: None,
engine_factory, engine_factory,
...@@ -494,6 +497,7 @@ impl ModelWatcher { ...@@ -494,6 +497,7 @@ impl ModelWatcher {
tokenizer_hf.clone(), tokenizer_hf.clone(),
prefill_chooser.clone(), prefill_chooser.clone(),
self.router_config.enforce_disagg, self.router_config.enforce_disagg,
self.migration_limit,
self.metrics.clone(), self.metrics.clone(),
) )
.await .await
...@@ -529,6 +533,7 @@ impl ModelWatcher { ...@@ -529,6 +533,7 @@ impl ModelWatcher {
tokenizer_hf, tokenizer_hf,
prefill_chooser, prefill_chooser,
self.router_config.enforce_disagg, self.router_config.enforce_disagg,
self.migration_limit,
self.metrics.clone(), self.metrics.clone(),
) )
.await .await
......
...@@ -70,6 +70,7 @@ pub async fn prepare_engine( ...@@ -70,6 +70,7 @@ pub async fn prepare_engine(
distributed_runtime.clone(), distributed_runtime.clone(),
model_manager.clone(), model_manager.clone(),
RouterConfig::default(), RouterConfig::default(),
local_model.migration_limit(),
None, None,
metrics, metrics,
)); ));
...@@ -180,6 +181,7 @@ pub async fn build_routed_pipeline<Req, Resp>( ...@@ -180,6 +181,7 @@ pub async fn build_routed_pipeline<Req, Resp>(
hf_tokenizer: tokenizers::Tokenizer, hf_tokenizer: tokenizers::Tokenizer,
prefill_chooser: Option<Arc<PrefillRouter>>, prefill_chooser: Option<Arc<PrefillRouter>>,
enforce_disagg: bool, enforce_disagg: bool,
migration_limit: u32,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>> ) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where where
...@@ -208,6 +210,7 @@ where ...@@ -208,6 +210,7 @@ where
hf_tokenizer, hf_tokenizer,
prefill_chooser, prefill_chooser,
enforce_disagg, enforce_disagg,
migration_limit,
metrics, metrics,
) )
.await .await
...@@ -225,6 +228,7 @@ pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>( ...@@ -225,6 +228,7 @@ pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>(
hf_tokenizer: tokenizers::Tokenizer, hf_tokenizer: tokenizers::Tokenizer,
prefill_chooser: Option<Arc<PrefillRouter>>, prefill_chooser: Option<Arc<PrefillRouter>>,
enforce_disagg: bool, enforce_disagg: bool,
migration_limit: u32,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>> ) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where where
...@@ -240,7 +244,7 @@ where ...@@ -240,7 +244,7 @@ where
let frontend = SegmentSource::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new(); let frontend = SegmentSource::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
let preprocessor_op = preprocessor.into_operator(); let preprocessor_op = preprocessor.into_operator();
let backend = Backend::from_tokenizer(hf_tokenizer).into_operator(); let backend = Backend::from_tokenizer(hf_tokenizer).into_operator();
let migration = Migration::from_mdc(card, metrics).into_operator(); let migration = Migration::from_mdc(card, migration_limit, metrics).into_operator();
// For KV routing, use the client from the chooser to ensure shared state // For KV routing, use the client from the chooser to ensure shared state
let router_client = if router_mode == RouterMode::KV { let router_client = if router_mode == RouterMode::KV {
......
...@@ -35,6 +35,7 @@ pub async fn run( ...@@ -35,6 +35,7 @@ pub async fn run(
EngineConfig::Dynamic { ref model, .. } => { EngineConfig::Dynamic { ref model, .. } => {
let grpc_service = grpc_service_builder.build()?; let grpc_service = grpc_service_builder.build()?;
let router_config = model.router_config(); let router_config = model.router_config();
let migration_limit = model.migration_limit();
// Listen for models registering themselves, add them to gRPC service // Listen for models registering themselves, add them to gRPC service
let namespace = model.namespace().unwrap_or(""); let namespace = model.namespace().unwrap_or("");
let target_namespace = if is_global_namespace(namespace) { let target_namespace = if is_global_namespace(namespace) {
...@@ -46,6 +47,7 @@ pub async fn run( ...@@ -46,6 +47,7 @@ pub async fn run(
distributed_runtime.clone(), distributed_runtime.clone(),
grpc_service.state().manager_clone(), grpc_service.state().manager_clone(),
router_config.clone(), router_config.clone(),
migration_limit,
target_namespace, target_namespace,
) )
.await?; .await?;
...@@ -109,11 +111,19 @@ async fn run_watcher( ...@@ -109,11 +111,19 @@ async fn run_watcher(
runtime: DistributedRuntime, runtime: DistributedRuntime,
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
router_config: RouterConfig, router_config: RouterConfig,
migration_limit: u32,
target_namespace: Option<String>, target_namespace: Option<String>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// Create metrics for migration tracking (not exposed via /metrics in gRPC mode) // Create metrics for migration tracking (not exposed via /metrics in gRPC mode)
let metrics = Arc::new(Metrics::new()); let metrics = Arc::new(Metrics::new());
let watch_obj = ModelWatcher::new(runtime.clone(), model_manager, router_config, None, metrics); let watch_obj = ModelWatcher::new(
runtime.clone(),
model_manager,
router_config,
migration_limit,
None,
metrics,
);
tracing::debug!("Waiting for remote model"); tracing::debug!("Waiting for remote model");
let discovery = runtime.discovery(); let discovery = runtime.discovery();
let discovery_stream = discovery let discovery_stream = discovery
......
...@@ -61,6 +61,7 @@ pub async fn run( ...@@ -61,6 +61,7 @@ pub async fn run(
let http_service = http_service_builder.build()?; let http_service = http_service_builder.build()?;
let router_config = model.router_config(); let router_config = model.router_config();
let migration_limit = model.migration_limit();
// Listen for models registering themselves, add them to HTTP service // Listen for models registering themselves, add them to HTTP service
// Check if we should filter by namespace (based on the local model's namespace) // Check if we should filter by namespace (based on the local model's namespace)
// Get namespace from the model, fallback to endpoint_id namespace if not set // Get namespace from the model, fallback to endpoint_id namespace if not set
...@@ -74,6 +75,7 @@ pub async fn run( ...@@ -74,6 +75,7 @@ pub async fn run(
distributed_runtime.clone(), distributed_runtime.clone(),
http_service.state().manager_clone(), http_service.state().manager_clone(),
router_config.clone(), router_config.clone(),
migration_limit,
target_namespace, target_namespace,
Arc::new(http_service.clone()), Arc::new(http_service.clone()),
http_service.state().metrics_clone(), http_service.state().metrics_clone(),
...@@ -146,10 +148,12 @@ pub async fn run( ...@@ -146,10 +148,12 @@ pub async fn run(
/// Spawns a task that watches for new models in store, /// Spawns a task that watches for new models in store,
/// and registers them with the ModelManager so that the HTTP service can use them. /// and registers them with the ModelManager so that the HTTP service can use them.
#[allow(clippy::too_many_arguments)]
async fn run_watcher( async fn run_watcher(
runtime: DistributedRuntime, runtime: DistributedRuntime,
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
router_config: RouterConfig, router_config: RouterConfig,
migration_limit: u32,
target_namespace: Option<String>, target_namespace: Option<String>,
http_service: Arc<HttpService>, http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>, metrics: Arc<crate::http::service::metrics::Metrics>,
...@@ -159,6 +163,7 @@ async fn run_watcher( ...@@ -159,6 +163,7 @@ async fn run_watcher(
runtime.clone(), runtime.clone(),
model_manager, model_manager,
router_config, router_config,
migration_limit,
engine_factory, engine_factory,
metrics.clone(), metrics.clone(),
); );
......
...@@ -290,6 +290,7 @@ impl LocalModelBuilder { ...@@ -290,6 +290,7 @@ impl LocalModelBuilder {
router_config: self.router_config.take().unwrap_or_default(), router_config: self.router_config.take().unwrap_or_default(),
runtime_config: self.runtime_config.clone(), runtime_config: self.runtime_config.clone(),
namespace: self.namespace.clone(), namespace: self.namespace.clone(),
migration_limit: self.migration_limit,
}); });
} }
...@@ -341,6 +342,7 @@ impl LocalModelBuilder { ...@@ -341,6 +342,7 @@ impl LocalModelBuilder {
router_config: self.router_config.take().unwrap_or_default(), router_config: self.router_config.take().unwrap_or_default(),
runtime_config: self.runtime_config.clone(), runtime_config: self.runtime_config.clone(),
namespace: self.namespace.clone(), namespace: self.namespace.clone(),
migration_limit: self.migration_limit,
}) })
} }
} }
...@@ -359,6 +361,7 @@ pub struct LocalModel { ...@@ -359,6 +361,7 @@ pub struct LocalModel {
router_config: RouterConfig, router_config: RouterConfig,
runtime_config: ModelRuntimeConfig, runtime_config: ModelRuntimeConfig,
namespace: Option<String>, namespace: Option<String>,
migration_limit: u32,
} }
impl LocalModel { impl LocalModel {
...@@ -422,6 +425,10 @@ impl LocalModel { ...@@ -422,6 +425,10 @@ impl LocalModel {
&self.runtime_config &self.runtime_config
} }
pub fn migration_limit(&self) -> u32 {
self.migration_limit
}
pub fn namespace(&self) -> Option<&str> { pub fn namespace(&self) -> Option<&str> {
self.namespace.as_deref() self.namespace.as_deref()
} }
......
...@@ -28,18 +28,22 @@ pub struct Migration { ...@@ -28,18 +28,22 @@ pub struct Migration {
} }
impl Migration { impl Migration {
pub fn from_mdc(mdc: &ModelDeploymentCard, metrics: Arc<Metrics>) -> Arc<Self> { pub fn new(migration_limit: u32, model_name: String, metrics: Arc<Metrics>) -> Arc<Self> {
tracing::debug!( tracing::debug!("model {} migration limit {}", model_name, migration_limit);
"model {} migration limit {}",
mdc.display_name,
mdc.migration_limit
);
Arc::new(Self { Arc::new(Self {
migration_limit: mdc.migration_limit, migration_limit,
model_name: Arc::new(mdc.display_name.clone()), model_name: Arc::new(model_name),
metrics, metrics,
}) })
} }
pub fn from_mdc(
mdc: &ModelDeploymentCard,
migration_limit: u32,
metrics: Arc<Metrics>,
) -> Arc<Self> {
Self::new(migration_limit, mdc.display_name.clone(), metrics)
}
} }
#[async_trait] #[async_trait]
......
...@@ -339,6 +339,7 @@ mod integration_tests { ...@@ -339,6 +339,7 @@ mod integration_tests {
distributed_runtime.clone(), distributed_runtime.clone(),
service.state().manager_clone(), service.state().manager_clone(),
dynamo_llm::entrypoint::RouterConfig::default(), dynamo_llm::entrypoint::RouterConfig::default(),
0, // migration_limit
None, None,
service.state().metrics_clone(), service.state().metrics_clone(),
); );
...@@ -512,6 +513,7 @@ mod integration_tests { ...@@ -512,6 +513,7 @@ mod integration_tests {
distributed_runtime.clone(), distributed_runtime.clone(),
service.state().manager_clone(), service.state().manager_clone(),
dynamo_llm::entrypoint::RouterConfig::default(), dynamo_llm::entrypoint::RouterConfig::default(),
0, // migration_limit
None, None,
service.state().metrics_clone(), service.state().metrics_clone(),
); );
......
...@@ -62,8 +62,6 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -62,8 +62,6 @@ class DynamoWorkerProcess(ManagedProcess):
system_port = allocate_port(9100) system_port = allocate_port(9100)
self.system_port = system_port self.system_port = system_port
self.frontend_port = frontend_port self.frontend_port = frontend_port
# Prefill workers require migration_limit=0 (no KV cache migration support)
migration_limit = "0" if mode == "prefill" else "3"
command = [ command = [
"python3", "python3",
...@@ -77,8 +75,6 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -77,8 +75,6 @@ class DynamoWorkerProcess(ManagedProcess):
"16384", "16384",
"--max-num-tokens", "--max-num-tokens",
"16384", "16384",
"--migration-limit",
migration_limit,
] ]
if mode != "prefill_and_decode": if mode != "prefill_and_decode":
with open("test_request_cancellation_trtllm_config.yaml", "w") as f: with open("test_request_cancellation_trtllm_config.yaml", "w") as f:
......
...@@ -63,8 +63,6 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -63,8 +63,6 @@ class DynamoWorkerProcess(ManagedProcess):
"0.45", "0.45",
"--max-model-len", "--max-model-len",
"16384", "16384",
"--migration-limit",
"3",
] ]
# Configure health check based on worker type # Configure health check based on worker type
......
...@@ -40,9 +40,6 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -40,9 +40,6 @@ class DynamoWorkerProcess(ManagedProcess):
etcd_endpoints: List of ETCD endpoints for HA etcd_endpoints: List of ETCD endpoints for HA
mode: One of "prefill_and_decode", "prefill", "decode" mode: One of "prefill_and_decode", "prefill", "decode"
""" """
# Prefill workers require migration_limit=0 (no KV cache migration support)
migration_limit = "0" if mode == "prefill" else "3"
command = [ command = [
"python3", "python3",
"-m", "-m",
...@@ -55,8 +52,6 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -55,8 +52,6 @@ class DynamoWorkerProcess(ManagedProcess):
"0.45", "0.45",
"--max-seq-len", "--max-seq-len",
"8192", "8192",
"--migration-limit",
migration_limit,
] ]
# Add disaggregation-specific configuration # Add disaggregation-specific configuration
......
...@@ -82,7 +82,6 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -82,7 +82,6 @@ class DynamoWorkerProcess(ManagedProcess):
request: pytest request fixture request: pytest request fixture
worker_id: Unique identifier for the worker (e.g., "worker1", "worker2") worker_id: Unique identifier for the worker (e.g., "worker1", "worker2")
frontend_port: Port where the frontend is running frontend_port: Port where the frontend is running
migration_limit: Maximum number of migration attempts (default: 3)
disagg_mode: None for aggregated, "prefill" or "decode" for disaggregated disagg_mode: None for aggregated, "prefill" or "decode" for disaggregated
""" """
...@@ -91,18 +90,12 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -91,18 +90,12 @@ class DynamoWorkerProcess(ManagedProcess):
request, request,
worker_id: str, worker_id: str,
frontend_port: int, frontend_port: int,
migration_limit: int = 3,
disagg_mode: str | None = None, disagg_mode: str | None = None,
): ):
self.worker_id = worker_id self.worker_id = worker_id
self.system_port = allocate_port(9100) self.system_port = allocate_port(9100)
self.disagg_mode = disagg_mode self.disagg_mode = disagg_mode
# Prefill workers require migration_limit=0 (no KV cache migration support)
if disagg_mode == "prefill":
logging.info("Prefill worker - setting migration_limit to 0")
migration_limit = 0
command = [ command = [
"python3", "python3",
"-m", "-m",
...@@ -120,8 +113,6 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -120,8 +113,6 @@ class DynamoWorkerProcess(ManagedProcess):
"0.3", "0.3",
"--context-length", "--context-length",
"8192", "8192",
"--migration-limit",
str(migration_limit),
] ]
if disagg_mode is None: if disagg_mode is None:
# Aggregated # Aggregated
...@@ -237,20 +228,17 @@ def test_request_migration_sglang_aggregated( ...@@ -237,20 +228,17 @@ def test_request_migration_sglang_aggregated(
""" """
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend: with DynamoFrontendProcess(request, migration_limit=migration_limit) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start 2 workers # Step 2: Start 2 workers
with DynamoWorkerProcess( with DynamoWorkerProcess(request, "worker1", frontend.frontend_port) as worker1:
request, "worker1", frontend.frontend_port, migration_limit=migration_limit
) as worker1:
logger.info(f"Worker 1 PID: {worker1.get_pid()}") logger.info(f"Worker 1 PID: {worker1.get_pid()}")
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, request,
"worker2", "worker2",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
) as worker2: ) as worker2:
logger.info(f"Worker 2 PID: {worker2.get_pid()}") logger.info(f"Worker 2 PID: {worker2.get_pid()}")
...@@ -293,7 +281,9 @@ def test_request_migration_sglang_prefill( ...@@ -293,7 +281,9 @@ def test_request_migration_sglang_prefill(
""" """
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend: with DynamoFrontendProcess(
request, migration_limit=migration_limit, enforce_disagg=True
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start decode worker first (required for prefill workers to connect) # Step 2: Start decode worker first (required for prefill workers to connect)
...@@ -301,7 +291,6 @@ def test_request_migration_sglang_prefill( ...@@ -301,7 +291,6 @@ def test_request_migration_sglang_prefill(
request, request,
"worker0", "worker0",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
disagg_mode="decode", disagg_mode="decode",
) as decode_worker: ) as decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}") logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
...@@ -311,7 +300,6 @@ def test_request_migration_sglang_prefill( ...@@ -311,7 +300,6 @@ def test_request_migration_sglang_prefill(
request, request,
"worker1", "worker1",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
disagg_mode="prefill", disagg_mode="prefill",
) as prefill1: ) as prefill1:
logger.info(f"Prefill Worker 1 PID: {prefill1.get_pid()}") logger.info(f"Prefill Worker 1 PID: {prefill1.get_pid()}")
...@@ -320,7 +308,6 @@ def test_request_migration_sglang_prefill( ...@@ -320,7 +308,6 @@ def test_request_migration_sglang_prefill(
request, request,
"worker2", "worker2",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
disagg_mode="prefill", disagg_mode="prefill",
) as prefill2: ) as prefill2:
logger.info(f"Prefill Worker 2 PID: {prefill2.get_pid()}") logger.info(f"Prefill Worker 2 PID: {prefill2.get_pid()}")
...@@ -364,7 +351,9 @@ def test_request_migration_sglang_kv_transfer( ...@@ -364,7 +351,9 @@ def test_request_migration_sglang_kv_transfer(
""" """
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend: with DynamoFrontendProcess(
request, migration_limit=migration_limit, enforce_disagg=True
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start prefill worker first # Step 2: Start prefill worker first
...@@ -372,7 +361,6 @@ def test_request_migration_sglang_kv_transfer( ...@@ -372,7 +361,6 @@ def test_request_migration_sglang_kv_transfer(
request, request,
"worker0", "worker0",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
disagg_mode="prefill", disagg_mode="prefill",
) as prefill_worker: ) as prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}") logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
...@@ -382,7 +370,6 @@ def test_request_migration_sglang_kv_transfer( ...@@ -382,7 +370,6 @@ def test_request_migration_sglang_kv_transfer(
request, request,
"worker1", "worker1",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
disagg_mode="decode", disagg_mode="decode",
) as decode1: ) as decode1:
logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}") logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")
...@@ -391,7 +378,6 @@ def test_request_migration_sglang_kv_transfer( ...@@ -391,7 +378,6 @@ def test_request_migration_sglang_kv_transfer(
request, request,
"worker2", "worker2",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
disagg_mode="decode", disagg_mode="decode",
) as decode2: ) as decode2:
logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}") logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")
...@@ -438,7 +424,9 @@ def test_request_migration_sglang_decode( ...@@ -438,7 +424,9 @@ def test_request_migration_sglang_decode(
) )
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend: with DynamoFrontendProcess(
request, migration_limit=migration_limit, enforce_disagg=True
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start prefill worker first # Step 2: Start prefill worker first
...@@ -446,7 +434,6 @@ def test_request_migration_sglang_decode( ...@@ -446,7 +434,6 @@ def test_request_migration_sglang_decode(
request, request,
"worker0", "worker0",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
disagg_mode="prefill", disagg_mode="prefill",
) as prefill_worker: ) as prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}") logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
...@@ -456,7 +443,6 @@ def test_request_migration_sglang_decode( ...@@ -456,7 +443,6 @@ def test_request_migration_sglang_decode(
request, request,
"worker1", "worker1",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
disagg_mode="decode", disagg_mode="decode",
) as decode1: ) as decode1:
logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}") logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")
...@@ -465,7 +451,6 @@ def test_request_migration_sglang_decode( ...@@ -465,7 +451,6 @@ def test_request_migration_sglang_decode(
request, request,
"worker2", "worker2",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
disagg_mode="decode", disagg_mode="decode",
) as decode2: ) as decode2:
logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}") logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")
......
...@@ -82,7 +82,6 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -82,7 +82,6 @@ class DynamoWorkerProcess(ManagedProcess):
request: pytest request fixture request: pytest request fixture
worker_id: Unique identifier for the worker (e.g., "worker1", "prefill1") worker_id: Unique identifier for the worker (e.g., "worker1", "prefill1")
frontend_port: Port where the frontend is running frontend_port: Port where the frontend is running
migration_limit: Maximum number of migration attempts (default: 3)
mode: "prefill_and_decode" for aggregated, "prefill" or "decode" for disaggregated mode: "prefill_and_decode" for aggregated, "prefill" or "decode" for disaggregated
""" """
...@@ -91,18 +90,12 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -91,18 +90,12 @@ class DynamoWorkerProcess(ManagedProcess):
request, request,
worker_id: str, worker_id: str,
frontend_port: int, frontend_port: int,
migration_limit: int = 3,
mode: str = "prefill_and_decode", mode: str = "prefill_and_decode",
): ):
self.worker_id = worker_id self.worker_id = worker_id
self.system_port = allocate_port(9100) self.system_port = allocate_port(9100)
self.mode = mode self.mode = mode
# Prefill workers require migration_limit=0 (no KV cache migration support)
if mode == "prefill":
logging.info("Prefill worker - setting migration_limit to 0")
migration_limit = 0
command = [ command = [
"python3", "python3",
"-m", "-m",
...@@ -117,8 +110,6 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -117,8 +110,6 @@ class DynamoWorkerProcess(ManagedProcess):
"8192", "8192",
"--free-gpu-memory-fraction", "--free-gpu-memory-fraction",
"0.15", # avoid validation error on TRT-LLM available memory checks "0.15", # avoid validation error on TRT-LLM available memory checks
"--migration-limit",
str(migration_limit),
] ]
if mode != "prefill_and_decode": if mode != "prefill_and_decode":
config_file = ( config_file = (
...@@ -225,20 +216,17 @@ def test_request_migration_trtllm_aggregated( ...@@ -225,20 +216,17 @@ def test_request_migration_trtllm_aggregated(
""" """
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend: with DynamoFrontendProcess(request, migration_limit=migration_limit) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start 2 workers # Step 2: Start 2 workers
with DynamoWorkerProcess( with DynamoWorkerProcess(request, "worker1", frontend.frontend_port) as worker1:
request, "worker1", frontend.frontend_port, migration_limit=migration_limit
) as worker1:
logger.info(f"Worker 1 PID: {worker1.get_pid()}") logger.info(f"Worker 1 PID: {worker1.get_pid()}")
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, request,
"worker2", "worker2",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
) as worker2: ) as worker2:
logger.info(f"Worker 2 PID: {worker2.get_pid()}") logger.info(f"Worker 2 PID: {worker2.get_pid()}")
...@@ -280,7 +268,9 @@ def test_request_migration_trtllm_prefill( ...@@ -280,7 +268,9 @@ def test_request_migration_trtllm_prefill(
""" """
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend: with DynamoFrontendProcess(
request, migration_limit=migration_limit, enforce_disagg=True
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start decode worker first (required for prefill workers to connect) # Step 2: Start decode worker first (required for prefill workers to connect)
...@@ -288,7 +278,6 @@ def test_request_migration_trtllm_prefill( ...@@ -288,7 +278,6 @@ def test_request_migration_trtllm_prefill(
request, request,
"worker0", "worker0",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
mode="decode", mode="decode",
) as decode_worker: ) as decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}") logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
...@@ -298,7 +287,6 @@ def test_request_migration_trtllm_prefill( ...@@ -298,7 +287,6 @@ def test_request_migration_trtllm_prefill(
request, request,
"worker1", "worker1",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
mode="prefill", mode="prefill",
) as prefill1: ) as prefill1:
logger.info(f"Prefill Worker 1 PID: {prefill1.get_pid()}") logger.info(f"Prefill Worker 1 PID: {prefill1.get_pid()}")
...@@ -307,7 +295,6 @@ def test_request_migration_trtllm_prefill( ...@@ -307,7 +295,6 @@ def test_request_migration_trtllm_prefill(
request, request,
"worker2", "worker2",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
mode="prefill", mode="prefill",
) as prefill2: ) as prefill2:
logger.info(f"Prefill Worker 2 PID: {prefill2.get_pid()}") logger.info(f"Prefill Worker 2 PID: {prefill2.get_pid()}")
...@@ -351,7 +338,9 @@ def test_request_migration_trtllm_kv_transfer( ...@@ -351,7 +338,9 @@ def test_request_migration_trtllm_kv_transfer(
""" """
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend: with DynamoFrontendProcess(
request, migration_limit=migration_limit, enforce_disagg=True
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start prefill worker first # Step 2: Start prefill worker first
...@@ -359,7 +348,6 @@ def test_request_migration_trtllm_kv_transfer( ...@@ -359,7 +348,6 @@ def test_request_migration_trtllm_kv_transfer(
request, request,
"worker0", "worker0",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
mode="prefill", mode="prefill",
) as prefill_worker: ) as prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}") logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
...@@ -369,7 +357,6 @@ def test_request_migration_trtllm_kv_transfer( ...@@ -369,7 +357,6 @@ def test_request_migration_trtllm_kv_transfer(
request, request,
"worker1", "worker1",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
mode="decode", mode="decode",
) as decode1: ) as decode1:
logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}") logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")
...@@ -378,7 +365,6 @@ def test_request_migration_trtllm_kv_transfer( ...@@ -378,7 +365,6 @@ def test_request_migration_trtllm_kv_transfer(
request, request,
"worker2", "worker2",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
mode="decode", mode="decode",
) as decode2: ) as decode2:
logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}") logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")
...@@ -425,7 +411,9 @@ def test_request_migration_trtllm_decode( ...@@ -425,7 +411,9 @@ def test_request_migration_trtllm_decode(
) )
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend: with DynamoFrontendProcess(
request, migration_limit=migration_limit, enforce_disagg=True
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start prefill worker first # Step 2: Start prefill worker first
...@@ -433,7 +421,6 @@ def test_request_migration_trtllm_decode( ...@@ -433,7 +421,6 @@ def test_request_migration_trtllm_decode(
request, request,
"worker0", "worker0",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
mode="prefill", mode="prefill",
) as prefill_worker: ) as prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}") logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
...@@ -443,7 +430,6 @@ def test_request_migration_trtllm_decode( ...@@ -443,7 +430,6 @@ def test_request_migration_trtllm_decode(
request, request,
"worker1", "worker1",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
mode="decode", mode="decode",
) as decode1: ) as decode1:
logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}") logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")
...@@ -452,7 +438,6 @@ def test_request_migration_trtllm_decode( ...@@ -452,7 +438,6 @@ def test_request_migration_trtllm_decode(
request, request,
"worker2", "worker2",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
mode="decode", mode="decode",
) as decode2: ) as decode2:
logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}") logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")
......
...@@ -72,7 +72,6 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -72,7 +72,6 @@ class DynamoWorkerProcess(ManagedProcess):
request: pytest request fixture request: pytest request fixture
worker_id: Unique identifier for the worker (e.g., "worker1", "prefill1") worker_id: Unique identifier for the worker (e.g., "worker1", "prefill1")
frontend_port: Port where the frontend is running frontend_port: Port where the frontend is running
migration_limit: Maximum number of migration attempts (default: 3)
is_prefill: None for aggregated mode, True for prefill worker, False for decode worker is_prefill: None for aggregated mode, True for prefill worker, False for decode worker
""" """
...@@ -81,7 +80,6 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -81,7 +80,6 @@ class DynamoWorkerProcess(ManagedProcess):
request, request,
worker_id: str, worker_id: str,
frontend_port: int, frontend_port: int,
migration_limit: int = 3,
is_prefill: bool | None = None, is_prefill: bool | None = None,
): ):
self.worker_id = worker_id self.worker_id = worker_id
...@@ -102,8 +100,6 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -102,8 +100,6 @@ class DynamoWorkerProcess(ManagedProcess):
"512", # 8192 tokens x 1 context / 16 tokens per block = 512 blocks "512", # 8192 tokens x 1 context / 16 tokens per block = 512 blocks
"--gpu-memory-utilization", "--gpu-memory-utilization",
"0.15", # avoid assertion error on vLLM available memory checks "0.15", # avoid assertion error on vLLM available memory checks
"--migration-limit",
str(migration_limit),
] ]
if is_prefill is True: if is_prefill is True:
command.append("--is-prefill-worker") command.append("--is-prefill-worker")
...@@ -221,20 +217,17 @@ def test_request_migration_vllm_aggregated( ...@@ -221,20 +217,17 @@ def test_request_migration_vllm_aggregated(
""" """
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend: with DynamoFrontendProcess(request, migration_limit=migration_limit) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start 2 workers # Step 2: Start 2 workers
with DynamoWorkerProcess( with DynamoWorkerProcess(request, "worker1", frontend.frontend_port) as worker1:
request, "worker1", frontend.frontend_port, migration_limit=migration_limit
) as worker1:
logger.info(f"Worker 1 PID: {worker1.get_pid()}") logger.info(f"Worker 1 PID: {worker1.get_pid()}")
with DynamoWorkerProcess( with DynamoWorkerProcess(
request, request,
"worker2", "worker2",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
) as worker2: ) as worker2:
logger.info(f"Worker 2 PID: {worker2.get_pid()}") logger.info(f"Worker 2 PID: {worker2.get_pid()}")
...@@ -276,7 +269,9 @@ def test_request_migration_vllm_prefill( ...@@ -276,7 +269,9 @@ def test_request_migration_vllm_prefill(
""" """
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend: with DynamoFrontendProcess(
request, migration_limit=migration_limit, enforce_disagg=True
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start decode worker first (required for prefill workers to connect) # Step 2: Start decode worker first (required for prefill workers to connect)
...@@ -284,7 +279,6 @@ def test_request_migration_vllm_prefill( ...@@ -284,7 +279,6 @@ def test_request_migration_vllm_prefill(
request, request,
"worker0", "worker0",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=False, is_prefill=False,
) as decode_worker: ) as decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}") logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
...@@ -294,7 +288,6 @@ def test_request_migration_vllm_prefill( ...@@ -294,7 +288,6 @@ def test_request_migration_vllm_prefill(
request, request,
"worker1", "worker1",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=True, is_prefill=True,
) as prefill1: ) as prefill1:
logger.info(f"Prefill Worker 1 PID: {prefill1.get_pid()}") logger.info(f"Prefill Worker 1 PID: {prefill1.get_pid()}")
...@@ -303,7 +296,6 @@ def test_request_migration_vllm_prefill( ...@@ -303,7 +296,6 @@ def test_request_migration_vllm_prefill(
request, request,
"worker2", "worker2",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=True, is_prefill=True,
) as prefill2: ) as prefill2:
logger.info(f"Prefill Worker 2 PID: {prefill2.get_pid()}") logger.info(f"Prefill Worker 2 PID: {prefill2.get_pid()}")
...@@ -356,7 +348,9 @@ def test_request_migration_vllm_kv_transfer( ...@@ -356,7 +348,9 @@ def test_request_migration_vllm_kv_transfer(
""" """
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend: with DynamoFrontendProcess(
request, migration_limit=migration_limit, enforce_disagg=True
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start prefill worker first # Step 2: Start prefill worker first
...@@ -364,7 +358,6 @@ def test_request_migration_vllm_kv_transfer( ...@@ -364,7 +358,6 @@ def test_request_migration_vllm_kv_transfer(
request, request,
"worker0", "worker0",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=True, is_prefill=True,
) as prefill_worker: ) as prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}") logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
...@@ -374,7 +367,6 @@ def test_request_migration_vllm_kv_transfer( ...@@ -374,7 +367,6 @@ def test_request_migration_vllm_kv_transfer(
request, request,
"worker1", "worker1",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=False, is_prefill=False,
) as decode1: ) as decode1:
logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}") logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")
...@@ -383,7 +375,6 @@ def test_request_migration_vllm_kv_transfer( ...@@ -383,7 +375,6 @@ def test_request_migration_vllm_kv_transfer(
request, request,
"worker2", "worker2",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=False, is_prefill=False,
) as decode2: ) as decode2:
logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}") logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")
...@@ -440,7 +431,9 @@ def test_request_migration_vllm_decode( ...@@ -440,7 +431,9 @@ def test_request_migration_vllm_decode(
) )
# Step 1: Start the frontend # Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend: with DynamoFrontendProcess(
request, migration_limit=migration_limit, enforce_disagg=True
) as frontend:
logger.info("Frontend started successfully") logger.info("Frontend started successfully")
# Step 2: Start prefill worker first # Step 2: Start prefill worker first
...@@ -448,7 +441,6 @@ def test_request_migration_vllm_decode( ...@@ -448,7 +441,6 @@ def test_request_migration_vllm_decode(
request, request,
"worker0", "worker0",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=True, is_prefill=True,
) as prefill_worker: ) as prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}") logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
...@@ -458,7 +450,6 @@ def test_request_migration_vllm_decode( ...@@ -458,7 +450,6 @@ def test_request_migration_vllm_decode(
request, request,
"worker1", "worker1",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=False, is_prefill=False,
) as decode1: ) as decode1:
logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}") logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")
...@@ -467,7 +458,6 @@ def test_request_migration_vllm_decode( ...@@ -467,7 +458,6 @@ def test_request_migration_vllm_decode(
request, request,
"worker2", "worker2",
frontend.frontend_port, frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=False, is_prefill=False,
) as decode2: ) as decode2:
logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}") logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")
......
...@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) ...@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class DynamoFrontendProcess(BaseDynamoFrontendProcess): class DynamoFrontendProcess(BaseDynamoFrontendProcess):
"""Fault-tolerance frontend wrapper (keeps env settings from the historical helper).""" """Fault-tolerance frontend wrapper (keeps env settings from the historical helper)."""
def __init__(self, request, enforce_disagg: bool = False): def __init__(self, request, migration_limit: int, enforce_disagg: bool = False):
extra_env = { extra_env = {
"DYN_REQUEST_PLANE": request.getfixturevalue("request_plane"), "DYN_REQUEST_PLANE": request.getfixturevalue("request_plane"),
# These tests expect full control over requests sent to workers. The canary # These tests expect full control over requests sent to workers. The canary
...@@ -35,6 +35,7 @@ class DynamoFrontendProcess(BaseDynamoFrontendProcess): ...@@ -35,6 +35,7 @@ class DynamoFrontendProcess(BaseDynamoFrontendProcess):
request, request,
frontend_port=0, # allocate a free port (xdist-safe) frontend_port=0, # allocate a free port (xdist-safe)
router_mode="round-robin", router_mode="round-robin",
migration_limit=migration_limit,
extra_args=extra_args if extra_args else None, extra_args=extra_args if extra_args else None,
extra_env=extra_env, extra_env=extra_env,
terminate_all_matching_process_names=False, terminate_all_matching_process_names=False,
......
...@@ -32,8 +32,6 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -32,8 +32,6 @@ class DynamoWorkerProcess(ManagedProcess):
"--enforce-eager", "--enforce-eager",
"--max-model-len", "--max-model-len",
"8192", "8192",
"--migration-limit",
"3",
] ]
# Set debug logging environment # Set debug logging environment
......
...@@ -721,6 +721,7 @@ class DynamoFrontendProcess(ManagedProcess): ...@@ -721,6 +721,7 @@ class DynamoFrontendProcess(ManagedProcess):
*, *,
frontend_port: Optional[int] = None, frontend_port: Optional[int] = None,
router_mode: str = "round-robin", router_mode: str = "round-robin",
migration_limit: int = 0,
extra_args: Optional[list[str]] = None, extra_args: Optional[list[str]] = None,
extra_env: Optional[dict[str, str]] = None, extra_env: Optional[dict[str, str]] = None,
# Default to false so pytest-xdist workers don't kill each other's frontends. # Default to false so pytest-xdist workers don't kill each other's frontends.
...@@ -750,6 +751,8 @@ class DynamoFrontendProcess(ManagedProcess): ...@@ -750,6 +751,8 @@ class DynamoFrontendProcess(ManagedProcess):
# dynamo.frontend defaults to 8000 when neither env nor flag is provided. # dynamo.frontend defaults to 8000 when neither env nor flag is provided.
if frontend_port is not None: if frontend_port is not None:
command.extend(["--http-port", str(frontend_port)]) command.extend(["--http-port", str(frontend_port)])
# Migration limit is configured at the frontend level
command.extend(["--migration-limit", str(migration_limit)])
if extra_args: if extra_args:
command.extend(extra_args) command.extend(extra_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment