Unverified Commit 6e2b22ea authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: ETCD high availability client failover - lease watch resilience (#3950)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
Co-authored-by: default avatarRyan McCormick <rmccormick@nvidia.com>
parent 25fc7325
......@@ -479,7 +479,7 @@ impl InnerClient {
/// Wait for a new scaling decision. Use `get` when this returns to fetch the values.
async fn wait(&self) -> anyhow::Result<()> {
let watcher = self.etcd_client.kv_watch_prefix(&self.key).await?;
let (_prefix, _watcher, mut receiver) = watcher.dissolve();
let (_prefix, mut receiver) = watcher.dissolve();
tokio::select! {
_ = receiver.recv() => {
Ok(())
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use serde::{Deserialize, Serialize};
use std::sync::{Arc, Mutex};
use tokio::sync::watch;
use tracing;
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::transports::etcd::WatchEvent;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DisaggRouterConf {
pub max_local_prefill_length: i32,
}
impl Default for DisaggRouterConf {
fn default() -> Self {
Self {
max_local_prefill_length: 1000,
}
}
}
impl DisaggRouterConf {
pub async fn from_etcd_with_watcher(
drt: Arc<DistributedRuntime>,
model_name: &str,
) -> anyhow::Result<(Self, watch::Receiver<Self>)> {
let etcd_key = format!("public/components/disagg_router/models/chat/{}", model_name);
// Get the initial value if it exists
let Some(etcd_client) = drt.etcd_client() else {
anyhow::bail!("Static components don't have an etcd client");
};
let initial_config = match etcd_client.kv_get_prefix(&etcd_key).await {
Ok(kvs) => {
if let Some(kv) = kvs.first() {
match serde_json::from_slice::<DisaggRouterConf>(kv.value()) {
Ok(config) => {
tracing::debug!(
"Found initial config for key {}: {:?}",
etcd_key,
config
);
config
}
Err(e) => {
tracing::warn!(
"Failed to parse initial config for key {}: {}",
etcd_key,
e
);
DisaggRouterConf::default()
}
}
} else {
tracing::debug!(
"No initial config found for key {}, using default",
etcd_key
);
DisaggRouterConf::default()
}
}
Err(e) => {
tracing::warn!("Error fetching initial config for key {}: {}", etcd_key, e);
DisaggRouterConf::default()
}
};
// Create watch channel for config updates
let (watch_tx, watch_rx) = watch::channel(initial_config.clone());
// Set up the watcher after getting the initial value
let prefix_watcher = etcd_client.kv_get_and_watch_prefix(&etcd_key).await?;
let (key, mut kv_event_rx) = prefix_watcher.dissolve();
// Spawn background task to watch for config changes
drt.runtime().secondary().spawn(async move {
tracing::info!("Starting config watcher for disagg router key: {}", key);
loop {
let kv_event = tokio::select! {
_ = watch_tx.closed() => {
tracing::debug!("All watchers have closed; shutting down config watcher for key: {}", key);
break;
}
kv_event = kv_event_rx.recv() => {
match kv_event {
Some(kv_event) => kv_event,
None => {
tracing::debug!("Watch stream has closed; shutting down config watcher for key: {}", key);
break;
}
}
}
};
tracing::debug!("Received watch event for key {}", key);
match kv_event {
WatchEvent::Put(kv) => {
let val = serde_json::from_slice::<DisaggRouterConf>(kv.value());
if let Ok(config) = val {
tracing::info!("Config updated for key {}: {:?}", key, config);
// Broadcast the update
if watch_tx.send(config).is_err() {
tracing::debug!("Unable to send watch updates; shutting down config watcher for key: {}", key);
break;
}
} else {
tracing::error!("Unable to parse router config for key {}", key);
break;
}
}
WatchEvent::Delete(_) => {
tracing::warn!("Config key was deleted: {}", key);
// Reset to default values
if watch_tx.send(DisaggRouterConf::default()).is_err() {
tracing::debug!("Unable to send watch updates; shutting down config watcher for key: {}", key);
break;
}
}
}
}
tracing::debug!("Completed config watcher for key: {}", key);
});
Ok((initial_config, watch_rx))
}
}
#[derive(Clone)]
pub struct DisaggregatedRouter {
max_local_prefill_length: Arc<Mutex<i32>>,
model_name: String,
config_watcher: Option<watch::Receiver<DisaggRouterConf>>,
}
impl DisaggregatedRouter {
pub fn new(max_local_prefill_length: i32, model_name: String) -> Self {
DisaggregatedRouter {
max_local_prefill_length: Arc::new(Mutex::new(max_local_prefill_length)),
model_name,
config_watcher: None,
}
}
pub async fn new_with_etcd_and_default(
drt: Arc<DistributedRuntime>,
model_name: String,
default_max_local_prefill_length: i32,
) -> anyhow::Result<Self> {
let (mut config, watcher) =
DisaggRouterConf::from_etcd_with_watcher(drt, &model_name).await?;
// Use the provided default if no etcd value was found (when config is the default value)
if config.max_local_prefill_length == DisaggRouterConf::default().max_local_prefill_length {
config.max_local_prefill_length = default_max_local_prefill_length;
}
let router = Self {
max_local_prefill_length: Arc::new(Mutex::new(config.max_local_prefill_length)),
model_name: model_name.clone(),
config_watcher: Some(watcher),
};
// Start background task to watch for config updates
router.start_config_watcher();
Ok(router)
}
fn start_config_watcher(&self) {
if let Some(watcher) = self.config_watcher.clone() {
let mut watcher = watcher;
// Create a clone for the task
let model_name = self.model_name.clone();
let max_local_prefill_length = self.max_local_prefill_length.clone();
tokio::spawn(async move {
tracing::info!("Starting config update watcher for model: {}", model_name);
while watcher.changed().await.is_ok() {
let config = watcher.borrow().clone();
let new_value = config.max_local_prefill_length;
// Update the value using the mutex
let mut current_value = max_local_prefill_length.lock().unwrap();
let old_value = *current_value;
if old_value != new_value {
*current_value = new_value;
tracing::info!(
"Applied config update for model {}: max_local_prefill_length changed from {} to {}",
model_name,
old_value,
new_value
);
}
}
tracing::debug!("Config watcher closed for model: {}", model_name);
});
}
}
pub fn check_for_updates(&self) {
if let Some(watcher) = &self.config_watcher
&& watcher.has_changed().unwrap_or(false)
{
let config = watcher.borrow().clone();
let new_value = config.max_local_prefill_length;
// Update the value using the mutex
let mut current_value = self.max_local_prefill_length.lock().unwrap();
let old_value = *current_value;
if old_value != new_value {
*current_value = new_value;
tracing::info!(
"Applied config update for model {}: max_local_prefill_length changed from {} to {}",
self.model_name,
old_value,
new_value
);
}
}
}
pub fn prefill_remote(&self, prefill_length: i32, prefix_hit_length: i32) -> bool {
// Check for updates before making the decision
self.check_for_updates();
// Get the current value from the mutex
let max_local_prefill_length = *self.max_local_prefill_length.lock().unwrap();
// schedule the request purely based on the prefill length
// TODO: apply math models and compare local vs remote prefill TTFT
prefill_length - prefix_hit_length > max_local_prefill_length
}
pub fn update_value(&self, max_local_prefill_length: i32) {
let mut current = self.max_local_prefill_length.lock().unwrap();
*current = max_local_prefill_length;
}
pub fn get_model_name(&self) -> &str {
&self.model_name
}
}
......@@ -274,14 +274,14 @@ pub async fn start_kv_router_background(
cleanup_orphaned_consumers(&mut nats_queue, &etcd_client, &component, &consumer_uuid).await;
// Watch for router deletions to clean up orphaned consumers
let (_prefix_str, _watcher, mut router_replicas_rx) = etcd_client
let (_prefix_str, mut router_replicas_rx) = etcd_client
.kv_get_and_watch_prefix(&format!("{}/", KV_ROUTERS_ROOT_PATH))
.await?
.dissolve();
// Get the generate endpoint and watch for instance deletions
let generate_endpoint = component.endpoint("generate");
let (_instance_prefix, _instance_watcher, mut instance_event_rx) = etcd_client
let (_instance_prefix, mut instance_event_rx) = etcd_client
.kv_get_and_watch_prefix(generate_endpoint.etcd_root())
.await?
.dissolve();
......
......@@ -213,7 +213,7 @@ impl Client {
.kv_get_and_watch_prefix(endpoint.etcd_root())
.await?;
let (prefix, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
let (prefix, mut kv_event_rx) = prefix_watcher.dissolve();
let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
......
......@@ -15,7 +15,7 @@ use validator::Validate;
use etcd_client::{
Certificate, Compare, CompareOp, DeleteOptions, GetOptions, Identity, LockClient, LockOptions,
LockResponse, PutOptions, PutResponse, TlsOptions, Txn, TxnOp, TxnOpResponse, WatchOptions,
Watcher,
WatchStream, Watcher,
};
pub use etcd_client::{ConnectOptions, KeyValue, LeaseClient};
use tokio::time::{Duration, interval};
......@@ -38,6 +38,9 @@ pub struct Client {
connector: Arc<Connector>,
primary_lease: u64,
runtime: Runtime,
// Exclusive runtime for etcd lease keep-alive and watch tasks
// Avoid those tasks from being starved when the main runtime is busy
// WARNING: Do not await on main runtime from this runtime or deadlocks may occur
rt: Arc<tokio::runtime::Runtime>,
}
......@@ -306,110 +309,216 @@ impl Client {
self.watch_internal(prefix, true).await
}
/// Core watch implementation that sets up a resilient watcher for a key prefix.
///
/// Creates a background task that maintains a watch stream with automatic reconnection
/// on recoverable errors. If `include_existing` is true, existing keys are included
/// in the initial watch events.
async fn watch_internal(
&self,
prefix: impl AsRef<str> + std::fmt::Display,
include_existing: bool,
) -> Result<PrefixWatcher> {
let client = self.connector.get_client();
let mut kv_client = client.kv_client();
let mut watch_client = client.watch_client();
let (tx, rx) = mpsc::channel(32);
// Get start revision and send existing KVs
let mut start_revision = self
.get_start_revision(
prefix.as_ref(),
if include_existing { Some(&tx) } else { None },
)
.await?;
// Resilience watch stream in background
let connector = self.connector.clone();
let prefix_str = prefix.as_ref().to_string();
self.rt.spawn(async move {
let mut reconnect = true;
while reconnect {
// Start a new watch stream
let watch_stream =
match Self::new_watch_stream(&connector, &prefix_str, start_revision).await {
Ok(stream) => stream,
Err(_) => return,
};
// Watch the stream
reconnect =
Self::monitor_watch_stream(watch_stream, &prefix_str, &mut start_revision, &tx)
.await;
}
});
Ok(PrefixWatcher {
prefix: prefix.as_ref().to_string(),
rx,
})
}
/// Fetch the initial revision for watching and optionally send existing key-values.
///
/// Returns the next revision to watch from. If `existing_kvs_tx` is provided,
/// all existing keys with the prefix are sent through the channel first.
async fn get_start_revision(
&self,
prefix: impl AsRef<str> + std::fmt::Display,
existing_kvs_tx: Option<&mpsc::Sender<WatchEvent>>,
) -> Result<i64> {
let mut kv_client = self.connector.get_client().kv_client();
let mut get_response = kv_client
.get(prefix.as_ref(), Some(GetOptions::new().with_prefix()))
.await?;
let start_revision = get_response
// Get the start revision
let mut start_revision = get_response
.header()
.ok_or(error!("missing header; unable to get revision"))?
.revision();
tracing::trace!("{prefix}: start_revision: {start_revision}");
let start_revision = start_revision + 1;
let (watcher, mut watch_stream) = watch_client
.watch(
prefix.as_ref(),
Some(
WatchOptions::new()
.with_prefix()
.with_start_revision(start_revision)
.with_prev_key(),
),
)
.await?;
start_revision += 1;
let kvs = if include_existing {
// Send existing KVs from response if requested
if let Some(tx) = existing_kvs_tx {
let kvs = get_response.take_kvs();
tracing::trace!("initial kv count: {:?}", kvs.len());
kvs
} else {
vec![]
};
for kv in kvs.into_iter() {
tx.send(WatchEvent::Put(kv)).await?;
}
}
let (tx, rx) = mpsc::channel(32);
Ok(start_revision)
}
self.rt.spawn(async move {
if include_existing {
for kv in kvs {
if tx.send(WatchEvent::Put(kv)).await.is_err() {
// receiver is already closed
return;
/// Establish a new watch stream with automatic retry and reconnection.
///
/// Attempts to create a watch stream, reconnecting to ETCD if necessary.
/// Uses a 10-second timeout for reconnection attempts before giving up.
async fn new_watch_stream(
connector: &Arc<Connector>,
prefix: &String,
start_revision: i64,
) -> Result<WatchStream> {
loop {
match connector
.get_client()
.watch_client()
.watch(
prefix.as_str(),
Some(
WatchOptions::new()
.with_prefix()
.with_start_revision(start_revision)
.with_prev_key(),
),
)
.await
{
Ok((_, watch_stream)) => {
tracing::debug!("Watch stream established for prefix '{}'", prefix);
return Ok(watch_stream);
}
Err(err) => {
tracing::debug!(error = %err, "Failed to establish watch stream for prefix '{}'", prefix);
let deadline = std::time::Instant::now() + Duration::from_secs(10);
if let Err(err) = connector.reconnect(deadline).await {
tracing::error!(
"Failed to reconnect to ETCD within 10 secs for watching prefix '{}': {}",
prefix,
err
);
return Err(err);
}
// continue - retry establishing the watch stream
}
}
}
}
loop {
tokio::select! {
maybe_resp = watch_stream.next() => {
// Early return for None or Err cases
let Some(Ok(response)) = maybe_resp else {
tracing::info!("kv watch stream closed");
return;
};
// Process events
for event in response.events() {
// Extract the KeyValue if it exists
let Some(kv) = event.kv() else {
continue; // Skip events with no KV
};
// Handle based on event type
match event.event_type() {
etcd_client::EventType::Put => {
if let Err(err) = tx.send(WatchEvent::Put(kv.clone())).await {
tracing::error!("kv watcher error forwarding WatchEvent::Put: {err}");
return;
}
}
etcd_client::EventType::Delete => {
if tx.send(WatchEvent::Delete(kv.clone())).await.is_err() {
return;
}
}
}
/// Monitor a watch stream and forward events to receivers.
///
/// Returns `true` for recoverable errors (network issues, stream closure) that warrant
/// reconnection attempts. Returns `false` for permanent failures (protocol violations,
/// channel errors, no receivers) where watching should stop.
async fn monitor_watch_stream(
mut watch_stream: WatchStream,
prefix: &String,
start_revision: &mut i64,
tx: &mpsc::Sender<WatchEvent>,
) -> bool {
loop {
tokio::select! {
maybe_resp = watch_stream.next() => {
// Handle the watch response
let response = match maybe_resp {
Some(Ok(res)) => res,
Some(Err(err)) => {
tracing::warn!(error = %err, "Error watching stream for prefix '{}'", prefix);
return true; // Exit to reconnect
}
None => {
tracing::warn!("Watch stream unexpectedly closed for prefix '{}'", prefix);
return true; // Exit to reconnect
}
};
// Update revision for reconnect
*start_revision = match response.header() {
Some(header) => header.revision() + 1,
None => {
tracing::error!("Missing header in watch response for prefix '{}'", prefix);
return false;
}
};
// Process events
if Self::process_watch_events(response.events(), tx).await.is_err() {
return false;
};
}
_ = tx.closed() => {
tracing::debug!("no more receivers, stopping watcher");
return false;
}
}
}
}
/// Process etcd events and forward them as Put/Delete watch events.
///
/// Filters out events without key-values and transforms etcd events into
/// appropriate WatchEvent types for channel transmission.
async fn process_watch_events(
events: &[etcd_client::Event],
tx: &mpsc::Sender<WatchEvent>,
) -> Result<()> {
for event in events {
// Extract the KeyValue if it exists
let Some(kv) = event.kv() else {
continue; // Skip events with no KV
};
// Handle based on event type
match event.event_type() {
etcd_client::EventType::Put => {
if let Err(err) = tx.send(WatchEvent::Put(kv.clone())).await {
tracing::error!("kv watcher error forwarding WatchEvent::Put: {err}");
return Err(err.into());
}
_ = tx.closed() => {
tracing::debug!("no more receivers, stopping watcher");
return;
}
etcd_client::EventType::Delete => {
if tx.send(WatchEvent::Delete(kv.clone())).await.is_err() {
return Err(anyhow::anyhow!("failed to send WatchEvent::Delete"));
}
}
}
});
Ok(PrefixWatcher {
prefix: prefix.as_ref().to_string(),
watcher,
rx,
})
}
Ok(())
}
}
#[derive(Dissolve)]
pub struct PrefixWatcher {
prefix: String,
watcher: Watcher,
rx: mpsc::Receiver<WatchEvent>,
}
......
......@@ -27,7 +27,7 @@ async fn wait_for_key_count<T: DeserializeOwned>(
expected_count: usize,
timeout: Option<Duration>,
) -> Result<HashMap<String, T>, LeaderWorkerBarrierError> {
let (_key, _watcher, mut rx) = client
let (_key, mut rx) = client
.kv_get_and_watch_prefix(&key)
.await
.map_err(LeaderWorkerBarrierError::EtcdError)?
......
......@@ -92,7 +92,7 @@ where
let prefix = prefix.into();
let prefix_watcher = client.kv_get_and_watch_prefix(&prefix).await?;
let (prefix_str, _watcher, mut events_rx) = prefix_watcher.dissolve();
let (prefix_str, mut events_rx) = prefix_watcher.dissolve();
tokio::spawn(async move {
let mut state: HashMap<K, V> = HashMap::new();
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import os
import shutil
import time
import pytest
from tests.conftest import NatsServer
from tests.fault_tolerance.etcd_ha.utils import (
DynamoFrontendProcess,
EtcdCluster,
send_inference_request,
wait_for_processes_to_terminate,
)
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.engine_process import FRONTEND_PORT
from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_health_generate, check_models_api
logger = logging.getLogger(__name__)
class DynamoWorkerProcess(ManagedProcess):
"""Process manager for Dynamo worker with SGLang backend and ETCD HA support"""
def __init__(self, request, etcd_endpoints: list, mode: str = "agg"):
"""
Initialize SGLang worker process with ETCD HA support.
Args:
request: pytest request object
etcd_endpoints: List of ETCD endpoints
mode: One of "agg", "prefill", "decode"
"""
command = [
"python3",
"-m",
"dynamo.sglang",
"--model-path",
FAULT_TOLERANCE_MODEL_NAME,
"--served-model-name",
FAULT_TOLERANCE_MODEL_NAME,
"--page-size",
"16",
"--tp",
"1",
"--trust-remote-code",
]
# Add mode-specific arguments
if mode == "agg":
# Aggregated mode - add skip-tokenizer-init
command.append("--skip-tokenizer-init")
else:
# Disaggregated mode - add disaggregation arguments
command.extend(
[
"--disaggregation-mode",
mode,
"--disaggregation-bootstrap-port",
"12345",
"--host",
"0.0.0.0",
"--disaggregation-transfer-backend",
"nixl",
]
)
health_check_urls = [
(f"http://localhost:{FRONTEND_PORT}/v1/models", check_models_api),
(f"http://localhost:{FRONTEND_PORT}/health", check_health_generate),
]
# Set port based on worker type
if mode == "prefill":
port = "8082"
health_check_urls = [(f"http://localhost:{port}/health", self.is_ready)]
elif mode == "decode":
port = "8081"
health_check_urls = [(f"http://localhost:{port}/health", self.is_ready)]
else: # agg (aggregated mode)
port = "8081"
# Set debug logging and ETCD endpoints
env = os.environ.copy()
env["DYN_LOG"] = "debug"
env["ETCD_ENDPOINTS"] = ",".join(etcd_endpoints)
env["DYN_SYSTEM_ENABLED"] = "true"
env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
env["DYN_SYSTEM_PORT"] = port
# Set GPU assignment for disaggregated mode
if mode == "decode":
env["CUDA_VISIBLE_DEVICES"] = "1" # Use GPU 1 for decode worker
elif mode == "prefill":
env["CUDA_VISIBLE_DEVICES"] = "0" # Use GPU 0 for prefill worker
# For agg (aggregated) mode, use default GPU assignment
# Set log directory based on worker type
log_dir = f"{request.node.name}_{mode}_worker"
# Clean up any existing log directory from previous runs
try:
shutil.rmtree(log_dir)
logger.info(f"Cleaned up existing log directory: {log_dir}")
except FileNotFoundError:
pass
super().__init__(
command=command,
env=env,
health_check_urls=health_check_urls,
timeout=300,
display_output=True,
terminate_existing=False,
# Ensure any orphaned SGLang engine cores or child helpers are cleaned up
stragglers=[
"SGLANG:EngineCore",
],
straggler_commands=[
"-m dynamo.sglang",
],
log_dir=log_dir,
)
self.mode = mode
def is_ready(self, response) -> bool:
"""Check the health of the worker process"""
try:
data = response.json()
if data.get("status") == "ready":
logger.info(f"{self.mode.capitalize()} worker status is ready")
return True
logger.warning(
f"{self.mode.capitalize()} worker status is not ready: {data.get('status')}"
)
except ValueError:
logger.warning(
f"{self.mode.capitalize()} worker health response is not valid JSON"
)
return False
@pytest.mark.sglang
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_etcd_ha_failover_sglang_aggregated(request, predownload_models):
"""
Test ETCD High Availability with leader failover using SGLang.
This test:
1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and an SGLang worker
3. Sends an inference request to verify the system works
4. Terminates the ETCD leader node
5. Sends another inference request to verify the system still works
"""
# Step 1: Start NATS server
with NatsServer(request):
logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes
etcd_endpoints = etcd_cluster.get_client_endpoints()
logger.info(f"ETCD endpoints: {etcd_endpoints}")
# Step 3: Start the frontend with ETCD endpoints
with DynamoFrontendProcess(request, etcd_endpoints):
logger.info("Frontend started successfully")
# Step 4: Start an SGLang worker
with DynamoWorkerProcess(request, etcd_endpoints, mode="agg"):
logger.info("SGLang worker started successfully")
# Small wait to ensure worker is fully ready
time.sleep(2)
# Step 5: Send first inference request to verify system is working
logger.info("Sending first inference request (before failover)")
result1 = send_inference_request("What is 2+2? The answer is")
assert (
"4" in result1.lower() or "four" in result1.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'"
# Step 6: Identify and terminate the ETCD leader
logger.info("Terminating ETCD leader to test failover")
terminated_idx = etcd_cluster.terminate_leader()
if terminated_idx is None:
pytest.fail("Failed to identify and terminate ETCD leader")
logger.info(f"Terminated ETCD node {terminated_idx}")
# Step 7: Send second inference request to verify system still works
logger.info("Sending second inference request (after failover)")
result2 = send_inference_request("The capital of France is")
assert (
"paris" in result2.lower()
), f"Expected 'Paris' in response, got: '{result2}'"
@pytest.mark.sglang
@pytest.mark.gpu_2
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_etcd_ha_failover_sglang_disaggregated(
request, predownload_models, set_ucx_tls_no_mm
):
"""
Test ETCD High Availability with leader failover in disaggregated mode using SGLang.
This test:
1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and both prefill and decode SGLang workers
3. Sends an inference request to verify the system works
4. Terminates the ETCD leader node
5. Sends another inference request to verify the system still works
Note: This test requires 2 GPUs to run decode and prefill workers on separate GPUs.
"""
# Step 1: Start NATS server
with NatsServer(request):
logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes
etcd_endpoints = etcd_cluster.get_client_endpoints()
logger.info(f"ETCD endpoints: {etcd_endpoints}")
# Step 3: Start the frontend with ETCD endpoints
with DynamoFrontendProcess(request, etcd_endpoints):
logger.info("Frontend started successfully")
# Step 4: Start the decode worker
with DynamoWorkerProcess(request, etcd_endpoints, mode="decode"):
logger.info("Decode worker started successfully")
# Step 5: Start the prefill worker
with DynamoWorkerProcess(request, etcd_endpoints, mode="prefill"):
logger.info("Prefill worker started successfully")
# Small wait to ensure workers are fully ready
time.sleep(2)
# Step 6: Send first inference request to verify system is working
logger.info("Sending first inference request (before failover)")
result1 = send_inference_request("What is 2+2? The answer is")
assert (
"4" in result1.lower() or "four" in result1.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'"
# Step 7: Identify and terminate the ETCD leader
logger.info("Terminating ETCD leader to test failover")
terminated_idx = etcd_cluster.terminate_leader()
if terminated_idx is None:
pytest.fail("Failed to identify and terminate ETCD leader")
logger.info(f"Terminated ETCD node {terminated_idx}")
# Step 8: Send second inference request to verify system still works
logger.info("Sending second inference request (after failover)")
result2 = send_inference_request("The capital of France is")
assert (
"paris" in result2.lower()
), f"Expected 'Paris' in response, got: '{result2}'"
@pytest.mark.sglang
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_etcd_non_ha_shutdown_sglang_aggregated(request, predownload_models):
"""
Test that frontend and worker shut down when single ETCD node is terminated using SGLang.
This test:
1. Starts a single ETCD node (no cluster)
2. Starts NATS, frontend, and an SGLang worker
3. Sends an inference request to verify the system works
4. Terminates the single ETCD node
5. Verifies that frontend and worker shut down gracefully
"""
# Step 1: Start NATS server
with NatsServer(request):
logger.info("NATS server started successfully")
# Step 2: Start single ETCD node using EtcdCluster with num_replicas=1
with EtcdCluster(request, num_replicas=1) as etcd_cluster:
logger.info("Single ETCD node started successfully")
# Get the endpoint for the single ETCD node
etcd_endpoints = etcd_cluster.get_client_endpoints()
logger.info(f"ETCD endpoint: {etcd_endpoints}")
# Step 3: Start the frontend with ETCD endpoint
with DynamoFrontendProcess(request, etcd_endpoints) as frontend:
logger.info("Frontend started successfully")
# Step 4: Start an SGLang worker
with DynamoWorkerProcess(request, etcd_endpoints, mode="agg") as worker:
logger.info("SGLang worker started successfully")
# Small wait to ensure worker is fully ready
time.sleep(2)
# Step 5: Send inference request to verify system is working
logger.info("Sending inference request")
result = send_inference_request("What is 2+2? The answer is")
assert (
"4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result}'"
logger.info("System is working correctly with single ETCD node")
# Step 6: Terminate the ETCD node
logger.info("Terminating single ETCD node")
etcd_cluster.stop()
# Step 7: Wait and verify frontend and worker detect the loss
wait_for_processes_to_terminate(
{"Worker": worker, "Frontend": frontend}
)
@pytest.mark.sglang
@pytest.mark.gpu_2
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_etcd_non_ha_shutdown_sglang_disaggregated(
request, predownload_models, set_ucx_tls_no_mm
):
"""
Test that frontend and workers shut down when single ETCD node is terminated in disaggregated mode using SGLang.
This test:
1. Starts a single ETCD node (no cluster)
2. Starts NATS, frontend, and both prefill and decode SGLang workers
3. Sends an inference request to verify the system works
4. Terminates the single ETCD node
5. Verifies that frontend and both workers shut down gracefully
Note: This test requires 2 GPUs to run decode and prefill workers on separate GPUs.
"""
# Step 1: Start NATS server
with NatsServer(request):
logger.info("NATS server started successfully")
# Step 2: Start single ETCD node using EtcdCluster with num_replicas=1
with EtcdCluster(request, num_replicas=1) as etcd_cluster:
logger.info("Single ETCD node started successfully")
# Get the endpoint for the single ETCD node
etcd_endpoints = etcd_cluster.get_client_endpoints()
logger.info(f"ETCD endpoint: {etcd_endpoints}")
# Step 3: Start the frontend with ETCD endpoint
with DynamoFrontendProcess(request, etcd_endpoints) as frontend:
logger.info("Frontend started successfully")
# Step 4: Start the decode worker
with DynamoWorkerProcess(
request, etcd_endpoints, mode="decode"
) as decode_worker:
logger.info("Decode worker started successfully")
# Step 5: Start the prefill worker
with DynamoWorkerProcess(
request, etcd_endpoints, mode="prefill"
) as prefill_worker:
logger.info("Prefill worker started successfully")
# Small wait to ensure workers are fully ready
time.sleep(2)
# Step 6: Send inference request to verify system is working
logger.info("Sending inference request")
result = send_inference_request("What is 2+2? The answer is")
assert (
"4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result}'"
logger.info(
"System is working correctly with single ETCD node in disaggregated mode"
)
# Step 7: Terminate the ETCD node
logger.info("Terminating single ETCD node")
etcd_cluster.stop()
# Step 8: Wait and verify frontend and both workers detect the loss
wait_for_processes_to_terminate(
{
"Decode Worker": decode_worker,
"Prefill Worker": prefill_worker,
"Frontend": frontend,
}
)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import os
import shutil
import time
import pytest
from tests.conftest import NatsServer
from tests.fault_tolerance.etcd_ha.utils import (
DynamoFrontendProcess,
EtcdCluster,
send_inference_request,
wait_for_processes_to_terminate,
)
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.engine_process import FRONTEND_PORT
from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_health_generate, check_models_api
logger = logging.getLogger(__name__)
class DynamoWorkerProcess(ManagedProcess):
"""Process manager for Dynamo worker with TRT-LLM backend and ETCD HA support"""
def __init__(
self,
request,
etcd_endpoints: list,
mode: str = "prefill_and_decode",
):
"""
Initialize TRT-LLM worker process with ETCD HA support.
Args:
request: pytest request object
etcd_endpoints: List of ETCD endpoints for HA
mode: One of "prefill_and_decode", "prefill", "decode"
"""
# Prefill workers require migration_limit=0 (no KV cache migration support)
migration_limit = "0" if mode == "prefill" else "3"
command = [
"python3",
"-m",
"dynamo.trtllm",
"--model",
FAULT_TOLERANCE_MODEL_NAME,
"--disaggregation-mode",
mode,
"--free-gpu-memory-fraction",
"0.45",
"--max-seq-len",
"8192",
"--migration-limit",
migration_limit,
]
# Add disaggregation-specific configuration
if mode != "prefill_and_decode":
with open("test_etcd_ha_trtllm_config.yaml", "w") as f:
f.write("cache_transceiver_config:\n backend: DEFAULT\n")
f.write("disable_overlap_scheduler: true\n")
command += [
"--extra-engine-args",
"test_etcd_ha_trtllm_config.yaml",
]
health_check_urls = [
(f"http://localhost:{FRONTEND_PORT}/v1/models", check_models_api),
(f"http://localhost:{FRONTEND_PORT}/health", check_health_generate),
]
# Set port based on worker type
if mode == "prefill":
port = "8082"
health_check_urls = [(f"http://localhost:{port}/health", self.is_ready)]
elif mode == "decode":
port = "8081"
health_check_urls = [(f"http://localhost:{port}/health", self.is_ready)]
else: # prefill_and_decode
port = "8081"
# Set debug logging and ETCD endpoints
env = os.environ.copy()
env["DYN_LOG"] = "debug"
env["ETCD_ENDPOINTS"] = ",".join(etcd_endpoints)
env["DYN_SYSTEM_ENABLED"] = "true"
env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
env["DYN_SYSTEM_PORT"] = port
# Set log directory based on worker type
log_dir = f"{request.node.name}_{mode}_worker"
# Clean up any existing log directory from previous runs
try:
shutil.rmtree(log_dir)
logger.info(f"Cleaned up existing log directory: {log_dir}")
except FileNotFoundError:
pass
super().__init__(
command=command,
env=env,
health_check_urls=health_check_urls,
timeout=300,
display_output=True,
terminate_existing=False,
log_dir=log_dir,
)
self.mode = mode
def is_ready(self, response) -> bool:
"""Check the health of the worker process"""
try:
data = response.json()
if data.get("status") == "ready":
logger.info(f"{self.mode.capitalize()} worker status is ready")
return True
logger.warning(
f"{self.mode.capitalize()} worker status is not ready: {data.get('status')}"
)
except ValueError:
logger.warning(
f"{self.mode.capitalize()} worker health response is not valid JSON"
)
return False
@pytest.mark.trtllm_marker
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_etcd_ha_failover_trtllm_aggregated(request, predownload_models):
"""
Test ETCD High Availability with leader failover for TRT-LLM in aggregated mode.
This test:
1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and an aggregated TRT-LLM worker
3. Sends an inference request to verify the system works
4. Terminates the ETCD leader node
5. Sends another inference request to verify the system still works
"""
# Step 1: Start NATS server
with NatsServer(request):
logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes
etcd_endpoints = etcd_cluster.get_client_endpoints()
logger.info(f"ETCD endpoints: {etcd_endpoints}")
# Step 3: Start the frontend with ETCD endpoints
with DynamoFrontendProcess(request, etcd_endpoints):
logger.info("Frontend started successfully")
# Step 4: Start an aggregated TRT-LLM worker
with DynamoWorkerProcess(
request, etcd_endpoints, mode="prefill_and_decode"
):
logger.info("Aggregated TRT-LLM worker started successfully")
# Step 5: Send first inference request to verify system is working
logger.info("Sending first inference request (before failover)")
result1 = send_inference_request("What is 2+2? The answer is")
assert (
"4" in result1.lower() or "four" in result1.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'"
# Step 6: Identify and terminate the ETCD leader
logger.info("Terminating ETCD leader to test failover")
terminated_idx = etcd_cluster.terminate_leader()
if terminated_idx is None:
pytest.fail("Failed to identify and terminate ETCD leader")
logger.info(f"Terminated ETCD node {terminated_idx}")
# Step 7: Send second inference request to verify system still works
logger.info("Sending second inference request (after failover)")
result2 = send_inference_request("The capital of France is")
assert (
"paris" in result2.lower()
), f"Expected 'Paris' in response, got: '{result2}'"
@pytest.mark.trtllm_marker
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_etcd_ha_failover_trtllm_disaggregated(
request, predownload_models, set_ucx_tls_no_mm
):
"""
Test ETCD High Availability with leader failover for TRT-LLM in disaggregated mode.
This test:
1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and both prefill and decode TRT-LLM workers
3. Sends an inference request to verify the system works
4. Terminates the ETCD leader node
5. Sends another inference request to verify the system still works
"""
# Step 1: Start NATS server
with NatsServer(request):
logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes
etcd_endpoints = etcd_cluster.get_client_endpoints()
logger.info(f"ETCD endpoints: {etcd_endpoints}")
# Step 3: Start the frontend with ETCD endpoints
with DynamoFrontendProcess(request, etcd_endpoints):
logger.info("Frontend started successfully")
# Step 4: Start the prefill worker
with DynamoWorkerProcess(request, etcd_endpoints, mode="prefill"):
logger.info("Prefill worker started successfully")
# Step 5: Start the decode worker
with DynamoWorkerProcess(request, etcd_endpoints, mode="decode"):
logger.info("Decode worker started successfully")
# TODO: Fix disagg health checks
time.sleep(2)
# Step 6: Send first inference request to verify system is working
logger.info("Sending first inference request (before failover)")
result1 = send_inference_request("What is 2+2? The answer is")
assert (
"4" in result1.lower() or "four" in result1.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'"
# Step 7: Identify and terminate the ETCD leader
logger.info("Terminating ETCD leader to test failover")
terminated_idx = etcd_cluster.terminate_leader()
if terminated_idx is None:
pytest.fail("Failed to identify and terminate ETCD leader")
logger.info(f"Terminated ETCD node {terminated_idx}")
# Step 8: Send second inference request to verify system still works
logger.info("Sending second inference request (after failover)")
result2 = send_inference_request("The capital of France is")
assert (
"paris" in result2.lower()
), f"Expected 'Paris' in response, got: '{result2}'"
@pytest.mark.trtllm_marker
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_etcd_non_ha_shutdown_trtllm_aggregated(request, predownload_models):
"""
Test that frontend and worker shut down when single ETCD node is terminated for TRT-LLM in aggregated mode.
This test:
1. Starts a single ETCD node (no cluster)
2. Starts NATS, frontend, and an aggregated TRT-LLM worker
3. Sends an inference request to verify the system works
4. Terminates the single ETCD node
5. Verifies that frontend and worker shut down gracefully
"""
# Step 1: Start NATS server
with NatsServer(request):
logger.info("NATS server started successfully")
# Step 2: Start single ETCD node using EtcdCluster with num_replicas=1
with EtcdCluster(request, num_replicas=1) as etcd_cluster:
logger.info("Single ETCD node started successfully")
# Get the endpoint for the single ETCD node
etcd_endpoints = etcd_cluster.get_client_endpoints()
logger.info(f"ETCD endpoint: {etcd_endpoints}")
# Step 3: Start the frontend with ETCD endpoint
with DynamoFrontendProcess(request, etcd_endpoints) as frontend:
logger.info("Frontend started successfully")
# Step 4: Start an aggregated TRT-LLM worker
with DynamoWorkerProcess(
request, etcd_endpoints, mode="prefill_and_decode"
) as worker:
logger.info("Aggregated TRT-LLM worker started successfully")
# TODO: Fix disagg health checks
time.sleep(2)
# Step 5: Send inference request to verify system is working
logger.info("Sending inference request")
result = send_inference_request("What is 2+2? The answer is")
assert (
"4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result}'"
logger.info("System is working correctly with single ETCD node")
# Step 6: Terminate the ETCD node
logger.info("Terminating single ETCD node")
etcd_cluster.stop()
# Step 7: Wait and verify frontend and worker detect the loss
wait_for_processes_to_terminate(
{"Worker": worker, "Frontend": frontend}
)
@pytest.mark.trtllm_marker
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_etcd_non_ha_shutdown_trtllm_disaggregated(
request, predownload_models, set_ucx_tls_no_mm
):
"""
Test that frontend and workers shut down when single ETCD node is terminated for TRT-LLM in disaggregated mode.
This test:
1. Starts a single ETCD node (no cluster)
2. Starts NATS, frontend, and both prefill and decode TRT-LLM workers
3. Sends an inference request to verify the system works
4. Terminates the single ETCD node
5. Verifies that frontend and both workers shut down gracefully
"""
# Step 1: Start NATS server
with NatsServer(request):
logger.info("NATS server started successfully")
# Step 2: Start single ETCD node using EtcdCluster with num_replicas=1
with EtcdCluster(request, num_replicas=1) as etcd_cluster:
logger.info("Single ETCD node started successfully")
# Get the endpoint for the single ETCD node
etcd_endpoints = etcd_cluster.get_client_endpoints()
logger.info(f"ETCD endpoint: {etcd_endpoints}")
# Step 3: Start the frontend with ETCD endpoint
with DynamoFrontendProcess(request, etcd_endpoints) as frontend:
logger.info("Frontend started successfully")
# Step 4: Start the prefill worker
with DynamoWorkerProcess(
request, etcd_endpoints, mode="prefill"
) as prefill_worker:
logger.info("Prefill worker started successfully")
# Step 5: Start the decode worker
with DynamoWorkerProcess(
request, etcd_endpoints, mode="decode"
) as decode_worker:
logger.info("Decode worker started successfully")
# TODO: Fix disagg health checks
time.sleep(2)
# Step 6: Send inference request to verify system is working
logger.info("Sending inference request")
result = send_inference_request("What is 2+2? The answer is")
assert (
"4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result}'"
logger.info(
"System is working correctly with single ETCD node in disaggregated mode"
)
# Step 7: Terminate the ETCD node
logger.info("Terminating single ETCD node")
etcd_cluster.stop()
# Step 8: Wait and verify frontend and both workers detect the loss
wait_for_processes_to_terminate(
{
"Decode Worker": decode_worker,
"Prefill Worker": prefill_worker,
"Frontend": frontend,
}
)
......@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
class DynamoWorkerProcess(ManagedProcess):
"""Process manager for Dynamo worker with vLLM backend and ETCD HA support"""
def __init__(self, request, etcd_endpoints: list):
def __init__(self, request, etcd_endpoints: list, is_prefill: bool = False):
command = [
"python3",
"-m",
......@@ -39,18 +39,34 @@ class DynamoWorkerProcess(ManagedProcess):
"8192",
]
# Health checks - frontend model registration
health_check_urls = [
(f"http://localhost:{FRONTEND_PORT}/v1/models", check_models_api),
(f"http://localhost:{FRONTEND_PORT}/health", check_health_generate),
]
# Set port based on worker type
port = "8082" if is_prefill else "8081"
# Configure health check based on worker type
if is_prefill:
# Prefill workers check their own status endpoint
command.append("--is-prefill-worker")
health_check_urls = [(f"http://localhost:{port}/health", self.is_ready)]
else:
# Decode workers should also check their own status endpoint first,
# then verify the frontend sees the model
health_check_urls = [
(f"http://localhost:{port}/health", self.is_ready),
(f"http://localhost:{FRONTEND_PORT}/v1/models", check_models_api),
(f"http://localhost:{FRONTEND_PORT}/health", check_health_generate),
]
# Set debug logging and ETCD endpoints
env = os.environ.copy()
env["DYN_LOG"] = "debug"
env["ETCD_ENDPOINTS"] = ",".join(etcd_endpoints)
env["DYN_SYSTEM_ENABLED"] = "true"
env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
env["DYN_SYSTEM_PORT"] = port
log_dir = f"{request.node.name}_worker"
# Set log directory based on worker type
worker_type = "prefill_worker" if is_prefill else "worker"
log_dir = f"{request.node.name}_{worker_type}"
# Clean up any existing log directory from previous runs
try:
......@@ -75,15 +91,28 @@ class DynamoWorkerProcess(ManagedProcess):
log_dir=log_dir,
)
self.is_prefill = is_prefill
def is_ready(self, response) -> bool:
"""Check the health of the worker process"""
try:
data = response.json()
if data.get("status") == "ready":
worker_type = "Prefill worker" if self.is_prefill else "Worker"
logger.info(f"{worker_type} status is ready")
return True
worker_type = "Prefill worker" if self.is_prefill else "Worker"
logger.warning(f"{worker_type} status is not ready: {data.get('status')}")
except ValueError:
worker_type = "Prefill worker" if self.is_prefill else "Worker"
logger.warning(f"{worker_type} health response is not valid JSON")
return False
@pytest.mark.vllm
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.xfail(
strict=False,
reason="Lease watch failover not yet implemented, only lease keep alive failover is implemented",
)
def test_etcd_ha_failover_vllm_aggregated(request, predownload_models):
"""
Test ETCD High Availability with leader failover.
......@@ -138,6 +167,70 @@ def test_etcd_ha_failover_vllm_aggregated(request, predownload_models):
), f"Expected 'Paris' in response, got: '{result2}'"
@pytest.mark.vllm
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_etcd_ha_failover_vllm_disaggregated(
request, predownload_models, set_ucx_tls_no_mm
):
"""
Test ETCD High Availability with leader failover in disaggregated mode.
This test:
1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and both prefill and decode vLLM workers
3. Sends an inference request to verify the system works
4. Terminates the ETCD leader node
5. Sends another inference request to verify the system still works
"""
# Step 1: Start NATS server
with NatsServer(request):
logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes
etcd_endpoints = etcd_cluster.get_client_endpoints()
logger.info(f"ETCD endpoints: {etcd_endpoints}")
# Step 3: Start the frontend with ETCD endpoints
with DynamoFrontendProcess(request, etcd_endpoints):
logger.info("Frontend started successfully")
# Step 4: Start the prefill worker
with DynamoWorkerProcess(request, etcd_endpoints, is_prefill=True):
logger.info("Prefill worker started successfully")
# Step 5: Start the decode worker
with DynamoWorkerProcess(request, etcd_endpoints, is_prefill=False):
logger.info("Decode worker started successfully")
# Step 6: Send first inference request to verify system is working
logger.info("Sending first inference request (before failover)")
result1 = send_inference_request("What is 2+2? The answer is")
assert (
"4" in result1.lower() or "four" in result1.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'"
# Step 7: Identify and terminate the ETCD leader
logger.info("Terminating ETCD leader to test failover")
terminated_idx = etcd_cluster.terminate_leader()
if terminated_idx is None:
pytest.fail("Failed to identify and terminate ETCD leader")
logger.info(f"Terminated ETCD node {terminated_idx}")
# Step 8: Send second inference request to verify system still works
logger.info("Sending second inference request (after failover)")
result2 = send_inference_request("The capital of France is")
assert (
"paris" in result2.lower()
), f"Expected 'Paris' in response, got: '{result2}'"
@pytest.mark.vllm
@pytest.mark.gpu_1
@pytest.mark.e2e
......@@ -190,3 +283,73 @@ def test_etcd_non_ha_shutdown_vllm_aggregated(request, predownload_models):
wait_for_processes_to_terminate(
{"Worker": worker, "Frontend": frontend}
)
@pytest.mark.vllm
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_etcd_non_ha_shutdown_vllm_disaggregated(
request, predownload_models, set_ucx_tls_no_mm
):
"""
Test that frontend and workers shut down when single ETCD node is terminated in disaggregated mode.
This test:
1. Starts a single ETCD node (no cluster)
2. Starts NATS, frontend, and both prefill and decode vLLM workers
3. Sends an inference request to verify the system works
4. Terminates the single ETCD node
5. Verifies that frontend and both workers shut down gracefully
"""
# Step 1: Start NATS server
with NatsServer(request):
logger.info("NATS server started successfully")
# Step 2: Start single ETCD node using EtcdCluster with num_replicas=1
with EtcdCluster(request, num_replicas=1) as etcd_cluster:
logger.info("Single ETCD node started successfully")
# Get the endpoint for the single ETCD node
etcd_endpoints = etcd_cluster.get_client_endpoints()
logger.info(f"ETCD endpoint: {etcd_endpoints}")
# Step 3: Start the frontend with ETCD endpoint
with DynamoFrontendProcess(request, etcd_endpoints) as frontend:
logger.info("Frontend started successfully")
# Step 4: Start the prefill worker
with DynamoWorkerProcess(
request, etcd_endpoints, is_prefill=True
) as prefill_worker:
logger.info("Prefill worker started successfully")
# Step 5: Start the decode worker
with DynamoWorkerProcess(
request, etcd_endpoints, is_prefill=False
) as decode_worker:
logger.info("Decode worker started successfully")
# Step 6: Send inference request to verify system is working
logger.info("Sending inference request")
result = send_inference_request("What is 2+2? The answer is")
assert (
"4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result}'"
logger.info(
"System is working correctly with single ETCD node in disaggregated mode"
)
# Step 7: Terminate the ETCD node
logger.info("Terminating single ETCD node")
etcd_cluster.stop()
# Step 8: Wait and verify frontend and both workers detect the loss
wait_for_processes_to_terminate(
{
"Decode Worker": decode_worker,
"Prefill Worker": prefill_worker,
"Frontend": frontend,
}
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment