Unverified Commit ed8cd590 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

test: Deterministic ETCD client failover tests (#4363)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent 1e120ed0
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::connector::Connector; use super::connector::Connector;
use etcd_client::{LeaseKeepAliveStream, LeaseKeeper};
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
...@@ -45,21 +46,53 @@ pub async fn create_lease( ...@@ -45,21 +46,53 @@ pub async fn create_lease(
async fn keep_alive( async fn keep_alive(
connector: Arc<Connector>, connector: Arc<Connector>,
lease_id: u64, lease_id: u64,
mut ttl: u64, ttl: u64,
token: CancellationToken, token: CancellationToken,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// Deadline when the lease expires
let mut deadline = Instant::now() + Duration::from_secs(ttl); let mut deadline = Instant::now() + Duration::from_secs(ttl);
loop { let mut reconnect = true;
while reconnect {
// Try to establish or re-establish the keep-alive stream // Try to establish or re-establish the keep-alive stream
let (sender, receiver) =
match new_keep_alive_stream(&connector, lease_id, &deadline, &token).await? {
Some(stream) => stream,
None => break, // cancelled
};
// Keep-alive loop with the established stream
reconnect = keep_alive_with_stream(
&connector,
sender,
receiver,
lease_id,
&mut deadline,
&token,
)
.await?;
}
Ok(())
}
/// Establish a new keep-alive stream with automatic retry and reconnection.
///
/// Returns:
/// `Ok(Some((LeaseKeeper, LeaseKeepAliveStream)))` on success.
/// `Ok(None)` if cancelled.
/// `Err` for unrecoverable errors such as deadline exceeded.
async fn new_keep_alive_stream(
connector: &Arc<Connector>,
lease_id: u64,
deadline: &Instant,
token: &CancellationToken,
) -> anyhow::Result<Option<(LeaseKeeper, LeaseKeepAliveStream)>> {
loop {
let mut lease_client = connector.get_client().lease_client(); let mut lease_client = connector.get_client().lease_client();
let (mut heartbeat_sender, mut heartbeat_receiver) = match lease_client match lease_client.keep_alive(lease_id as i64).await {
.keep_alive(lease_id as i64)
.await
{
Ok((sender, receiver)) => { Ok((sender, receiver)) => {
tracing::debug!(lease_id, "Established keep-alive stream"); tracing::debug!(lease_id, "Established keep-alive stream");
(sender, receiver) return Ok(Some((sender, receiver))); // success
} }
Err(e) => { Err(e) => {
tracing::warn!(lease_id, error = %e, "Failed to establish keep-alive stream"); tracing::warn!(lease_id, error = %e, "Failed to establish keep-alive stream");
...@@ -68,84 +101,89 @@ async fn keep_alive( ...@@ -68,84 +101,89 @@ async fn keep_alive(
tokio::select! { tokio::select! {
biased; biased;
reconnect_result = connector.reconnect(deadline) => { reconnect_result = connector.reconnect(*deadline) => {
match reconnect_result { match reconnect_result {
Err(e) => return Err(e), Err(e) => return Err(e), // cannot reconnect
_ => continue, _ => continue, // retry
} }
} }
_ = token.cancelled() => { _ = token.cancelled() => {
tracing::debug!(lease_id, "Cancellation token triggered during reconnection"); tracing::debug!(lease_id, "Cancellation token triggered during reconnection");
return Ok(()); return Ok(None); // cancelled
} }
} }
} }
}; };
}
}
// Keep-alive loop with the established stream /// Keep-alive loop that maintains the lease using the provided sender and receiver.
loop { ///
if deadline < std::time::Instant::now() { /// Returns:
anyhow::bail!( /// `Ok(true)` for recoverable errors such as stream closure that warrant reconnection attempts.
"Unable to refresh lease - deadline exceeded. Check etcd server status" /// `Ok(false)` if cancelled.
); /// `Err` for unrecoverable errors such as lease already expired.
} async fn keep_alive_with_stream(
connector: &Arc<Connector>,
tokio::select! { mut sender: LeaseKeeper,
biased; mut receiver: LeaseKeepAliveStream,
lease_id: u64,
status = heartbeat_receiver.message() => { deadline: &mut Instant,
match status { token: &CancellationToken,
Ok(Some(resp)) => { ) -> anyhow::Result<bool> {
tracing::trace!(lease_id, "keep alive response received: {:?}", resp); loop {
let next_renewal = deadline
// Update ttl and deadline from response .saturating_duration_since(Instant::now())
ttl = resp.ttl() as u64; .div_f64(2.0);
deadline = Instant::now() + Duration::from_secs(ttl);
tokio::select! {
if resp.ttl() == 0 { biased;
anyhow::bail!("Unable to maintain lease - expired or revoked. Check etcd server status");
} status = receiver.message() => {
} match status {
Ok(None) => { Ok(Some(resp)) => {
tracing::warn!(lease_id, "Keep-alive stream unexpectedly ended"); tracing::trace!(lease_id, "keep alive response received: {:?}", resp);
break; // Update deadline from response
} let ttl = resp.ttl();
Err(e) => { if ttl <= 0 {
tracing::warn!(lease_id, error = %e, "Keep-alive stream error"); tracing::error!(lease_id, "Keep-alive lease expired");
break; anyhow::bail!("Unable to maintain lease - expired or revoked. Check etcd server status");
} }
*deadline = Instant::now() + Duration::from_secs(ttl as u64);
}
Ok(None) => {
tracing::warn!(lease_id, "Keep-alive stream unexpectedly ended");
return Ok(true); // Exit to reconnect
}
Err(e) => {
tracing::warn!(lease_id, error = %e, "Keep-alive stream error");
return Ok(true); // Exit to reconnect
} }
} }
}
_ = token.cancelled() => { _ = token.cancelled() => {
tracing::debug!(lease_id, "cancellation token triggered; revoking lease"); tracing::debug!(lease_id, "cancellation token triggered; revoking lease");
if let Err(e) = lease_client.revoke(lease_id as i64).await { let mut lease_client = connector.get_client().lease_client();
tracing::warn!( if let Err(e) = lease_client.revoke(lease_id as i64).await {
lease_id, tracing::warn!(
error = %e, lease_id,
"Failed to revoke lease during cancellation. Cleanup may be incomplete." error = %e,
); "Failed to revoke lease during cancellation. Cleanup may be incomplete."
} );
return Ok(());
} }
return Ok(false);
}
_ = tokio::time::sleep(Duration::from_secs(ttl / 2)) => { _ = tokio::time::sleep(next_renewal) => {
tracing::trace!(lease_id, "sending keep alive"); tracing::trace!(lease_id, "sending keep alive");
if let Err(e) = sender.keep_alive().await {
// if we get a error issuing the heartbeat, set the ttl to 0 tracing::warn!(
// this will allow us to poll the response stream once and the cancellation lease_id,
// token once, then immediately try to tick the heartbeat error = %e,
// this will repeat until either the heartbeat is reestablished or the deadline "Unable to send lease heartbeat. Check etcd server status"
// is exceeded );
if let Err(e) = heartbeat_sender.keep_alive().await {
tracing::warn!(
lease_id,
error = %e,
"Unable to send lease heartbeat. Check etcd server status"
);
ttl = 0;
}
} }
} }
} }
......
...@@ -148,24 +148,29 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -148,24 +148,29 @@ class DynamoWorkerProcess(ManagedProcess):
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_ha_failover_sglang_aggregated(request, predownload_models): def test_etcd_ha_failover_sglang_aggregated(request, predownload_models):
""" """
Test ETCD High Availability with leader failover using SGLang. Test ETCD High Availability with repeated node failures and recoveries using SGLang.
This test: This test:
1. Starts a 3-node ETCD cluster 1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and an SGLang worker 2. Starts NATS, frontend, and an SGLang worker
3. Sends an inference request to verify the system works 3. Cycles through each of the 3 replicas:
4. Terminates the ETCD leader node - Terminate the replica by index
5. Sends another inference request to verify the system still works - Send inference request to verify system still works
- Restart the terminated node
This ensures testing of:
- ETCD leader termination
- Frontend/worker disconnection from their connected ETCD replica
""" """
# Step 1: Start NATS server # Step 1: Start NATS server
with NatsServer(request): with NatsServer(request):
logger.info("NATS server started successfully") logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster # Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster: num_replicas = 3
with EtcdCluster(request, num_replicas=num_replicas) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully") logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes # Get the endpoints for all ETCD nodes
...@@ -182,46 +187,56 @@ def test_etcd_ha_failover_sglang_aggregated(request, predownload_models): ...@@ -182,46 +187,56 @@ def test_etcd_ha_failover_sglang_aggregated(request, predownload_models):
# Small wait to ensure worker is fully ready # Small wait to ensure worker is fully ready
time.sleep(2) time.sleep(2)
# Step 5: Send first inference request to verify system is working # Step 5: Send initial inference request to verify system is working
logger.info("Sending first inference request (before failover)") logger.info("Sending initial inference request")
result1 = send_inference_request("What is 2+2? The answer is") result = send_inference_request("What is 2+2? The answer is")
assert ( assert (
"4" in result1.lower() or "four" in result1.lower() "4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'" ), f"Expected '4' or 'four' in response, got: '{result}'"
# Step 6: Identify and terminate the ETCD leader # Step 6: Cycle through each replica to terminate/verify/restart
logger.info("Terminating ETCD leader to test failover") for i in range(num_replicas):
terminated_idx = etcd_cluster.terminate_leader() # Terminate a replica
if terminated_idx is None: logger.info(f"Iteration {i}: Terminating replica etcd-{i}")
pytest.fail("Failed to identify and terminate ETCD leader") etcd_cluster.terminate_replica(i)
logger.info(f"Terminated ETCD node {terminated_idx}") # Send inference request to verify system still works
logger.info(
f"Iteration {i}: Sending inference request after termination"
)
result = send_inference_request(
"The capital of France is", max_tokens=20
)
assert (
"paris" in result.lower()
), f"Iteration {i}: Expected 'Paris' in response, got: '{result}'"
# Step 7: Send second inference request to verify system still works # Restart the terminated replica
logger.info("Sending second inference request (after failover)") logger.info(f"Iteration {i}: Restarting replica etcd-{i}")
result2 = send_inference_request("The capital of France is") etcd_cluster.restart_replica(i)
assert (
"paris" in result2.lower()
), f"Expected 'Paris' in response, got: '{result2}'"
@pytest.mark.sglang @pytest.mark.sglang
@pytest.mark.gpu_2 @pytest.mark.gpu_2
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_ha_failover_sglang_disaggregated( def test_etcd_ha_failover_sglang_disaggregated(
request, predownload_models, set_ucx_tls_no_mm request, predownload_models, set_ucx_tls_no_mm
): ):
""" """
Test ETCD High Availability with leader failover in disaggregated mode using SGLang. Test ETCD High Availability with repeated node failures and recoveries in disaggregated mode using SGLang.
This test: This test:
1. Starts a 3-node ETCD cluster 1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and both prefill and decode SGLang workers 2. Starts NATS, frontend, and both prefill and decode SGLang workers
3. Sends an inference request to verify the system works 3. Cycles through each of the 3 replicas:
4. Terminates the ETCD leader node - Terminate the replica by index
5. Sends another inference request to verify the system still works - Send inference request to verify system still works
- Restart the terminated node
This ensures testing of:
- ETCD leader termination
- Frontend/worker disconnection from their connected ETCD replica
Note: This test requires 2 GPUs to run decode and prefill workers on separate GPUs. Note: This test requires 2 GPUs to run decode and prefill workers on separate GPUs.
""" """
...@@ -230,7 +245,8 @@ def test_etcd_ha_failover_sglang_disaggregated( ...@@ -230,7 +245,8 @@ def test_etcd_ha_failover_sglang_disaggregated(
logger.info("NATS server started successfully") logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster # Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster: num_replicas = 3
with EtcdCluster(request, num_replicas=num_replicas) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully") logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes # Get the endpoints for all ETCD nodes
...@@ -251,34 +267,39 @@ def test_etcd_ha_failover_sglang_disaggregated( ...@@ -251,34 +267,39 @@ def test_etcd_ha_failover_sglang_disaggregated(
# Small wait to ensure workers are fully ready # Small wait to ensure workers are fully ready
time.sleep(2) time.sleep(2)
# Step 6: Send first inference request to verify system is working # Step 6: Send initial inference request to verify system is working
logger.info("Sending first inference request (before failover)") logger.info("Sending initial inference request")
result1 = send_inference_request("What is 2+2? The answer is") result = send_inference_request("What is 2+2? The answer is")
assert ( assert (
"4" in result1.lower() or "four" in result1.lower() "4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'" ), f"Expected '4' or 'four' in response, got: '{result}'"
# Step 7: Identify and terminate the ETCD leader # Step 7: Cycle through each replica to terminate/verify/restart
logger.info("Terminating ETCD leader to test failover") for i in range(num_replicas):
terminated_idx = etcd_cluster.terminate_leader() # Terminate a replica
if terminated_idx is None: logger.info(f"Iteration {i}: Terminating replica etcd-{i}")
pytest.fail("Failed to identify and terminate ETCD leader") etcd_cluster.terminate_replica(i)
logger.info(f"Terminated ETCD node {terminated_idx}") # Send inference request to verify system still works
logger.info(
f"Iteration {i}: Sending inference request after termination"
)
result = send_inference_request(
"The capital of France is", max_tokens=20
)
assert (
"paris" in result.lower()
), f"Iteration {i}: Expected 'Paris' in response, got: '{result}'"
# Step 8: Send second inference request to verify system still works # Restart the terminated replica
logger.info("Sending second inference request (after failover)") logger.info(f"Iteration {i}: Restarting replica etcd-{i}")
result2 = send_inference_request("The capital of France is") etcd_cluster.restart_replica(i)
assert (
"paris" in result2.lower()
), f"Expected 'Paris' in response, got: '{result2}'"
@pytest.mark.sglang @pytest.mark.sglang
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_non_ha_shutdown_sglang_aggregated(request, predownload_models): def test_etcd_non_ha_shutdown_sglang_aggregated(request, predownload_models):
""" """
Test that frontend and worker shut down when single ETCD node is terminated using SGLang. Test that frontend and worker shut down when single ETCD node is terminated using SGLang.
...@@ -335,7 +356,6 @@ def test_etcd_non_ha_shutdown_sglang_aggregated(request, predownload_models): ...@@ -335,7 +356,6 @@ def test_etcd_non_ha_shutdown_sglang_aggregated(request, predownload_models):
@pytest.mark.gpu_2 @pytest.mark.gpu_2
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_non_ha_shutdown_sglang_disaggregated( def test_etcd_non_ha_shutdown_sglang_disaggregated(
request, predownload_models, set_ucx_tls_no_mm request, predownload_models, set_ucx_tls_no_mm
): ):
......
...@@ -134,24 +134,29 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -134,24 +134,29 @@ class DynamoWorkerProcess(ManagedProcess):
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_ha_failover_trtllm_aggregated(request, predownload_models): def test_etcd_ha_failover_trtllm_aggregated(request, predownload_models):
""" """
Test ETCD High Availability with leader failover for TRT-LLM in aggregated mode. Test ETCD High Availability with repeated node failures and recoveries for TRT-LLM in aggregated mode.
This test: This test:
1. Starts a 3-node ETCD cluster 1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and an aggregated TRT-LLM worker 2. Starts NATS, frontend, and an aggregated TRT-LLM worker
3. Sends an inference request to verify the system works 3. Cycles through each of the 3 replicas:
4. Terminates the ETCD leader node - Terminate the replica by index
5. Sends another inference request to verify the system still works - Send inference request to verify system still works
- Restart the terminated node
This ensures testing of:
- ETCD leader termination
- Frontend/worker disconnection from their connected ETCD replica
""" """
# Step 1: Start NATS server # Step 1: Start NATS server
with NatsServer(request): with NatsServer(request):
logger.info("NATS server started successfully") logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster # Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster: num_replicas = 3
with EtcdCluster(request, num_replicas=num_replicas) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully") logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes # Get the endpoints for all ETCD nodes
...@@ -168,53 +173,64 @@ def test_etcd_ha_failover_trtllm_aggregated(request, predownload_models): ...@@ -168,53 +173,64 @@ def test_etcd_ha_failover_trtllm_aggregated(request, predownload_models):
): ):
logger.info("Aggregated TRT-LLM worker started successfully") logger.info("Aggregated TRT-LLM worker started successfully")
# Step 5: Send first inference request to verify system is working # Step 5: Send initial inference request to verify system is working
logger.info("Sending first inference request (before failover)") logger.info("Sending initial inference request")
result1 = send_inference_request("What is 2+2? The answer is") result = send_inference_request("What is 2+2? The answer is")
assert ( assert (
"4" in result1.lower() or "four" in result1.lower() "4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'" ), f"Expected '4' or 'four' in response, got: '{result}'"
# Step 6: Identify and terminate the ETCD leader # Step 6: Cycle through each replica to terminate/verify/restart
logger.info("Terminating ETCD leader to test failover") for i in range(num_replicas):
terminated_idx = etcd_cluster.terminate_leader() # Terminate a replica
if terminated_idx is None: logger.info(f"Iteration {i}: Terminating replica etcd-{i}")
pytest.fail("Failed to identify and terminate ETCD leader") etcd_cluster.terminate_replica(i)
logger.info(f"Terminated ETCD node {terminated_idx}") # Send inference request to verify system still works
logger.info(
f"Iteration {i}: Sending inference request after termination"
)
result = send_inference_request(
"The capital of France is", max_tokens=20
)
assert (
"paris" in result.lower()
), f"Iteration {i}: Expected 'Paris' in response, got: '{result}'"
# Step 7: Send second inference request to verify system still works # Restart the terminated replica
logger.info("Sending second inference request (after failover)") logger.info(f"Iteration {i}: Restarting replica etcd-{i}")
result2 = send_inference_request("The capital of France is") etcd_cluster.restart_replica(i)
assert (
"paris" in result2.lower()
), f"Expected 'Paris' in response, got: '{result2}'"
@pytest.mark.trtllm_marker @pytest.mark.trtllm_marker
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_ha_failover_trtllm_disaggregated( def test_etcd_ha_failover_trtllm_disaggregated(
request, predownload_models, set_ucx_tls_no_mm request, predownload_models, set_ucx_tls_no_mm
): ):
""" """
Test ETCD High Availability with leader failover for TRT-LLM in disaggregated mode. Test ETCD High Availability with repeated node failures and recoveries for TRT-LLM in disaggregated mode.
This test: This test:
1. Starts a 3-node ETCD cluster 1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and both prefill and decode TRT-LLM workers 2. Starts NATS, frontend, and both prefill and decode TRT-LLM workers
3. Sends an inference request to verify the system works 3. Cycles through each of the 3 replicas:
4. Terminates the ETCD leader node - Terminate the replica by index
5. Sends another inference request to verify the system still works - Send inference request to verify system still works
- Restart the terminated node
This ensures testing of:
- ETCD leader termination
- Frontend/worker disconnection from their connected ETCD replica
""" """
# Step 1: Start NATS server # Step 1: Start NATS server
with NatsServer(request): with NatsServer(request):
logger.info("NATS server started successfully") logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster # Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster: num_replicas = 3
with EtcdCluster(request, num_replicas=num_replicas) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully") logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes # Get the endpoints for all ETCD nodes
...@@ -236,34 +252,39 @@ def test_etcd_ha_failover_trtllm_disaggregated( ...@@ -236,34 +252,39 @@ def test_etcd_ha_failover_trtllm_disaggregated(
# TODO: Fix disagg health checks # TODO: Fix disagg health checks
time.sleep(2) time.sleep(2)
# Step 6: Send first inference request to verify system is working # Step 6: Send initial inference request to verify system is working
logger.info("Sending first inference request (before failover)") logger.info("Sending initial inference request")
result1 = send_inference_request("What is 2+2? The answer is") result = send_inference_request("What is 2+2? The answer is")
assert ( assert (
"4" in result1.lower() or "four" in result1.lower() "4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'" ), f"Expected '4' or 'four' in response, got: '{result}'"
# Step 7: Identify and terminate the ETCD leader # Step 7: Cycle through each replica to terminate/verify/restart
logger.info("Terminating ETCD leader to test failover") for i in range(num_replicas):
terminated_idx = etcd_cluster.terminate_leader() # Terminate a replica
if terminated_idx is None: logger.info(f"Iteration {i}: Terminating replica etcd-{i}")
pytest.fail("Failed to identify and terminate ETCD leader") etcd_cluster.terminate_replica(i)
logger.info(f"Terminated ETCD node {terminated_idx}") # Send inference request to verify system still works
logger.info(
f"Iteration {i}: Sending inference request after termination"
)
result = send_inference_request(
"The capital of France is", max_tokens=20
)
assert (
"paris" in result.lower()
), f"Iteration {i}: Expected 'Paris' in response, got: '{result}'"
# Step 8: Send second inference request to verify system still works # Restart the terminated replica
logger.info("Sending second inference request (after failover)") logger.info(f"Iteration {i}: Restarting replica etcd-{i}")
result2 = send_inference_request("The capital of France is") etcd_cluster.restart_replica(i)
assert (
"paris" in result2.lower()
), f"Expected 'Paris' in response, got: '{result2}'"
@pytest.mark.trtllm_marker @pytest.mark.trtllm_marker
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_non_ha_shutdown_trtllm_aggregated(request, predownload_models): def test_etcd_non_ha_shutdown_trtllm_aggregated(request, predownload_models):
""" """
Test that frontend and worker shut down when single ETCD node is terminated for TRT-LLM in aggregated mode. Test that frontend and worker shut down when single ETCD node is terminated for TRT-LLM in aggregated mode.
...@@ -323,7 +344,6 @@ def test_etcd_non_ha_shutdown_trtllm_aggregated(request, predownload_models): ...@@ -323,7 +344,6 @@ def test_etcd_non_ha_shutdown_trtllm_aggregated(request, predownload_models):
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_non_ha_shutdown_trtllm_disaggregated( def test_etcd_non_ha_shutdown_trtllm_disaggregated(
request, predownload_models, set_ucx_tls_no_mm request, predownload_models, set_ucx_tls_no_mm
): ):
......
...@@ -116,24 +116,29 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -116,24 +116,29 @@ class DynamoWorkerProcess(ManagedProcess):
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_ha_failover_vllm_aggregated(request, predownload_models): def test_etcd_ha_failover_vllm_aggregated(request, predownload_models):
""" """
Test ETCD High Availability with leader failover. Test ETCD High Availability with repeated node failures and recoveries.
This test: This test:
1. Starts a 3-node ETCD cluster 1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and a vLLM worker 2. Starts NATS, frontend, and a vLLM worker
3. Sends an inference request to verify the system works 3. Cycles through each of the 3 replicas:
4. Terminates the ETCD leader node - Terminate the replica by index
5. Sends another inference request to verify the system still works - Send inference request to verify system still works
- Restart the terminated node
This ensures testing of:
- ETCD leader termination
- Frontend/worker disconnection from their connected ETCD replica
""" """
# Step 1: Start NATS server # Step 1: Start NATS server
with NatsServer(request): with NatsServer(request):
logger.info("NATS server started successfully") logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster # Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster: num_replicas = 3
with EtcdCluster(request, num_replicas=num_replicas) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully") logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes # Get the endpoints for all ETCD nodes
...@@ -148,53 +153,64 @@ def test_etcd_ha_failover_vllm_aggregated(request, predownload_models): ...@@ -148,53 +153,64 @@ def test_etcd_ha_failover_vllm_aggregated(request, predownload_models):
with DynamoWorkerProcess(request, etcd_endpoints): with DynamoWorkerProcess(request, etcd_endpoints):
logger.info("Worker started successfully") logger.info("Worker started successfully")
# Step 5: Send first inference request to verify system is working # Step 5: Send initial inference request to verify system is working
logger.info("Sending first inference request (before failover)") logger.info("Sending initial inference request")
result1 = send_inference_request("What is 2+2? The answer is") result = send_inference_request("What is 2+2? The answer is")
assert ( assert (
"4" in result1.lower() or "four" in result1.lower() "4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'" ), f"Expected '4' or 'four' in response, got: '{result}'"
# Step 6: Identify and terminate the ETCD leader # Step 6: Cycle through each replica to terminate/verify/restart
logger.info("Terminating ETCD leader to test failover") for i in range(num_replicas):
terminated_idx = etcd_cluster.terminate_leader() # Terminate a replica
if terminated_idx is None: logger.info(f"Iteration {i}: Terminating replica etcd-{i}")
pytest.fail("Failed to identify and terminate ETCD leader") etcd_cluster.terminate_replica(i)
logger.info(f"Terminated ETCD node {terminated_idx}") # Send inference request to verify system still works
logger.info(
f"Iteration {i}: Sending inference request after termination"
)
result = send_inference_request(
"The capital of France is", max_tokens=20
)
assert (
"paris" in result.lower()
), f"Iteration {i}: Expected 'Paris' in response, got: '{result}'"
# Step 7: Send second inference request to verify system still works # Restart the terminated replica
logger.info("Sending second inference request (after failover)") logger.info(f"Iteration {i}: Restarting replica etcd-{i}")
result2 = send_inference_request("The capital of France is") etcd_cluster.restart_replica(i)
assert (
"paris" in result2.lower()
), f"Expected 'Paris' in response, got: '{result2}'"
@pytest.mark.vllm @pytest.mark.vllm
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_ha_failover_vllm_disaggregated( def test_etcd_ha_failover_vllm_disaggregated(
request, predownload_models, set_ucx_tls_no_mm request, predownload_models, set_ucx_tls_no_mm
): ):
""" """
Test ETCD High Availability with leader failover in disaggregated mode. Test ETCD High Availability with repeated node failures and recoveries in disaggregated mode.
This test: This test:
1. Starts a 3-node ETCD cluster 1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and both prefill and decode vLLM workers 2. Starts NATS, frontend, and both prefill and decode vLLM workers
3. Sends an inference request to verify the system works 3. Cycles through each of the 3 replicas:
4. Terminates the ETCD leader node - Terminate the replica by index
5. Sends another inference request to verify the system still works - Send inference request to verify system still works
- Restart the terminated node
This ensures testing of:
- ETCD leader termination
- Frontend/worker disconnection from their connected ETCD replica
""" """
# Step 1: Start NATS server # Step 1: Start NATS server
with NatsServer(request): with NatsServer(request):
logger.info("NATS server started successfully") logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster # Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster: num_replicas = 3
with EtcdCluster(request, num_replicas=num_replicas) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully") logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes # Get the endpoints for all ETCD nodes
...@@ -213,34 +229,39 @@ def test_etcd_ha_failover_vllm_disaggregated( ...@@ -213,34 +229,39 @@ def test_etcd_ha_failover_vllm_disaggregated(
with DynamoWorkerProcess(request, etcd_endpoints, is_prefill=False): with DynamoWorkerProcess(request, etcd_endpoints, is_prefill=False):
logger.info("Decode worker started successfully") logger.info("Decode worker started successfully")
# Step 6: Send first inference request to verify system is working # Step 6: Send initial inference request to verify system is working
logger.info("Sending first inference request (before failover)") logger.info("Sending initial inference request")
result1 = send_inference_request("What is 2+2? The answer is") result = send_inference_request("What is 2+2? The answer is")
assert ( assert (
"4" in result1.lower() or "four" in result1.lower() "4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'" ), f"Expected '4' or 'four' in response, got: '{result}'"
# Step 7: Identify and terminate the ETCD leader # Step 7: Cycle through each replica to terminate/verify/restart
logger.info("Terminating ETCD leader to test failover") for i in range(num_replicas):
terminated_idx = etcd_cluster.terminate_leader() # Terminate a replica
if terminated_idx is None: logger.info(f"Iteration {i}: Terminating replica etcd-{i}")
pytest.fail("Failed to identify and terminate ETCD leader") etcd_cluster.terminate_replica(i)
logger.info(f"Terminated ETCD node {terminated_idx}") # Send inference request to verify system still works
logger.info(
f"Iteration {i}: Sending inference request after termination"
)
result = send_inference_request(
"The capital of France is", max_tokens=20
)
assert (
"paris" in result.lower()
), f"Iteration {i}: Expected 'Paris' in response, got: '{result}'"
# Step 8: Send second inference request to verify system still works # Restart the terminated replica
logger.info("Sending second inference request (after failover)") logger.info(f"Iteration {i}: Restarting replica etcd-{i}")
result2 = send_inference_request("The capital of France is") etcd_cluster.restart_replica(i)
assert (
"paris" in result2.lower()
), f"Expected 'Paris' in response, got: '{result2}'"
@pytest.mark.vllm @pytest.mark.vllm
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_non_ha_shutdown_vllm_aggregated(request, predownload_models): def test_etcd_non_ha_shutdown_vllm_aggregated(request, predownload_models):
""" """
Test that frontend and worker shut down when single ETCD node is terminated. Test that frontend and worker shut down when single ETCD node is terminated.
...@@ -295,7 +316,6 @@ def test_etcd_non_ha_shutdown_vllm_aggregated(request, predownload_models): ...@@ -295,7 +316,6 @@ def test_etcd_non_ha_shutdown_vllm_aggregated(request, predownload_models):
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_non_ha_shutdown_vllm_disaggregated( def test_etcd_non_ha_shutdown_vllm_disaggregated(
request, predownload_models, set_ucx_tls_no_mm request, predownload_models, set_ucx_tls_no_mm
): ):
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json
import logging import logging
import os import os
import shutil import shutil
import subprocess
import tempfile import tempfile
import time import time
from typing import List, Optional from typing import List, Optional
...@@ -62,6 +64,7 @@ class EtcdReplicaServer(ManagedProcess): ...@@ -62,6 +64,7 @@ class EtcdReplicaServer(ManagedProcess):
data_dir: str, data_dir: str,
log_dir: str, log_dir: str,
timeout: int = 30, timeout: int = 30,
cluster_state: str = "new",
): ):
self.name = name self.name = name
self.client_port = client_port self.client_port = client_port
...@@ -81,15 +84,15 @@ class EtcdReplicaServer(ManagedProcess): ...@@ -81,15 +84,15 @@ class EtcdReplicaServer(ManagedProcess):
"--listen-client-urls", "--listen-client-urls",
f"http://0.0.0.0:{client_port}", f"http://0.0.0.0:{client_port}",
"--advertise-client-urls", "--advertise-client-urls",
f"http://localhost:{client_port}", f"http://127.0.0.1:{client_port}",
"--listen-peer-urls", "--listen-peer-urls",
f"http://0.0.0.0:{peer_port}", f"http://0.0.0.0:{peer_port}",
"--initial-advertise-peer-urls", "--initial-advertise-peer-urls",
f"http://localhost:{peer_port}", f"http://127.0.0.1:{peer_port}",
"--initial-cluster", "--initial-cluster",
initial_cluster, initial_cluster,
"--initial-cluster-state", "--initial-cluster-state",
"new", cluster_state,
"--initial-cluster-token", "--initial-cluster-token",
"etcd-cluster", "etcd-cluster",
] ]
...@@ -108,7 +111,7 @@ class EtcdReplicaServer(ManagedProcess): ...@@ -108,7 +111,7 @@ class EtcdReplicaServer(ManagedProcess):
"""Get the status of this ETCD node""" """Get the status of this ETCD node"""
try: try:
response = requests.post( response = requests.post(
f"http://localhost:{self.client_port}/v3/maintenance/status", f"http://127.0.0.1:{self.client_port}/v3/maintenance/status",
json={}, json={},
timeout=2, timeout=2,
) )
...@@ -118,15 +121,19 @@ class EtcdReplicaServer(ManagedProcess): ...@@ -118,15 +121,19 @@ class EtcdReplicaServer(ManagedProcess):
logger.warning(f"Failed to get status for {self.name}: {e}") logger.warning(f"Failed to get status for {self.name}: {e}")
return {} return {}
def is_leader(self) -> bool: def is_leader(self) -> Optional[bool]:
"""Check if this node is the current leader""" """
Check if this node is the current leader.
Returns: True/False on is leader or None if status cannot be retrieved.
"""
status = self.get_status() status = self.get_status()
# In etcd v3 API, we check if this member ID matches the leader ID # In etcd v3 API, we check if this member ID matches the leader ID
if status: if status:
member_id = status.get("header", {}).get("member_id", "") member_id = status.get("header", {}).get("member_id", "")
leader_id = status.get("leader", "") leader_id = status.get("leader", "")
return member_id == leader_id return member_id == leader_id
return False return None
class EtcdCluster: class EtcdCluster:
...@@ -136,13 +143,11 @@ class EtcdCluster: ...@@ -136,13 +143,11 @@ class EtcdCluster:
self, self,
request, request,
num_replicas: int = 3, num_replicas: int = 3,
base_client_port: int = 2379, base_port: int = 2379,
base_peer_port: int = 12380,
): ):
self.request = request self.request = request
self.num_replicas = num_replicas self.num_replicas = num_replicas
self.base_client_port = base_client_port self.base_port = base_port
self.base_peer_port = base_peer_port
self.replicas: List[Optional[EtcdReplicaServer]] = [] self.replicas: List[Optional[EtcdReplicaServer]] = []
self.data_dirs: List[str] = [] self.data_dirs: List[str] = []
self.log_base_dir = f"{request.node.name}_etcd_cluster" self.log_base_dir = f"{request.node.name}_etcd_cluster"
...@@ -156,123 +161,217 @@ class EtcdCluster: ...@@ -156,123 +161,217 @@ class EtcdCluster:
os.makedirs(self.log_base_dir, exist_ok=True) os.makedirs(self.log_base_dir, exist_ok=True)
def start(self): def _get_initial_cluster(self) -> str:
"""Start ETCD cluster with configured number of replicas""" """Build the initial cluster configuration string"""
logger.info(f"Starting {self.num_replicas}-node ETCD cluster")
# Build initial cluster configuration
initial_cluster_parts = [] initial_cluster_parts = []
for i in range(self.num_replicas): for i in range(self.num_replicas):
name = f"etcd-{i}" name = f"etcd-{i}"
peer_port = self.base_peer_port + i peer_port = self.base_port + (2 * i) + 1
initial_cluster_parts.append(f"{name}=http://localhost:{peer_port}") initial_cluster_parts.append(f"{name}=http://127.0.0.1:{peer_port}")
return ",".join(initial_cluster_parts)
initial_cluster = ",".join(initial_cluster_parts)
def _start_replica(self, idx: int, cluster_state: str = "new") -> EtcdReplicaServer:
# Start each replica """Start a single ETCD replica"""
for i in range(self.num_replicas): name = f"etcd-{idx}"
name = f"etcd-{i}" # e.g. base_port = 2379 -> client_port = 2379, 2381, 2383
client_port = self.base_client_port + i # e.g. base_port = 2379 -> peer_port = 2380, 2382, 2384
peer_port = self.base_peer_port + i client_port = self.base_port + (2 * idx)
data_dir = tempfile.mkdtemp(prefix=f"etcd_{i}_") peer_port = self.base_port + (2 * idx) + 1
log_dir = os.path.join(self.log_base_dir, name)
# Create data dir for the node
data_dir = tempfile.mkdtemp(prefix=f"etcd_{idx}_")
if idx < len(self.data_dirs):
self.data_dirs[idx] = data_dir
else:
self.data_dirs.append(data_dir) self.data_dirs.append(data_dir)
os.makedirs(log_dir, exist_ok=True)
logger.info( log_dir = os.path.join(self.log_base_dir, name)
f"Starting {name} on client port {client_port}, peer port {peer_port}" os.makedirs(log_dir, exist_ok=True)
)
replica = EtcdReplicaServer( logger.info(
request=self.request, f"Starting {name} on client port {client_port}, peer port {peer_port}"
name=name, )
client_port=client_port,
peer_port=peer_port,
initial_cluster=initial_cluster,
data_dir=data_dir,
log_dir=log_dir,
)
replica.__enter__()
self.replicas.append(replica)
logger.info(f"All {self.num_replicas} ETCD replicas started successfully")
# Wait for cluster to stabilize and elect a leader replica = EtcdReplicaServer(
self._wait_for_healthy_cluster(timeout=30) request=self.request,
name=name,
client_port=client_port,
peer_port=peer_port,
initial_cluster=self._get_initial_cluster(),
data_dir=data_dir,
log_dir=log_dir,
cluster_state=cluster_state,
)
leader_idx = self.find_leader() replica.__enter__()
if leader_idx is not None: return replica
logger.info(f"Initial leader elected: etcd-{leader_idx}")
else:
logger.warning("No leader elected yet")
def _wait_for_healthy_cluster(self, timeout: int = 30): def _wait_for_healthy_cluster(self, timeout: int = 30):
"""Wait for all replicas to be healthy and responsive. """Wait for cluster to be healthy and elected leader."""
logger.info("Waiting for cluster to become healthy...")
Args:
timeout: Maximum time to wait in seconds
Raises:
RuntimeError: If cluster doesn't become healthy within timeout
"""
logger.info("Waiting for all replicas to be healthy...")
start_time = time.time() start_time = time.time()
while time.time() - start_time < timeout: while time.time() - start_time < timeout:
time.sleep(1) # Check if a leader is elected indicating cluster health
is_healthy = True
# Check if all replicas are responding leader_id = None
all_healthy = True
for i, replica in enumerate(self.replicas): for i, replica in enumerate(self.replicas):
if replica: if replica:
status = replica.get_status() is_leader = replica.is_leader()
if not status: if is_leader is None:
logger.debug(f"etcd-{i} not yet responsive") is_healthy = False
all_healthy = False
break break
if is_leader is True:
if all_healthy: if leader_id is not None:
logger.info("All replicas are healthy") raise RuntimeError(
f"Multiple leaders detected in ETCD cluster etcd-{leader_id} and etcd-{i}"
)
leader_id = i
if is_healthy and leader_id is not None:
logger.info(f"Cluster is healthy with leader at etcd-{leader_id}")
return return
time.sleep(1)
raise RuntimeError(f"ETCD cluster failed to become healthy within {timeout}s") raise RuntimeError(f"ETCD cluster failed to become healthy within {timeout}s")
def find_leader(self) -> Optional[int]: def _replace_member(self, idx: int):
"""Find which replica is currently the leader""" """Remove old member and add new member to the cluster using etcdctl"""
for i, replica in enumerate(self.replicas): # Find a healthy replica to perform member operations
if replica and replica.is_leader(): healthy_replica = None
return i for i, r in enumerate(self.replicas):
return None if r and i != idx:
healthy_replica = r
break
if not healthy_replica:
raise RuntimeError("No healthy replica found to perform member operations")
name = f"etcd-{idx}"
peer_port = self.base_port + (2 * idx) + 1
peer_url = f"http://127.0.0.1:{peer_port}"
# Set ETCDCTL_ENDPOINTS for etcdctl commands
etcdctl_env = os.environ.copy()
etcdctl_env[
"ETCDCTL_ENDPOINTS"
] = f"http://127.0.0.1:{healthy_replica.client_port}"
etcdctl_env["ETCDCTL_API"] = "3"
# First, get member list to find the old member's ID
logger.info(f"Getting member list to find {name}")
try:
result = subprocess.run(
["etcdctl", "member", "list", "--write-out=json"],
env=etcdctl_env,
capture_output=True,
text=True,
timeout=5,
)
if result.returncode == 0:
members = json.loads(result.stdout).get("members", [])
old_member_id = None
for member in members:
if member.get("name") == name:
old_member_id = member.get("ID")
break
def terminate_leader(self) -> Optional[int]: if old_member_id:
"""Terminate the current leader and return its index""" # Convert member ID to hex format (etcdctl expects hex)
leader_idx = self.find_leader() hex_member_id = format(int(old_member_id), "x")
logger.info(
f"Removing member with ID {old_member_id} (hex: {hex_member_id})"
)
remove_result = subprocess.run(
["etcdctl", "member", "remove", hex_member_id],
env=etcdctl_env,
capture_output=True,
text=True,
timeout=5,
)
if remove_result.returncode != 0:
raise RuntimeError(
f"Failed to remove old member: {remove_result.stderr}"
)
logger.info(f"Successfully removed old member {name}")
except Exception as e:
raise RuntimeError(f"Error during member removal: {e}")
if leader_idx is None: # Add the new member to the cluster
logger.warning("No leader found to terminate") logger.info(f"Adding new member {name} to cluster with peer URL {peer_url}")
return None try:
add_result = subprocess.run(
["etcdctl", "member", "add", name, f"--peer-urls={peer_url}"],
env=etcdctl_env,
capture_output=True,
text=True,
timeout=5,
)
if add_result.returncode != 0:
raise RuntimeError(f"Failed to add new member: {add_result.stderr}")
logger.info(f"Successfully added new member {name}")
except Exception as e:
raise RuntimeError(f"Error adding new member: {e}")
def start(self):
"""Start ETCD cluster with configured number of replicas"""
logger.info(f"Starting {self.num_replicas}-node ETCD cluster")
logger.info(f"Terminating current leader: etcd-{leader_idx}") # Start each replica
replica = self.replicas[leader_idx] for i in range(self.num_replicas):
replica = self._start_replica(i, cluster_state="new")
self.replicas.append(replica)
if replica: logger.info(f"All {self.num_replicas} ETCD replicas started successfully")
replica.__exit__(None, None, None)
self.replicas[leader_idx] = None
logger.info(f"Leader etcd-{leader_idx} has been terminated")
return leader_idx # Wait for cluster to stabilize
self._wait_for_healthy_cluster()
def get_client_endpoints(self) -> List[str]: def get_client_endpoints(self) -> List[str]:
"""Get list of active client endpoints""" """Get list of active client endpoints"""
endpoints = [] endpoints = []
for i, replica in enumerate(self.replicas): for i, replica in enumerate(self.replicas):
if replica: # Only include active replicas if replica: # Only include active replicas
client_port = self.base_client_port + i client_port = self.base_port + (2 * i)
endpoints.append(f"http://localhost:{client_port}") endpoints.append(f"http://127.0.0.1:{client_port}")
return endpoints return endpoints
def terminate_replica(self, idx: int):
"""Terminate a specific replica by index."""
if idx < 0 or idx >= self.num_replicas:
raise RuntimeError(f"Invalid replica index: {idx}")
replica = self.replicas[idx]
if not replica:
raise RuntimeError(f"Replica etcd-{idx} is already terminated")
replica.__exit__(None, None, None)
self.replicas[idx] = None
logger.info(f"Terminated replica etcd-{idx}")
def restart_replica(self, idx: int):
"""Restart a terminated replica"""
if idx < 0 or idx >= self.num_replicas:
raise RuntimeError(f"Invalid replica index: {idx}")
if self.replicas[idx] is not None:
raise RuntimeError(f"Replica etcd-{idx} is already running")
# Make sure the cluster is healthy before restarting
self._wait_for_healthy_cluster()
# Remove old member and add new member
self._replace_member(idx)
# Start the replica with existing cluster state
replica = self._start_replica(idx, cluster_state="existing")
self.replicas[idx] = replica
# Wait for cluster to stabilize
self._wait_for_healthy_cluster()
def stop(self): def stop(self):
"""Clean up all replicas and temporary directories""" """Clean up all replicas and temporary directories"""
logger.info("Cleaning up ETCD cluster") logger.info("Cleaning up ETCD cluster")
...@@ -329,10 +428,10 @@ def send_inference_request(prompt: str, max_tokens: int = 50) -> str: ...@@ -329,10 +428,10 @@ def send_inference_request(prompt: str, max_tokens: int = 50) -> str:
return text return text
else: else:
pytest.fail( pytest.fail(
f"Inference request failed with code {response.status_code}: {response.text}" f"[ETCD HA regression?] Inference request failed with code {response.status_code}: {response.text}"
) )
except Exception as e: except Exception as e:
pytest.fail(f"Inference request failed: {e}") pytest.fail(f"[ETCD HA regression?] Inference request failed: {e}")
def wait_for_processes_to_terminate( def wait_for_processes_to_terminate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment