"lib/llm/src/vscode:/vscode.git/clone" did not exist on "3e015595ccb8c85e35dba6b836be8f721b71da6f"
Unverified Commit 1ffa489e authored by Jacky's avatar Jacky Committed by GitHub
Browse files

refactor: Move --migration-limit flag from backend to frontend (#5918)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent 3842b244
......@@ -58,6 +58,7 @@ pub struct ModelWatcher {
manager: Arc<ModelManager>,
drt: DistributedRuntime,
router_config: RouterConfig,
migration_limit: u32,
notify_on_model: Notify,
model_update_tx: Option<Sender<ModelUpdate>>,
engine_factory: Option<EngineFactoryCallback>,
......@@ -78,6 +79,7 @@ impl ModelWatcher {
runtime: DistributedRuntime,
model_manager: Arc<ModelManager>,
router_config: RouterConfig,
migration_limit: u32,
engine_factory: Option<EngineFactoryCallback>,
metrics: Arc<Metrics>,
) -> ModelWatcher {
......@@ -85,6 +87,7 @@ impl ModelWatcher {
manager: model_manager,
drt: runtime,
router_config,
migration_limit,
notify_on_model: Notify::new(),
model_update_tx: None,
engine_factory,
......@@ -494,6 +497,7 @@ impl ModelWatcher {
tokenizer_hf.clone(),
prefill_chooser.clone(),
self.router_config.enforce_disagg,
self.migration_limit,
self.metrics.clone(),
)
.await
......@@ -529,6 +533,7 @@ impl ModelWatcher {
tokenizer_hf,
prefill_chooser,
self.router_config.enforce_disagg,
self.migration_limit,
self.metrics.clone(),
)
.await
......
......@@ -70,6 +70,7 @@ pub async fn prepare_engine(
distributed_runtime.clone(),
model_manager.clone(),
RouterConfig::default(),
local_model.migration_limit(),
None,
metrics,
));
......@@ -180,6 +181,7 @@ pub async fn build_routed_pipeline<Req, Resp>(
hf_tokenizer: tokenizers::Tokenizer,
prefill_chooser: Option<Arc<PrefillRouter>>,
enforce_disagg: bool,
migration_limit: u32,
metrics: Arc<Metrics>,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where
......@@ -208,6 +210,7 @@ where
hf_tokenizer,
prefill_chooser,
enforce_disagg,
migration_limit,
metrics,
)
.await
......@@ -225,6 +228,7 @@ pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>(
hf_tokenizer: tokenizers::Tokenizer,
prefill_chooser: Option<Arc<PrefillRouter>>,
enforce_disagg: bool,
migration_limit: u32,
metrics: Arc<Metrics>,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where
......@@ -240,7 +244,7 @@ where
let frontend = SegmentSource::<SingleIn<Req>, ManyOut<Annotated<Resp>>>::new();
let preprocessor_op = preprocessor.into_operator();
let backend = Backend::from_tokenizer(hf_tokenizer).into_operator();
let migration = Migration::from_mdc(card, metrics).into_operator();
let migration = Migration::from_mdc(card, migration_limit, metrics).into_operator();
// For KV routing, use the client from the chooser to ensure shared state
let router_client = if router_mode == RouterMode::KV {
......
......@@ -35,6 +35,7 @@ pub async fn run(
EngineConfig::Dynamic { ref model, .. } => {
let grpc_service = grpc_service_builder.build()?;
let router_config = model.router_config();
let migration_limit = model.migration_limit();
// Listen for models registering themselves, add them to gRPC service
let namespace = model.namespace().unwrap_or("");
let target_namespace = if is_global_namespace(namespace) {
......@@ -46,6 +47,7 @@ pub async fn run(
distributed_runtime.clone(),
grpc_service.state().manager_clone(),
router_config.clone(),
migration_limit,
target_namespace,
)
.await?;
......@@ -109,11 +111,19 @@ async fn run_watcher(
runtime: DistributedRuntime,
model_manager: Arc<ModelManager>,
router_config: RouterConfig,
migration_limit: u32,
target_namespace: Option<String>,
) -> anyhow::Result<()> {
// Create metrics for migration tracking (not exposed via /metrics in gRPC mode)
let metrics = Arc::new(Metrics::new());
let watch_obj = ModelWatcher::new(runtime.clone(), model_manager, router_config, None, metrics);
let watch_obj = ModelWatcher::new(
runtime.clone(),
model_manager,
router_config,
migration_limit,
None,
metrics,
);
tracing::debug!("Waiting for remote model");
let discovery = runtime.discovery();
let discovery_stream = discovery
......
......@@ -61,6 +61,7 @@ pub async fn run(
let http_service = http_service_builder.build()?;
let router_config = model.router_config();
let migration_limit = model.migration_limit();
// Listen for models registering themselves, add them to HTTP service
// Check if we should filter by namespace (based on the local model's namespace)
// Get namespace from the model, fallback to endpoint_id namespace if not set
......@@ -74,6 +75,7 @@ pub async fn run(
distributed_runtime.clone(),
http_service.state().manager_clone(),
router_config.clone(),
migration_limit,
target_namespace,
Arc::new(http_service.clone()),
http_service.state().metrics_clone(),
......@@ -146,10 +148,12 @@ pub async fn run(
/// Spawns a task that watches for new models in store,
/// and registers them with the ModelManager so that the HTTP service can use them.
#[allow(clippy::too_many_arguments)]
async fn run_watcher(
runtime: DistributedRuntime,
model_manager: Arc<ModelManager>,
router_config: RouterConfig,
migration_limit: u32,
target_namespace: Option<String>,
http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>,
......@@ -159,6 +163,7 @@ async fn run_watcher(
runtime.clone(),
model_manager,
router_config,
migration_limit,
engine_factory,
metrics.clone(),
);
......
......@@ -290,6 +290,7 @@ impl LocalModelBuilder {
router_config: self.router_config.take().unwrap_or_default(),
runtime_config: self.runtime_config.clone(),
namespace: self.namespace.clone(),
migration_limit: self.migration_limit,
});
}
......@@ -341,6 +342,7 @@ impl LocalModelBuilder {
router_config: self.router_config.take().unwrap_or_default(),
runtime_config: self.runtime_config.clone(),
namespace: self.namespace.clone(),
migration_limit: self.migration_limit,
})
}
}
......@@ -359,6 +361,7 @@ pub struct LocalModel {
router_config: RouterConfig,
runtime_config: ModelRuntimeConfig,
namespace: Option<String>,
migration_limit: u32,
}
impl LocalModel {
......@@ -422,6 +425,10 @@ impl LocalModel {
&self.runtime_config
}
pub fn migration_limit(&self) -> u32 {
self.migration_limit
}
pub fn namespace(&self) -> Option<&str> {
self.namespace.as_deref()
}
......
......@@ -28,18 +28,22 @@ pub struct Migration {
}
impl Migration {
pub fn from_mdc(mdc: &ModelDeploymentCard, metrics: Arc<Metrics>) -> Arc<Self> {
tracing::debug!(
"model {} migration limit {}",
mdc.display_name,
mdc.migration_limit
);
pub fn new(migration_limit: u32, model_name: String, metrics: Arc<Metrics>) -> Arc<Self> {
tracing::debug!("model {} migration limit {}", model_name, migration_limit);
Arc::new(Self {
migration_limit: mdc.migration_limit,
model_name: Arc::new(mdc.display_name.clone()),
migration_limit,
model_name: Arc::new(model_name),
metrics,
})
}
pub fn from_mdc(
mdc: &ModelDeploymentCard,
migration_limit: u32,
metrics: Arc<Metrics>,
) -> Arc<Self> {
Self::new(migration_limit, mdc.display_name.clone(), metrics)
}
}
#[async_trait]
......
......@@ -339,6 +339,7 @@ mod integration_tests {
distributed_runtime.clone(),
service.state().manager_clone(),
dynamo_llm::entrypoint::RouterConfig::default(),
0, // migration_limit
None,
service.state().metrics_clone(),
);
......@@ -512,6 +513,7 @@ mod integration_tests {
distributed_runtime.clone(),
service.state().manager_clone(),
dynamo_llm::entrypoint::RouterConfig::default(),
0, // migration_limit
None,
service.state().metrics_clone(),
);
......
......@@ -62,8 +62,6 @@ class DynamoWorkerProcess(ManagedProcess):
system_port = allocate_port(9100)
self.system_port = system_port
self.frontend_port = frontend_port
# Prefill workers require migration_limit=0 (no KV cache migration support)
migration_limit = "0" if mode == "prefill" else "3"
command = [
"python3",
......@@ -77,8 +75,6 @@ class DynamoWorkerProcess(ManagedProcess):
"16384",
"--max-num-tokens",
"16384",
"--migration-limit",
migration_limit,
]
if mode != "prefill_and_decode":
with open("test_request_cancellation_trtllm_config.yaml", "w") as f:
......
......@@ -63,8 +63,6 @@ class DynamoWorkerProcess(ManagedProcess):
"0.45",
"--max-model-len",
"16384",
"--migration-limit",
"3",
]
# Configure health check based on worker type
......
......@@ -40,9 +40,6 @@ class DynamoWorkerProcess(ManagedProcess):
etcd_endpoints: List of ETCD endpoints for HA
mode: One of "prefill_and_decode", "prefill", "decode"
"""
# Prefill workers require migration_limit=0 (no KV cache migration support)
migration_limit = "0" if mode == "prefill" else "3"
command = [
"python3",
"-m",
......@@ -55,8 +52,6 @@ class DynamoWorkerProcess(ManagedProcess):
"0.45",
"--max-seq-len",
"8192",
"--migration-limit",
migration_limit,
]
# Add disaggregation-specific configuration
......
......@@ -82,7 +82,6 @@ class DynamoWorkerProcess(ManagedProcess):
request: pytest request fixture
worker_id: Unique identifier for the worker (e.g., "worker1", "worker2")
frontend_port: Port where the frontend is running
migration_limit: Maximum number of migration attempts (default: 3)
disagg_mode: None for aggregated, "prefill" or "decode" for disaggregated
"""
......@@ -91,18 +90,12 @@ class DynamoWorkerProcess(ManagedProcess):
request,
worker_id: str,
frontend_port: int,
migration_limit: int = 3,
disagg_mode: str | None = None,
):
self.worker_id = worker_id
self.system_port = allocate_port(9100)
self.disagg_mode = disagg_mode
# Prefill workers require migration_limit=0 (no KV cache migration support)
if disagg_mode == "prefill":
logging.info("Prefill worker - setting migration_limit to 0")
migration_limit = 0
command = [
"python3",
"-m",
......@@ -120,8 +113,6 @@ class DynamoWorkerProcess(ManagedProcess):
"0.3",
"--context-length",
"8192",
"--migration-limit",
str(migration_limit),
]
if disagg_mode is None:
# Aggregated
......@@ -237,20 +228,17 @@ def test_request_migration_sglang_aggregated(
"""
# Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend:
with DynamoFrontendProcess(request, migration_limit=migration_limit) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start 2 workers
with DynamoWorkerProcess(
request, "worker1", frontend.frontend_port, migration_limit=migration_limit
) as worker1:
with DynamoWorkerProcess(request, "worker1", frontend.frontend_port) as worker1:
logger.info(f"Worker 1 PID: {worker1.get_pid()}")
with DynamoWorkerProcess(
request,
"worker2",
frontend.frontend_port,
migration_limit=migration_limit,
) as worker2:
logger.info(f"Worker 2 PID: {worker2.get_pid()}")
......@@ -293,7 +281,9 @@ def test_request_migration_sglang_prefill(
"""
# Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend:
with DynamoFrontendProcess(
request, migration_limit=migration_limit, enforce_disagg=True
) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start decode worker first (required for prefill workers to connect)
......@@ -301,7 +291,6 @@ def test_request_migration_sglang_prefill(
request,
"worker0",
frontend.frontend_port,
migration_limit=migration_limit,
disagg_mode="decode",
) as decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
......@@ -311,7 +300,6 @@ def test_request_migration_sglang_prefill(
request,
"worker1",
frontend.frontend_port,
migration_limit=migration_limit,
disagg_mode="prefill",
) as prefill1:
logger.info(f"Prefill Worker 1 PID: {prefill1.get_pid()}")
......@@ -320,7 +308,6 @@ def test_request_migration_sglang_prefill(
request,
"worker2",
frontend.frontend_port,
migration_limit=migration_limit,
disagg_mode="prefill",
) as prefill2:
logger.info(f"Prefill Worker 2 PID: {prefill2.get_pid()}")
......@@ -364,7 +351,9 @@ def test_request_migration_sglang_kv_transfer(
"""
# Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend:
with DynamoFrontendProcess(
request, migration_limit=migration_limit, enforce_disagg=True
) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start prefill worker first
......@@ -372,7 +361,6 @@ def test_request_migration_sglang_kv_transfer(
request,
"worker0",
frontend.frontend_port,
migration_limit=migration_limit,
disagg_mode="prefill",
) as prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
......@@ -382,7 +370,6 @@ def test_request_migration_sglang_kv_transfer(
request,
"worker1",
frontend.frontend_port,
migration_limit=migration_limit,
disagg_mode="decode",
) as decode1:
logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")
......@@ -391,7 +378,6 @@ def test_request_migration_sglang_kv_transfer(
request,
"worker2",
frontend.frontend_port,
migration_limit=migration_limit,
disagg_mode="decode",
) as decode2:
logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")
......@@ -438,7 +424,9 @@ def test_request_migration_sglang_decode(
)
# Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend:
with DynamoFrontendProcess(
request, migration_limit=migration_limit, enforce_disagg=True
) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start prefill worker first
......@@ -446,7 +434,6 @@ def test_request_migration_sglang_decode(
request,
"worker0",
frontend.frontend_port,
migration_limit=migration_limit,
disagg_mode="prefill",
) as prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
......@@ -456,7 +443,6 @@ def test_request_migration_sglang_decode(
request,
"worker1",
frontend.frontend_port,
migration_limit=migration_limit,
disagg_mode="decode",
) as decode1:
logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")
......@@ -465,7 +451,6 @@ def test_request_migration_sglang_decode(
request,
"worker2",
frontend.frontend_port,
migration_limit=migration_limit,
disagg_mode="decode",
) as decode2:
logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")
......
......@@ -82,7 +82,6 @@ class DynamoWorkerProcess(ManagedProcess):
request: pytest request fixture
worker_id: Unique identifier for the worker (e.g., "worker1", "prefill1")
frontend_port: Port where the frontend is running
migration_limit: Maximum number of migration attempts (default: 3)
mode: "prefill_and_decode" for aggregated, "prefill" or "decode" for disaggregated
"""
......@@ -91,18 +90,12 @@ class DynamoWorkerProcess(ManagedProcess):
request,
worker_id: str,
frontend_port: int,
migration_limit: int = 3,
mode: str = "prefill_and_decode",
):
self.worker_id = worker_id
self.system_port = allocate_port(9100)
self.mode = mode
# Prefill workers require migration_limit=0 (no KV cache migration support)
if mode == "prefill":
logging.info("Prefill worker - setting migration_limit to 0")
migration_limit = 0
command = [
"python3",
"-m",
......@@ -117,8 +110,6 @@ class DynamoWorkerProcess(ManagedProcess):
"8192",
"--free-gpu-memory-fraction",
"0.15", # avoid validation error on TRT-LLM available memory checks
"--migration-limit",
str(migration_limit),
]
if mode != "prefill_and_decode":
config_file = (
......@@ -225,20 +216,17 @@ def test_request_migration_trtllm_aggregated(
"""
# Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend:
with DynamoFrontendProcess(request, migration_limit=migration_limit) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start 2 workers
with DynamoWorkerProcess(
request, "worker1", frontend.frontend_port, migration_limit=migration_limit
) as worker1:
with DynamoWorkerProcess(request, "worker1", frontend.frontend_port) as worker1:
logger.info(f"Worker 1 PID: {worker1.get_pid()}")
with DynamoWorkerProcess(
request,
"worker2",
frontend.frontend_port,
migration_limit=migration_limit,
) as worker2:
logger.info(f"Worker 2 PID: {worker2.get_pid()}")
......@@ -280,7 +268,9 @@ def test_request_migration_trtllm_prefill(
"""
# Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend:
with DynamoFrontendProcess(
request, migration_limit=migration_limit, enforce_disagg=True
) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start decode worker first (required for prefill workers to connect)
......@@ -288,7 +278,6 @@ def test_request_migration_trtllm_prefill(
request,
"worker0",
frontend.frontend_port,
migration_limit=migration_limit,
mode="decode",
) as decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
......@@ -298,7 +287,6 @@ def test_request_migration_trtllm_prefill(
request,
"worker1",
frontend.frontend_port,
migration_limit=migration_limit,
mode="prefill",
) as prefill1:
logger.info(f"Prefill Worker 1 PID: {prefill1.get_pid()}")
......@@ -307,7 +295,6 @@ def test_request_migration_trtllm_prefill(
request,
"worker2",
frontend.frontend_port,
migration_limit=migration_limit,
mode="prefill",
) as prefill2:
logger.info(f"Prefill Worker 2 PID: {prefill2.get_pid()}")
......@@ -351,7 +338,9 @@ def test_request_migration_trtllm_kv_transfer(
"""
# Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend:
with DynamoFrontendProcess(
request, migration_limit=migration_limit, enforce_disagg=True
) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start prefill worker first
......@@ -359,7 +348,6 @@ def test_request_migration_trtllm_kv_transfer(
request,
"worker0",
frontend.frontend_port,
migration_limit=migration_limit,
mode="prefill",
) as prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
......@@ -369,7 +357,6 @@ def test_request_migration_trtllm_kv_transfer(
request,
"worker1",
frontend.frontend_port,
migration_limit=migration_limit,
mode="decode",
) as decode1:
logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")
......@@ -378,7 +365,6 @@ def test_request_migration_trtllm_kv_transfer(
request,
"worker2",
frontend.frontend_port,
migration_limit=migration_limit,
mode="decode",
) as decode2:
logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")
......@@ -425,7 +411,9 @@ def test_request_migration_trtllm_decode(
)
# Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend:
with DynamoFrontendProcess(
request, migration_limit=migration_limit, enforce_disagg=True
) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start prefill worker first
......@@ -433,7 +421,6 @@ def test_request_migration_trtllm_decode(
request,
"worker0",
frontend.frontend_port,
migration_limit=migration_limit,
mode="prefill",
) as prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
......@@ -443,7 +430,6 @@ def test_request_migration_trtllm_decode(
request,
"worker1",
frontend.frontend_port,
migration_limit=migration_limit,
mode="decode",
) as decode1:
logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")
......@@ -452,7 +438,6 @@ def test_request_migration_trtllm_decode(
request,
"worker2",
frontend.frontend_port,
migration_limit=migration_limit,
mode="decode",
) as decode2:
logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")
......
......@@ -72,7 +72,6 @@ class DynamoWorkerProcess(ManagedProcess):
request: pytest request fixture
worker_id: Unique identifier for the worker (e.g., "worker1", "prefill1")
frontend_port: Port where the frontend is running
migration_limit: Maximum number of migration attempts (default: 3)
is_prefill: None for aggregated mode, True for prefill worker, False for decode worker
"""
......@@ -81,7 +80,6 @@ class DynamoWorkerProcess(ManagedProcess):
request,
worker_id: str,
frontend_port: int,
migration_limit: int = 3,
is_prefill: bool | None = None,
):
self.worker_id = worker_id
......@@ -102,8 +100,6 @@ class DynamoWorkerProcess(ManagedProcess):
"512", # 8192 tokens x 1 context / 16 tokens per block = 512 blocks
"--gpu-memory-utilization",
"0.15", # avoid assertion error on vLLM available memory checks
"--migration-limit",
str(migration_limit),
]
if is_prefill is True:
command.append("--is-prefill-worker")
......@@ -221,20 +217,17 @@ def test_request_migration_vllm_aggregated(
"""
# Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend:
with DynamoFrontendProcess(request, migration_limit=migration_limit) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start 2 workers
with DynamoWorkerProcess(
request, "worker1", frontend.frontend_port, migration_limit=migration_limit
) as worker1:
with DynamoWorkerProcess(request, "worker1", frontend.frontend_port) as worker1:
logger.info(f"Worker 1 PID: {worker1.get_pid()}")
with DynamoWorkerProcess(
request,
"worker2",
frontend.frontend_port,
migration_limit=migration_limit,
) as worker2:
logger.info(f"Worker 2 PID: {worker2.get_pid()}")
......@@ -276,7 +269,9 @@ def test_request_migration_vllm_prefill(
"""
# Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend:
with DynamoFrontendProcess(
request, migration_limit=migration_limit, enforce_disagg=True
) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start decode worker first (required for prefill workers to connect)
......@@ -284,7 +279,6 @@ def test_request_migration_vllm_prefill(
request,
"worker0",
frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=False,
) as decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
......@@ -294,7 +288,6 @@ def test_request_migration_vllm_prefill(
request,
"worker1",
frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=True,
) as prefill1:
logger.info(f"Prefill Worker 1 PID: {prefill1.get_pid()}")
......@@ -303,7 +296,6 @@ def test_request_migration_vllm_prefill(
request,
"worker2",
frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=True,
) as prefill2:
logger.info(f"Prefill Worker 2 PID: {prefill2.get_pid()}")
......@@ -356,7 +348,9 @@ def test_request_migration_vllm_kv_transfer(
"""
# Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend:
with DynamoFrontendProcess(
request, migration_limit=migration_limit, enforce_disagg=True
) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start prefill worker first
......@@ -364,7 +358,6 @@ def test_request_migration_vllm_kv_transfer(
request,
"worker0",
frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=True,
) as prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
......@@ -374,7 +367,6 @@ def test_request_migration_vllm_kv_transfer(
request,
"worker1",
frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=False,
) as decode1:
logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")
......@@ -383,7 +375,6 @@ def test_request_migration_vllm_kv_transfer(
request,
"worker2",
frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=False,
) as decode2:
logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")
......@@ -440,7 +431,9 @@ def test_request_migration_vllm_decode(
)
# Step 1: Start the frontend
with DynamoFrontendProcess(request, enforce_disagg=True) as frontend:
with DynamoFrontendProcess(
request, migration_limit=migration_limit, enforce_disagg=True
) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start prefill worker first
......@@ -448,7 +441,6 @@ def test_request_migration_vllm_decode(
request,
"worker0",
frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=True,
) as prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
......@@ -458,7 +450,6 @@ def test_request_migration_vllm_decode(
request,
"worker1",
frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=False,
) as decode1:
logger.info(f"Decode Worker 1 PID: {decode1.get_pid()}")
......@@ -467,7 +458,6 @@ def test_request_migration_vllm_decode(
request,
"worker2",
frontend.frontend_port,
migration_limit=migration_limit,
is_prefill=False,
) as decode2:
logger.info(f"Decode Worker 2 PID: {decode2.get_pid()}")
......
......@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class DynamoFrontendProcess(BaseDynamoFrontendProcess):
"""Fault-tolerance frontend wrapper (keeps env settings from the historical helper)."""
def __init__(self, request, enforce_disagg: bool = False):
def __init__(self, request, migration_limit: int, enforce_disagg: bool = False):
extra_env = {
"DYN_REQUEST_PLANE": request.getfixturevalue("request_plane"),
# These tests expect full control over requests sent to workers. The canary
......@@ -35,6 +35,7 @@ class DynamoFrontendProcess(BaseDynamoFrontendProcess):
request,
frontend_port=0, # allocate a free port (xdist-safe)
router_mode="round-robin",
migration_limit=migration_limit,
extra_args=extra_args if extra_args else None,
extra_env=extra_env,
terminate_all_matching_process_names=False,
......
......@@ -32,8 +32,6 @@ class DynamoWorkerProcess(ManagedProcess):
"--enforce-eager",
"--max-model-len",
"8192",
"--migration-limit",
"3",
]
# Set debug logging environment
......
......@@ -721,6 +721,7 @@ class DynamoFrontendProcess(ManagedProcess):
*,
frontend_port: Optional[int] = None,
router_mode: str = "round-robin",
migration_limit: int = 0,
extra_args: Optional[list[str]] = None,
extra_env: Optional[dict[str, str]] = None,
# Default to false so pytest-xdist workers don't kill each other's frontends.
......@@ -750,6 +751,8 @@ class DynamoFrontendProcess(ManagedProcess):
# dynamo.frontend defaults to 8000 when neither env nor flag is provided.
if frontend_port is not None:
command.extend(["--http-port", str(frontend_port)])
# Migration limit is configured at the frontend level
command.extend(["--migration-limit", str(migration_limit)])
if extra_args:
command.extend(extra_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment