Unverified Commit ed8cd590 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

test: Deterministic ETCD client failover tests (#4363)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent 1e120ed0
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use super::connector::Connector;
use etcd_client::{LeaseKeepAliveStream, LeaseKeeper};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
......@@ -45,21 +46,53 @@ pub async fn create_lease(
async fn keep_alive(
connector: Arc<Connector>,
lease_id: u64,
mut ttl: u64,
ttl: u64,
token: CancellationToken,
) -> anyhow::Result<()> {
// Deadline when the lease expires
let mut deadline = Instant::now() + Duration::from_secs(ttl);
loop {
let mut reconnect = true;
while reconnect {
// Try to establish or re-establish the keep-alive stream
let (sender, receiver) =
match new_keep_alive_stream(&connector, lease_id, &deadline, &token).await? {
Some(stream) => stream,
None => break, // cancelled
};
// Keep-alive loop with the established stream
reconnect = keep_alive_with_stream(
&connector,
sender,
receiver,
lease_id,
&mut deadline,
&token,
)
.await?;
}
Ok(())
}
/// Establish a new keep-alive stream with automatic retry and reconnection.
///
/// Returns:
/// `Ok(Some((LeaseKeeper, LeaseKeepAliveStream)))` on success.
/// `Ok(None)` if cancelled.
/// `Err` for unrecoverable errors such as deadline exceeded.
async fn new_keep_alive_stream(
connector: &Arc<Connector>,
lease_id: u64,
deadline: &Instant,
token: &CancellationToken,
) -> anyhow::Result<Option<(LeaseKeeper, LeaseKeepAliveStream)>> {
loop {
let mut lease_client = connector.get_client().lease_client();
let (mut heartbeat_sender, mut heartbeat_receiver) = match lease_client
.keep_alive(lease_id as i64)
.await
{
match lease_client.keep_alive(lease_id as i64).await {
Ok((sender, receiver)) => {
tracing::debug!(lease_id, "Established keep-alive stream");
(sender, receiver)
return Ok(Some((sender, receiver))); // success
}
Err(e) => {
tracing::warn!(lease_id, error = %e, "Failed to establish keep-alive stream");
......@@ -68,58 +101,71 @@ async fn keep_alive(
tokio::select! {
biased;
reconnect_result = connector.reconnect(deadline) => {
reconnect_result = connector.reconnect(*deadline) => {
match reconnect_result {
Err(e) => return Err(e),
_ => continue,
Err(e) => return Err(e), // cannot reconnect
_ => continue, // retry
}
}
_ = token.cancelled() => {
tracing::debug!(lease_id, "Cancellation token triggered during reconnection");
return Ok(());
return Ok(None); // cancelled
}
}
}
};
}
}
// Keep-alive loop with the established stream
/// Keep-alive loop that maintains the lease using the provided sender and receiver.
///
/// Returns:
/// `Ok(true)` for recoverable errors such as stream closure that warrant reconnection attempts.
/// `Ok(false)` if cancelled.
/// `Err` for unrecoverable errors such as lease already expired.
async fn keep_alive_with_stream(
connector: &Arc<Connector>,
mut sender: LeaseKeeper,
mut receiver: LeaseKeepAliveStream,
lease_id: u64,
deadline: &mut Instant,
token: &CancellationToken,
) -> anyhow::Result<bool> {
loop {
if deadline < std::time::Instant::now() {
anyhow::bail!(
"Unable to refresh lease - deadline exceeded. Check etcd server status"
);
}
let next_renewal = deadline
.saturating_duration_since(Instant::now())
.div_f64(2.0);
tokio::select! {
biased;
status = heartbeat_receiver.message() => {
status = receiver.message() => {
match status {
Ok(Some(resp)) => {
tracing::trace!(lease_id, "keep alive response received: {:?}", resp);
// Update ttl and deadline from response
ttl = resp.ttl() as u64;
deadline = Instant::now() + Duration::from_secs(ttl);
if resp.ttl() == 0 {
// Update deadline from response
let ttl = resp.ttl();
if ttl <= 0 {
tracing::error!(lease_id, "Keep-alive lease expired");
anyhow::bail!("Unable to maintain lease - expired or revoked. Check etcd server status");
}
*deadline = Instant::now() + Duration::from_secs(ttl as u64);
}
Ok(None) => {
tracing::warn!(lease_id, "Keep-alive stream unexpectedly ended");
break;
return Ok(true); // Exit to reconnect
}
Err(e) => {
tracing::warn!(lease_id, error = %e, "Keep-alive stream error");
break;
return Ok(true); // Exit to reconnect
}
}
}
_ = token.cancelled() => {
tracing::debug!(lease_id, "cancellation token triggered; revoking lease");
let mut lease_client = connector.get_client().lease_client();
if let Err(e) = lease_client.revoke(lease_id as i64).await {
tracing::warn!(
lease_id,
......@@ -127,25 +173,17 @@ async fn keep_alive(
"Failed to revoke lease during cancellation. Cleanup may be incomplete."
);
}
return Ok(());
return Ok(false);
}
_ = tokio::time::sleep(Duration::from_secs(ttl / 2)) => {
_ = tokio::time::sleep(next_renewal) => {
tracing::trace!(lease_id, "sending keep alive");
// if we get a error issuing the heartbeat, set the ttl to 0
// this will allow us to poll the response stream once and the cancellation
// token once, then immediately try to tick the heartbeat
// this will repeat until either the heartbeat is reestablished or the deadline
// is exceeded
if let Err(e) = heartbeat_sender.keep_alive().await {
if let Err(e) = sender.keep_alive().await {
tracing::warn!(
lease_id,
error = %e,
"Unable to send lease heartbeat. Check etcd server status"
);
ttl = 0;
}
}
}
}
......
......@@ -148,24 +148,29 @@ class DynamoWorkerProcess(ManagedProcess):
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_ha_failover_sglang_aggregated(request, predownload_models):
"""
Test ETCD High Availability with leader failover using SGLang.
Test ETCD High Availability with repeated node failures and recoveries using SGLang.
This test:
1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and an SGLang worker
3. Sends an inference request to verify the system works
4. Terminates the ETCD leader node
5. Sends another inference request to verify the system still works
3. Cycles through each of the 3 replicas:
- Terminate the replica by index
- Send inference request to verify system still works
- Restart the terminated node
This ensures testing of:
- ETCD leader termination
- Frontend/worker disconnection from their connected ETCD replica
"""
# Step 1: Start NATS server
with NatsServer(request):
logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster:
num_replicas = 3
with EtcdCluster(request, num_replicas=num_replicas) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes
......@@ -182,46 +187,56 @@ def test_etcd_ha_failover_sglang_aggregated(request, predownload_models):
# Small wait to ensure worker is fully ready
time.sleep(2)
# Step 5: Send first inference request to verify system is working
logger.info("Sending first inference request (before failover)")
result1 = send_inference_request("What is 2+2? The answer is")
# Step 5: Send initial inference request to verify system is working
logger.info("Sending initial inference request")
result = send_inference_request("What is 2+2? The answer is")
assert (
"4" in result1.lower() or "four" in result1.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'"
# Step 6: Identify and terminate the ETCD leader
logger.info("Terminating ETCD leader to test failover")
terminated_idx = etcd_cluster.terminate_leader()
if terminated_idx is None:
pytest.fail("Failed to identify and terminate ETCD leader")
"4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result}'"
logger.info(f"Terminated ETCD node {terminated_idx}")
# Step 6: Cycle through each replica to terminate/verify/restart
for i in range(num_replicas):
# Terminate a replica
logger.info(f"Iteration {i}: Terminating replica etcd-{i}")
etcd_cluster.terminate_replica(i)
# Step 7: Send second inference request to verify system still works
logger.info("Sending second inference request (after failover)")
result2 = send_inference_request("The capital of France is")
# Send inference request to verify system still works
logger.info(
f"Iteration {i}: Sending inference request after termination"
)
result = send_inference_request(
"The capital of France is", max_tokens=20
)
assert (
"paris" in result2.lower()
), f"Expected 'Paris' in response, got: '{result2}'"
"paris" in result.lower()
), f"Iteration {i}: Expected 'Paris' in response, got: '{result}'"
# Restart the terminated replica
logger.info(f"Iteration {i}: Restarting replica etcd-{i}")
etcd_cluster.restart_replica(i)
@pytest.mark.sglang
@pytest.mark.gpu_2
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_ha_failover_sglang_disaggregated(
request, predownload_models, set_ucx_tls_no_mm
):
"""
Test ETCD High Availability with leader failover in disaggregated mode using SGLang.
Test ETCD High Availability with repeated node failures and recoveries in disaggregated mode using SGLang.
This test:
1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and both prefill and decode SGLang workers
3. Sends an inference request to verify the system works
4. Terminates the ETCD leader node
5. Sends another inference request to verify the system still works
3. Cycles through each of the 3 replicas:
- Terminate the replica by index
- Send inference request to verify system still works
- Restart the terminated node
This ensures testing of:
- ETCD leader termination
- Frontend/worker disconnection from their connected ETCD replica
Note: This test requires 2 GPUs to run decode and prefill workers on separate GPUs.
"""
......@@ -230,7 +245,8 @@ def test_etcd_ha_failover_sglang_disaggregated(
logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster:
num_replicas = 3
with EtcdCluster(request, num_replicas=num_replicas) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes
......@@ -251,34 +267,39 @@ def test_etcd_ha_failover_sglang_disaggregated(
# Small wait to ensure workers are fully ready
time.sleep(2)
# Step 6: Send first inference request to verify system is working
logger.info("Sending first inference request (before failover)")
result1 = send_inference_request("What is 2+2? The answer is")
# Step 6: Send initial inference request to verify system is working
logger.info("Sending initial inference request")
result = send_inference_request("What is 2+2? The answer is")
assert (
"4" in result1.lower() or "four" in result1.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'"
# Step 7: Identify and terminate the ETCD leader
logger.info("Terminating ETCD leader to test failover")
terminated_idx = etcd_cluster.terminate_leader()
if terminated_idx is None:
pytest.fail("Failed to identify and terminate ETCD leader")
"4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result}'"
logger.info(f"Terminated ETCD node {terminated_idx}")
# Step 7: Cycle through each replica to terminate/verify/restart
for i in range(num_replicas):
# Terminate a replica
logger.info(f"Iteration {i}: Terminating replica etcd-{i}")
etcd_cluster.terminate_replica(i)
# Step 8: Send second inference request to verify system still works
logger.info("Sending second inference request (after failover)")
result2 = send_inference_request("The capital of France is")
# Send inference request to verify system still works
logger.info(
f"Iteration {i}: Sending inference request after termination"
)
result = send_inference_request(
"The capital of France is", max_tokens=20
)
assert (
"paris" in result2.lower()
), f"Expected 'Paris' in response, got: '{result2}'"
"paris" in result.lower()
), f"Iteration {i}: Expected 'Paris' in response, got: '{result}'"
# Restart the terminated replica
logger.info(f"Iteration {i}: Restarting replica etcd-{i}")
etcd_cluster.restart_replica(i)
@pytest.mark.sglang
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_non_ha_shutdown_sglang_aggregated(request, predownload_models):
"""
Test that frontend and worker shut down when single ETCD node is terminated using SGLang.
......@@ -335,7 +356,6 @@ def test_etcd_non_ha_shutdown_sglang_aggregated(request, predownload_models):
@pytest.mark.gpu_2
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_non_ha_shutdown_sglang_disaggregated(
request, predownload_models, set_ucx_tls_no_mm
):
......
......@@ -134,24 +134,29 @@ class DynamoWorkerProcess(ManagedProcess):
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_ha_failover_trtllm_aggregated(request, predownload_models):
"""
Test ETCD High Availability with leader failover for TRT-LLM in aggregated mode.
Test ETCD High Availability with repeated node failures and recoveries for TRT-LLM in aggregated mode.
This test:
1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and an aggregated TRT-LLM worker
3. Sends an inference request to verify the system works
4. Terminates the ETCD leader node
5. Sends another inference request to verify the system still works
3. Cycles through each of the 3 replicas:
- Terminate the replica by index
- Send inference request to verify system still works
- Restart the terminated node
This ensures testing of:
- ETCD leader termination
- Frontend/worker disconnection from their connected ETCD replica
"""
# Step 1: Start NATS server
with NatsServer(request):
logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster:
num_replicas = 3
with EtcdCluster(request, num_replicas=num_replicas) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes
......@@ -168,53 +173,64 @@ def test_etcd_ha_failover_trtllm_aggregated(request, predownload_models):
):
logger.info("Aggregated TRT-LLM worker started successfully")
# Step 5: Send first inference request to verify system is working
logger.info("Sending first inference request (before failover)")
result1 = send_inference_request("What is 2+2? The answer is")
# Step 5: Send initial inference request to verify system is working
logger.info("Sending initial inference request")
result = send_inference_request("What is 2+2? The answer is")
assert (
"4" in result1.lower() or "four" in result1.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'"
# Step 6: Identify and terminate the ETCD leader
logger.info("Terminating ETCD leader to test failover")
terminated_idx = etcd_cluster.terminate_leader()
if terminated_idx is None:
pytest.fail("Failed to identify and terminate ETCD leader")
"4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result}'"
logger.info(f"Terminated ETCD node {terminated_idx}")
# Step 6: Cycle through each replica to terminate/verify/restart
for i in range(num_replicas):
# Terminate a replica
logger.info(f"Iteration {i}: Terminating replica etcd-{i}")
etcd_cluster.terminate_replica(i)
# Step 7: Send second inference request to verify system still works
logger.info("Sending second inference request (after failover)")
result2 = send_inference_request("The capital of France is")
# Send inference request to verify system still works
logger.info(
f"Iteration {i}: Sending inference request after termination"
)
result = send_inference_request(
"The capital of France is", max_tokens=20
)
assert (
"paris" in result2.lower()
), f"Expected 'Paris' in response, got: '{result2}'"
"paris" in result.lower()
), f"Iteration {i}: Expected 'Paris' in response, got: '{result}'"
# Restart the terminated replica
logger.info(f"Iteration {i}: Restarting replica etcd-{i}")
etcd_cluster.restart_replica(i)
@pytest.mark.trtllm_marker
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_ha_failover_trtllm_disaggregated(
request, predownload_models, set_ucx_tls_no_mm
):
"""
Test ETCD High Availability with leader failover for TRT-LLM in disaggregated mode.
Test ETCD High Availability with repeated node failures and recoveries for TRT-LLM in disaggregated mode.
This test:
1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and both prefill and decode TRT-LLM workers
3. Sends an inference request to verify the system works
4. Terminates the ETCD leader node
5. Sends another inference request to verify the system still works
3. Cycles through each of the 3 replicas:
- Terminate the replica by index
- Send inference request to verify system still works
- Restart the terminated node
This ensures testing of:
- ETCD leader termination
- Frontend/worker disconnection from their connected ETCD replica
"""
# Step 1: Start NATS server
with NatsServer(request):
logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster:
num_replicas = 3
with EtcdCluster(request, num_replicas=num_replicas) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes
......@@ -236,34 +252,39 @@ def test_etcd_ha_failover_trtllm_disaggregated(
# TODO: Fix disagg health checks
time.sleep(2)
# Step 6: Send first inference request to verify system is working
logger.info("Sending first inference request (before failover)")
result1 = send_inference_request("What is 2+2? The answer is")
# Step 6: Send initial inference request to verify system is working
logger.info("Sending initial inference request")
result = send_inference_request("What is 2+2? The answer is")
assert (
"4" in result1.lower() or "four" in result1.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'"
# Step 7: Identify and terminate the ETCD leader
logger.info("Terminating ETCD leader to test failover")
terminated_idx = etcd_cluster.terminate_leader()
if terminated_idx is None:
pytest.fail("Failed to identify and terminate ETCD leader")
"4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result}'"
logger.info(f"Terminated ETCD node {terminated_idx}")
# Step 7: Cycle through each replica to terminate/verify/restart
for i in range(num_replicas):
# Terminate a replica
logger.info(f"Iteration {i}: Terminating replica etcd-{i}")
etcd_cluster.terminate_replica(i)
# Step 8: Send second inference request to verify system still works
logger.info("Sending second inference request (after failover)")
result2 = send_inference_request("The capital of France is")
# Send inference request to verify system still works
logger.info(
f"Iteration {i}: Sending inference request after termination"
)
result = send_inference_request(
"The capital of France is", max_tokens=20
)
assert (
"paris" in result2.lower()
), f"Expected 'Paris' in response, got: '{result2}'"
"paris" in result.lower()
), f"Iteration {i}: Expected 'Paris' in response, got: '{result}'"
# Restart the terminated replica
logger.info(f"Iteration {i}: Restarting replica etcd-{i}")
etcd_cluster.restart_replica(i)
@pytest.mark.trtllm_marker
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_non_ha_shutdown_trtllm_aggregated(request, predownload_models):
"""
Test that frontend and worker shut down when single ETCD node is terminated for TRT-LLM in aggregated mode.
......@@ -323,7 +344,6 @@ def test_etcd_non_ha_shutdown_trtllm_aggregated(request, predownload_models):
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_non_ha_shutdown_trtllm_disaggregated(
request, predownload_models, set_ucx_tls_no_mm
):
......
......@@ -116,24 +116,29 @@ class DynamoWorkerProcess(ManagedProcess):
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_ha_failover_vllm_aggregated(request, predownload_models):
"""
Test ETCD High Availability with leader failover.
Test ETCD High Availability with repeated node failures and recoveries.
This test:
1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and a vLLM worker
3. Sends an inference request to verify the system works
4. Terminates the ETCD leader node
5. Sends another inference request to verify the system still works
3. Cycles through each of the 3 replicas:
- Terminate the replica by index
- Send inference request to verify system still works
- Restart the terminated node
This ensures testing of:
- ETCD leader termination
- Frontend/worker disconnection from their connected ETCD replica
"""
# Step 1: Start NATS server
with NatsServer(request):
logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster:
num_replicas = 3
with EtcdCluster(request, num_replicas=num_replicas) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes
......@@ -148,53 +153,64 @@ def test_etcd_ha_failover_vllm_aggregated(request, predownload_models):
with DynamoWorkerProcess(request, etcd_endpoints):
logger.info("Worker started successfully")
# Step 5: Send first inference request to verify system is working
logger.info("Sending first inference request (before failover)")
result1 = send_inference_request("What is 2+2? The answer is")
# Step 5: Send initial inference request to verify system is working
logger.info("Sending initial inference request")
result = send_inference_request("What is 2+2? The answer is")
assert (
"4" in result1.lower() or "four" in result1.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'"
# Step 6: Identify and terminate the ETCD leader
logger.info("Terminating ETCD leader to test failover")
terminated_idx = etcd_cluster.terminate_leader()
if terminated_idx is None:
pytest.fail("Failed to identify and terminate ETCD leader")
"4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result}'"
logger.info(f"Terminated ETCD node {terminated_idx}")
# Step 6: Cycle through each replica to terminate/verify/restart
for i in range(num_replicas):
# Terminate a replica
logger.info(f"Iteration {i}: Terminating replica etcd-{i}")
etcd_cluster.terminate_replica(i)
# Step 7: Send second inference request to verify system still works
logger.info("Sending second inference request (after failover)")
result2 = send_inference_request("The capital of France is")
# Send inference request to verify system still works
logger.info(
f"Iteration {i}: Sending inference request after termination"
)
result = send_inference_request(
"The capital of France is", max_tokens=20
)
assert (
"paris" in result2.lower()
), f"Expected 'Paris' in response, got: '{result2}'"
"paris" in result.lower()
), f"Iteration {i}: Expected 'Paris' in response, got: '{result}'"
# Restart the terminated replica
logger.info(f"Iteration {i}: Restarting replica etcd-{i}")
etcd_cluster.restart_replica(i)
@pytest.mark.vllm
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_ha_failover_vllm_disaggregated(
request, predownload_models, set_ucx_tls_no_mm
):
"""
Test ETCD High Availability with leader failover in disaggregated mode.
Test ETCD High Availability with repeated node failures and recoveries in disaggregated mode.
This test:
1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and both prefill and decode vLLM workers
3. Sends an inference request to verify the system works
4. Terminates the ETCD leader node
5. Sends another inference request to verify the system still works
3. Cycles through each of the 3 replicas:
- Terminate the replica by index
- Send inference request to verify system still works
- Restart the terminated node
This ensures testing of:
- ETCD leader termination
- Frontend/worker disconnection from their connected ETCD replica
"""
# Step 1: Start NATS server
with NatsServer(request):
logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster:
num_replicas = 3
with EtcdCluster(request, num_replicas=num_replicas) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes
......@@ -213,34 +229,39 @@ def test_etcd_ha_failover_vllm_disaggregated(
with DynamoWorkerProcess(request, etcd_endpoints, is_prefill=False):
logger.info("Decode worker started successfully")
# Step 6: Send first inference request to verify system is working
logger.info("Sending first inference request (before failover)")
result1 = send_inference_request("What is 2+2? The answer is")
# Step 6: Send initial inference request to verify system is working
logger.info("Sending initial inference request")
result = send_inference_request("What is 2+2? The answer is")
assert (
"4" in result1.lower() or "four" in result1.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'"
# Step 7: Identify and terminate the ETCD leader
logger.info("Terminating ETCD leader to test failover")
terminated_idx = etcd_cluster.terminate_leader()
if terminated_idx is None:
pytest.fail("Failed to identify and terminate ETCD leader")
"4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result}'"
logger.info(f"Terminated ETCD node {terminated_idx}")
# Step 7: Cycle through each replica to terminate/verify/restart
for i in range(num_replicas):
# Terminate a replica
logger.info(f"Iteration {i}: Terminating replica etcd-{i}")
etcd_cluster.terminate_replica(i)
# Step 8: Send second inference request to verify system still works
logger.info("Sending second inference request (after failover)")
result2 = send_inference_request("The capital of France is")
# Send inference request to verify system still works
logger.info(
f"Iteration {i}: Sending inference request after termination"
)
result = send_inference_request(
"The capital of France is", max_tokens=20
)
assert (
"paris" in result2.lower()
), f"Expected 'Paris' in response, got: '{result2}'"
"paris" in result.lower()
), f"Iteration {i}: Expected 'Paris' in response, got: '{result}'"
# Restart the terminated replica
logger.info(f"Iteration {i}: Restarting replica etcd-{i}")
etcd_cluster.restart_replica(i)
@pytest.mark.vllm
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_non_ha_shutdown_vllm_aggregated(request, predownload_models):
"""
Test that frontend and worker shut down when single ETCD node is terminated.
......@@ -295,7 +316,6 @@ def test_etcd_non_ha_shutdown_vllm_aggregated(request, predownload_models):
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_non_ha_shutdown_vllm_disaggregated(
request, predownload_models, set_ucx_tls_no_mm
):
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import json
import logging
import os
import shutil
import subprocess
import tempfile
import time
from typing import List, Optional
......@@ -62,6 +64,7 @@ class EtcdReplicaServer(ManagedProcess):
data_dir: str,
log_dir: str,
timeout: int = 30,
cluster_state: str = "new",
):
self.name = name
self.client_port = client_port
......@@ -81,15 +84,15 @@ class EtcdReplicaServer(ManagedProcess):
"--listen-client-urls",
f"http://0.0.0.0:{client_port}",
"--advertise-client-urls",
f"http://localhost:{client_port}",
f"http://127.0.0.1:{client_port}",
"--listen-peer-urls",
f"http://0.0.0.0:{peer_port}",
"--initial-advertise-peer-urls",
f"http://localhost:{peer_port}",
f"http://127.0.0.1:{peer_port}",
"--initial-cluster",
initial_cluster,
"--initial-cluster-state",
"new",
cluster_state,
"--initial-cluster-token",
"etcd-cluster",
]
......@@ -108,7 +111,7 @@ class EtcdReplicaServer(ManagedProcess):
"""Get the status of this ETCD node"""
try:
response = requests.post(
f"http://localhost:{self.client_port}/v3/maintenance/status",
f"http://127.0.0.1:{self.client_port}/v3/maintenance/status",
json={},
timeout=2,
)
......@@ -118,15 +121,19 @@ class EtcdReplicaServer(ManagedProcess):
logger.warning(f"Failed to get status for {self.name}: {e}")
return {}
def is_leader(self) -> bool:
"""Check if this node is the current leader"""
def is_leader(self) -> Optional[bool]:
"""
Check if this node is the current leader.
Returns: True/False on is leader or None if status cannot be retrieved.
"""
status = self.get_status()
# In etcd v3 API, we check if this member ID matches the leader ID
if status:
member_id = status.get("header", {}).get("member_id", "")
leader_id = status.get("leader", "")
return member_id == leader_id
return False
return None
class EtcdCluster:
......@@ -136,13 +143,11 @@ class EtcdCluster:
self,
request,
num_replicas: int = 3,
base_client_port: int = 2379,
base_peer_port: int = 12380,
base_port: int = 2379,
):
self.request = request
self.num_replicas = num_replicas
self.base_client_port = base_client_port
self.base_peer_port = base_peer_port
self.base_port = base_port
self.replicas: List[Optional[EtcdReplicaServer]] = []
self.data_dirs: List[str] = []
self.log_base_dir = f"{request.node.name}_etcd_cluster"
......@@ -156,28 +161,31 @@ class EtcdCluster:
os.makedirs(self.log_base_dir, exist_ok=True)
def start(self):
"""Start ETCD cluster with configured number of replicas"""
logger.info(f"Starting {self.num_replicas}-node ETCD cluster")
# Build initial cluster configuration
def _get_initial_cluster(self) -> str:
"""Build the initial cluster configuration string"""
initial_cluster_parts = []
for i in range(self.num_replicas):
name = f"etcd-{i}"
peer_port = self.base_peer_port + i
initial_cluster_parts.append(f"{name}=http://localhost:{peer_port}")
initial_cluster = ",".join(initial_cluster_parts)
peer_port = self.base_port + (2 * i) + 1
initial_cluster_parts.append(f"{name}=http://127.0.0.1:{peer_port}")
return ",".join(initial_cluster_parts)
def _start_replica(self, idx: int, cluster_state: str = "new") -> EtcdReplicaServer:
"""Start a single ETCD replica"""
name = f"etcd-{idx}"
# e.g. base_port = 2379 -> client_port = 2379, 2381, 2383
# e.g. base_port = 2379 -> peer_port = 2380, 2382, 2384
client_port = self.base_port + (2 * idx)
peer_port = self.base_port + (2 * idx) + 1
# Create data dir for the node
data_dir = tempfile.mkdtemp(prefix=f"etcd_{idx}_")
if idx < len(self.data_dirs):
self.data_dirs[idx] = data_dir
else:
self.data_dirs.append(data_dir)
# Start each replica
for i in range(self.num_replicas):
name = f"etcd-{i}"
client_port = self.base_client_port + i
peer_port = self.base_peer_port + i
data_dir = tempfile.mkdtemp(prefix=f"etcd_{i}_")
log_dir = os.path.join(self.log_base_dir, name)
self.data_dirs.append(data_dir)
os.makedirs(log_dir, exist_ok=True)
logger.info(
......@@ -189,90 +197,181 @@ class EtcdCluster:
name=name,
client_port=client_port,
peer_port=peer_port,
initial_cluster=initial_cluster,
initial_cluster=self._get_initial_cluster(),
data_dir=data_dir,
log_dir=log_dir,
cluster_state=cluster_state,
)
replica.__enter__()
self.replicas.append(replica)
logger.info(f"All {self.num_replicas} ETCD replicas started successfully")
# Wait for cluster to stabilize and elect a leader
self._wait_for_healthy_cluster(timeout=30)
leader_idx = self.find_leader()
if leader_idx is not None:
logger.info(f"Initial leader elected: etcd-{leader_idx}")
else:
logger.warning("No leader elected yet")
return replica
def _wait_for_healthy_cluster(self, timeout: int = 30):
"""Wait for all replicas to be healthy and responsive.
Args:
timeout: Maximum time to wait in seconds
Raises:
RuntimeError: If cluster doesn't become healthy within timeout
"""
logger.info("Waiting for all replicas to be healthy...")
"""Wait for cluster to be healthy and elected leader."""
logger.info("Waiting for cluster to become healthy...")
start_time = time.time()
while time.time() - start_time < timeout:
time.sleep(1)
# Check if all replicas are responding
all_healthy = True
# Check if a leader is elected indicating cluster health
is_healthy = True
leader_id = None
for i, replica in enumerate(self.replicas):
if replica:
status = replica.get_status()
if not status:
logger.debug(f"etcd-{i} not yet responsive")
all_healthy = False
is_leader = replica.is_leader()
if is_leader is None:
is_healthy = False
break
if is_leader is True:
if leader_id is not None:
raise RuntimeError(
f"Multiple leaders detected in ETCD cluster etcd-{leader_id} and etcd-{i}"
)
leader_id = i
if all_healthy:
logger.info("All replicas are healthy")
if is_healthy and leader_id is not None:
logger.info(f"Cluster is healthy with leader at etcd-{leader_id}")
return
time.sleep(1)
raise RuntimeError(f"ETCD cluster failed to become healthy within {timeout}s")
def find_leader(self) -> Optional[int]:
"""Find which replica is currently the leader"""
for i, replica in enumerate(self.replicas):
if replica and replica.is_leader():
return i
return None
def _replace_member(self, idx: int):
"""Remove old member and add new member to the cluster using etcdctl"""
# Find a healthy replica to perform member operations
healthy_replica = None
for i, r in enumerate(self.replicas):
if r and i != idx:
healthy_replica = r
break
def terminate_leader(self) -> Optional[int]:
"""Terminate the current leader and return its index"""
leader_idx = self.find_leader()
if not healthy_replica:
raise RuntimeError("No healthy replica found to perform member operations")
if leader_idx is None:
logger.warning("No leader found to terminate")
return None
name = f"etcd-{idx}"
peer_port = self.base_port + (2 * idx) + 1
peer_url = f"http://127.0.0.1:{peer_port}"
logger.info(f"Terminating current leader: etcd-{leader_idx}")
replica = self.replicas[leader_idx]
# Set ETCDCTL_ENDPOINTS for etcdctl commands
etcdctl_env = os.environ.copy()
etcdctl_env[
"ETCDCTL_ENDPOINTS"
] = f"http://127.0.0.1:{healthy_replica.client_port}"
etcdctl_env["ETCDCTL_API"] = "3"
if replica:
replica.__exit__(None, None, None)
self.replicas[leader_idx] = None
logger.info(f"Leader etcd-{leader_idx} has been terminated")
# First, get member list to find the old member's ID
logger.info(f"Getting member list to find {name}")
try:
result = subprocess.run(
["etcdctl", "member", "list", "--write-out=json"],
env=etcdctl_env,
capture_output=True,
text=True,
timeout=5,
)
if result.returncode == 0:
members = json.loads(result.stdout).get("members", [])
old_member_id = None
for member in members:
if member.get("name") == name:
old_member_id = member.get("ID")
break
if old_member_id:
# Convert member ID to hex format (etcdctl expects hex)
hex_member_id = format(int(old_member_id), "x")
logger.info(
f"Removing member with ID {old_member_id} (hex: {hex_member_id})"
)
remove_result = subprocess.run(
["etcdctl", "member", "remove", hex_member_id],
env=etcdctl_env,
capture_output=True,
text=True,
timeout=5,
)
if remove_result.returncode != 0:
raise RuntimeError(
f"Failed to remove old member: {remove_result.stderr}"
)
logger.info(f"Successfully removed old member {name}")
except Exception as e:
raise RuntimeError(f"Error during member removal: {e}")
# Add the new member to the cluster
logger.info(f"Adding new member {name} to cluster with peer URL {peer_url}")
try:
add_result = subprocess.run(
["etcdctl", "member", "add", name, f"--peer-urls={peer_url}"],
env=etcdctl_env,
capture_output=True,
text=True,
timeout=5,
)
if add_result.returncode != 0:
raise RuntimeError(f"Failed to add new member: {add_result.stderr}")
logger.info(f"Successfully added new member {name}")
except Exception as e:
raise RuntimeError(f"Error adding new member: {e}")
def start(self):
"""Start ETCD cluster with configured number of replicas"""
logger.info(f"Starting {self.num_replicas}-node ETCD cluster")
return leader_idx
# Start each replica
for i in range(self.num_replicas):
replica = self._start_replica(i, cluster_state="new")
self.replicas.append(replica)
logger.info(f"All {self.num_replicas} ETCD replicas started successfully")
# Wait for cluster to stabilize
self._wait_for_healthy_cluster()
def get_client_endpoints(self) -> List[str]:
"""Get list of active client endpoints"""
endpoints = []
for i, replica in enumerate(self.replicas):
if replica: # Only include active replicas
client_port = self.base_client_port + i
endpoints.append(f"http://localhost:{client_port}")
client_port = self.base_port + (2 * i)
endpoints.append(f"http://127.0.0.1:{client_port}")
return endpoints
def terminate_replica(self, idx: int):
"""Terminate a specific replica by index."""
if idx < 0 or idx >= self.num_replicas:
raise RuntimeError(f"Invalid replica index: {idx}")
replica = self.replicas[idx]
if not replica:
raise RuntimeError(f"Replica etcd-{idx} is already terminated")
replica.__exit__(None, None, None)
self.replicas[idx] = None
logger.info(f"Terminated replica etcd-{idx}")
def restart_replica(self, idx: int):
"""Restart a terminated replica"""
if idx < 0 or idx >= self.num_replicas:
raise RuntimeError(f"Invalid replica index: {idx}")
if self.replicas[idx] is not None:
raise RuntimeError(f"Replica etcd-{idx} is already running")
# Make sure the cluster is healthy before restarting
self._wait_for_healthy_cluster()
# Remove old member and add new member
self._replace_member(idx)
# Start the replica with existing cluster state
replica = self._start_replica(idx, cluster_state="existing")
self.replicas[idx] = replica
# Wait for cluster to stabilize
self._wait_for_healthy_cluster()
def stop(self):
"""Clean up all replicas and temporary directories"""
logger.info("Cleaning up ETCD cluster")
......@@ -329,10 +428,10 @@ def send_inference_request(prompt: str, max_tokens: int = 50) -> str:
return text
else:
pytest.fail(
f"Inference request failed with code {response.status_code}: {response.text}"
f"[ETCD HA regression?] Inference request failed with code {response.status_code}: {response.text}"
)
except Exception as e:
pytest.fail(f"Inference request failed: {e}")
pytest.fail(f"[ETCD HA regression?] Inference request failed: {e}")
def wait_for_processes_to_terminate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment