Unverified Commit ed8cd590 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

test: Deterministic ETCD client failover tests (#4363)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent 1e120ed0
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::connector::Connector; use super::connector::Connector;
use etcd_client::{LeaseKeepAliveStream, LeaseKeeper};
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken; use tokio_util::sync::CancellationToken;
...@@ -45,21 +46,53 @@ pub async fn create_lease( ...@@ -45,21 +46,53 @@ pub async fn create_lease(
async fn keep_alive( async fn keep_alive(
connector: Arc<Connector>, connector: Arc<Connector>,
lease_id: u64, lease_id: u64,
mut ttl: u64, ttl: u64,
token: CancellationToken, token: CancellationToken,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// Deadline when the lease expires
let mut deadline = Instant::now() + Duration::from_secs(ttl); let mut deadline = Instant::now() + Duration::from_secs(ttl);
loop { let mut reconnect = true;
while reconnect {
// Try to establish or re-establish the keep-alive stream // Try to establish or re-establish the keep-alive stream
let (sender, receiver) =
match new_keep_alive_stream(&connector, lease_id, &deadline, &token).await? {
Some(stream) => stream,
None => break, // cancelled
};
// Keep-alive loop with the established stream
reconnect = keep_alive_with_stream(
&connector,
sender,
receiver,
lease_id,
&mut deadline,
&token,
)
.await?;
}
Ok(())
}
/// Establish a new keep-alive stream with automatic retry and reconnection.
///
/// Returns:
/// `Ok(Some((LeaseKeeper, LeaseKeepAliveStream)))` on success.
/// `Ok(None)` if cancelled.
/// `Err` for unrecoverable errors such as deadline exceeded.
async fn new_keep_alive_stream(
connector: &Arc<Connector>,
lease_id: u64,
deadline: &Instant,
token: &CancellationToken,
) -> anyhow::Result<Option<(LeaseKeeper, LeaseKeepAliveStream)>> {
loop {
let mut lease_client = connector.get_client().lease_client(); let mut lease_client = connector.get_client().lease_client();
let (mut heartbeat_sender, mut heartbeat_receiver) = match lease_client match lease_client.keep_alive(lease_id as i64).await {
.keep_alive(lease_id as i64)
.await
{
Ok((sender, receiver)) => { Ok((sender, receiver)) => {
tracing::debug!(lease_id, "Established keep-alive stream"); tracing::debug!(lease_id, "Established keep-alive stream");
(sender, receiver) return Ok(Some((sender, receiver))); // success
} }
Err(e) => { Err(e) => {
tracing::warn!(lease_id, error = %e, "Failed to establish keep-alive stream"); tracing::warn!(lease_id, error = %e, "Failed to establish keep-alive stream");
...@@ -68,58 +101,71 @@ async fn keep_alive( ...@@ -68,58 +101,71 @@ async fn keep_alive(
tokio::select! { tokio::select! {
biased; biased;
reconnect_result = connector.reconnect(deadline) => { reconnect_result = connector.reconnect(*deadline) => {
match reconnect_result { match reconnect_result {
Err(e) => return Err(e), Err(e) => return Err(e), // cannot reconnect
_ => continue, _ => continue, // retry
} }
} }
_ = token.cancelled() => { _ = token.cancelled() => {
tracing::debug!(lease_id, "Cancellation token triggered during reconnection"); tracing::debug!(lease_id, "Cancellation token triggered during reconnection");
return Ok(()); return Ok(None); // cancelled
} }
} }
} }
}; };
}
}
// Keep-alive loop with the established stream /// Keep-alive loop that maintains the lease using the provided sender and receiver.
///
/// Returns:
/// `Ok(true)` for recoverable errors such as stream closure that warrant reconnection attempts.
/// `Ok(false)` if cancelled.
/// `Err` for unrecoverable errors such as lease already expired.
async fn keep_alive_with_stream(
connector: &Arc<Connector>,
mut sender: LeaseKeeper,
mut receiver: LeaseKeepAliveStream,
lease_id: u64,
deadline: &mut Instant,
token: &CancellationToken,
) -> anyhow::Result<bool> {
loop { loop {
if deadline < std::time::Instant::now() { let next_renewal = deadline
anyhow::bail!( .saturating_duration_since(Instant::now())
"Unable to refresh lease - deadline exceeded. Check etcd server status" .div_f64(2.0);
);
}
tokio::select! { tokio::select! {
biased; biased;
status = heartbeat_receiver.message() => { status = receiver.message() => {
match status { match status {
Ok(Some(resp)) => { Ok(Some(resp)) => {
tracing::trace!(lease_id, "keep alive response received: {:?}", resp); tracing::trace!(lease_id, "keep alive response received: {:?}", resp);
// Update deadline from response
// Update ttl and deadline from response let ttl = resp.ttl();
ttl = resp.ttl() as u64; if ttl <= 0 {
deadline = Instant::now() + Duration::from_secs(ttl); tracing::error!(lease_id, "Keep-alive lease expired");
if resp.ttl() == 0 {
anyhow::bail!("Unable to maintain lease - expired or revoked. Check etcd server status"); anyhow::bail!("Unable to maintain lease - expired or revoked. Check etcd server status");
} }
*deadline = Instant::now() + Duration::from_secs(ttl as u64);
} }
Ok(None) => { Ok(None) => {
tracing::warn!(lease_id, "Keep-alive stream unexpectedly ended"); tracing::warn!(lease_id, "Keep-alive stream unexpectedly ended");
break; return Ok(true); // Exit to reconnect
} }
Err(e) => { Err(e) => {
tracing::warn!(lease_id, error = %e, "Keep-alive stream error"); tracing::warn!(lease_id, error = %e, "Keep-alive stream error");
break; return Ok(true); // Exit to reconnect
} }
} }
} }
_ = token.cancelled() => { _ = token.cancelled() => {
tracing::debug!(lease_id, "cancellation token triggered; revoking lease"); tracing::debug!(lease_id, "cancellation token triggered; revoking lease");
let mut lease_client = connector.get_client().lease_client();
if let Err(e) = lease_client.revoke(lease_id as i64).await { if let Err(e) = lease_client.revoke(lease_id as i64).await {
tracing::warn!( tracing::warn!(
lease_id, lease_id,
...@@ -127,25 +173,17 @@ async fn keep_alive( ...@@ -127,25 +173,17 @@ async fn keep_alive(
"Failed to revoke lease during cancellation. Cleanup may be incomplete." "Failed to revoke lease during cancellation. Cleanup may be incomplete."
); );
} }
return Ok(()); return Ok(false);
} }
_ = tokio::time::sleep(Duration::from_secs(ttl / 2)) => { _ = tokio::time::sleep(next_renewal) => {
tracing::trace!(lease_id, "sending keep alive"); tracing::trace!(lease_id, "sending keep alive");
if let Err(e) = sender.keep_alive().await {
// if we get a error issuing the heartbeat, set the ttl to 0
// this will allow us to poll the response stream once and the cancellation
// token once, then immediately try to tick the heartbeat
// this will repeat until either the heartbeat is reestablished or the deadline
// is exceeded
if let Err(e) = heartbeat_sender.keep_alive().await {
tracing::warn!( tracing::warn!(
lease_id, lease_id,
error = %e, error = %e,
"Unable to send lease heartbeat. Check etcd server status" "Unable to send lease heartbeat. Check etcd server status"
); );
ttl = 0;
}
} }
} }
} }
......
...@@ -148,24 +148,29 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -148,24 +148,29 @@ class DynamoWorkerProcess(ManagedProcess):
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_ha_failover_sglang_aggregated(request, predownload_models): def test_etcd_ha_failover_sglang_aggregated(request, predownload_models):
""" """
Test ETCD High Availability with leader failover using SGLang. Test ETCD High Availability with repeated node failures and recoveries using SGLang.
This test: This test:
1. Starts a 3-node ETCD cluster 1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and an SGLang worker 2. Starts NATS, frontend, and an SGLang worker
3. Sends an inference request to verify the system works 3. Cycles through each of the 3 replicas:
4. Terminates the ETCD leader node - Terminate the replica by index
5. Sends another inference request to verify the system still works - Send inference request to verify system still works
- Restart the terminated node
This ensures testing of:
- ETCD leader termination
- Frontend/worker disconnection from their connected ETCD replica
""" """
# Step 1: Start NATS server # Step 1: Start NATS server
with NatsServer(request): with NatsServer(request):
logger.info("NATS server started successfully") logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster # Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster: num_replicas = 3
with EtcdCluster(request, num_replicas=num_replicas) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully") logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes # Get the endpoints for all ETCD nodes
...@@ -182,46 +187,56 @@ def test_etcd_ha_failover_sglang_aggregated(request, predownload_models): ...@@ -182,46 +187,56 @@ def test_etcd_ha_failover_sglang_aggregated(request, predownload_models):
# Small wait to ensure worker is fully ready # Small wait to ensure worker is fully ready
time.sleep(2) time.sleep(2)
# Step 5: Send first inference request to verify system is working # Step 5: Send initial inference request to verify system is working
logger.info("Sending first inference request (before failover)") logger.info("Sending initial inference request")
result1 = send_inference_request("What is 2+2? The answer is") result = send_inference_request("What is 2+2? The answer is")
assert ( assert (
"4" in result1.lower() or "four" in result1.lower() "4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'" ), f"Expected '4' or 'four' in response, got: '{result}'"
# Step 6: Identify and terminate the ETCD leader
logger.info("Terminating ETCD leader to test failover")
terminated_idx = etcd_cluster.terminate_leader()
if terminated_idx is None:
pytest.fail("Failed to identify and terminate ETCD leader")
logger.info(f"Terminated ETCD node {terminated_idx}") # Step 6: Cycle through each replica to terminate/verify/restart
for i in range(num_replicas):
# Terminate a replica
logger.info(f"Iteration {i}: Terminating replica etcd-{i}")
etcd_cluster.terminate_replica(i)
# Step 7: Send second inference request to verify system still works # Send inference request to verify system still works
logger.info("Sending second inference request (after failover)") logger.info(
result2 = send_inference_request("The capital of France is") f"Iteration {i}: Sending inference request after termination"
)
result = send_inference_request(
"The capital of France is", max_tokens=20
)
assert ( assert (
"paris" in result2.lower() "paris" in result.lower()
), f"Expected 'Paris' in response, got: '{result2}'" ), f"Iteration {i}: Expected 'Paris' in response, got: '{result}'"
# Restart the terminated replica
logger.info(f"Iteration {i}: Restarting replica etcd-{i}")
etcd_cluster.restart_replica(i)
@pytest.mark.sglang @pytest.mark.sglang
@pytest.mark.gpu_2 @pytest.mark.gpu_2
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_ha_failover_sglang_disaggregated( def test_etcd_ha_failover_sglang_disaggregated(
request, predownload_models, set_ucx_tls_no_mm request, predownload_models, set_ucx_tls_no_mm
): ):
""" """
Test ETCD High Availability with leader failover in disaggregated mode using SGLang. Test ETCD High Availability with repeated node failures and recoveries in disaggregated mode using SGLang.
This test: This test:
1. Starts a 3-node ETCD cluster 1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and both prefill and decode SGLang workers 2. Starts NATS, frontend, and both prefill and decode SGLang workers
3. Sends an inference request to verify the system works 3. Cycles through each of the 3 replicas:
4. Terminates the ETCD leader node - Terminate the replica by index
5. Sends another inference request to verify the system still works - Send inference request to verify system still works
- Restart the terminated node
This ensures testing of:
- ETCD leader termination
- Frontend/worker disconnection from their connected ETCD replica
Note: This test requires 2 GPUs to run decode and prefill workers on separate GPUs. Note: This test requires 2 GPUs to run decode and prefill workers on separate GPUs.
""" """
...@@ -230,7 +245,8 @@ def test_etcd_ha_failover_sglang_disaggregated( ...@@ -230,7 +245,8 @@ def test_etcd_ha_failover_sglang_disaggregated(
logger.info("NATS server started successfully") logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster # Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster: num_replicas = 3
with EtcdCluster(request, num_replicas=num_replicas) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully") logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes # Get the endpoints for all ETCD nodes
...@@ -251,34 +267,39 @@ def test_etcd_ha_failover_sglang_disaggregated( ...@@ -251,34 +267,39 @@ def test_etcd_ha_failover_sglang_disaggregated(
# Small wait to ensure workers are fully ready # Small wait to ensure workers are fully ready
time.sleep(2) time.sleep(2)
# Step 6: Send first inference request to verify system is working # Step 6: Send initial inference request to verify system is working
logger.info("Sending first inference request (before failover)") logger.info("Sending initial inference request")
result1 = send_inference_request("What is 2+2? The answer is") result = send_inference_request("What is 2+2? The answer is")
assert ( assert (
"4" in result1.lower() or "four" in result1.lower() "4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'" ), f"Expected '4' or 'four' in response, got: '{result}'"
# Step 7: Identify and terminate the ETCD leader
logger.info("Terminating ETCD leader to test failover")
terminated_idx = etcd_cluster.terminate_leader()
if terminated_idx is None:
pytest.fail("Failed to identify and terminate ETCD leader")
logger.info(f"Terminated ETCD node {terminated_idx}") # Step 7: Cycle through each replica to terminate/verify/restart
for i in range(num_replicas):
# Terminate a replica
logger.info(f"Iteration {i}: Terminating replica etcd-{i}")
etcd_cluster.terminate_replica(i)
# Step 8: Send second inference request to verify system still works # Send inference request to verify system still works
logger.info("Sending second inference request (after failover)") logger.info(
result2 = send_inference_request("The capital of France is") f"Iteration {i}: Sending inference request after termination"
)
result = send_inference_request(
"The capital of France is", max_tokens=20
)
assert ( assert (
"paris" in result2.lower() "paris" in result.lower()
), f"Expected 'Paris' in response, got: '{result2}'" ), f"Iteration {i}: Expected 'Paris' in response, got: '{result}'"
# Restart the terminated replica
logger.info(f"Iteration {i}: Restarting replica etcd-{i}")
etcd_cluster.restart_replica(i)
@pytest.mark.sglang @pytest.mark.sglang
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_non_ha_shutdown_sglang_aggregated(request, predownload_models): def test_etcd_non_ha_shutdown_sglang_aggregated(request, predownload_models):
""" """
Test that frontend and worker shut down when single ETCD node is terminated using SGLang. Test that frontend and worker shut down when single ETCD node is terminated using SGLang.
...@@ -335,7 +356,6 @@ def test_etcd_non_ha_shutdown_sglang_aggregated(request, predownload_models): ...@@ -335,7 +356,6 @@ def test_etcd_non_ha_shutdown_sglang_aggregated(request, predownload_models):
@pytest.mark.gpu_2 @pytest.mark.gpu_2
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_non_ha_shutdown_sglang_disaggregated( def test_etcd_non_ha_shutdown_sglang_disaggregated(
request, predownload_models, set_ucx_tls_no_mm request, predownload_models, set_ucx_tls_no_mm
): ):
......
...@@ -134,24 +134,29 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -134,24 +134,29 @@ class DynamoWorkerProcess(ManagedProcess):
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_ha_failover_trtllm_aggregated(request, predownload_models): def test_etcd_ha_failover_trtllm_aggregated(request, predownload_models):
""" """
Test ETCD High Availability with leader failover for TRT-LLM in aggregated mode. Test ETCD High Availability with repeated node failures and recoveries for TRT-LLM in aggregated mode.
This test: This test:
1. Starts a 3-node ETCD cluster 1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and an aggregated TRT-LLM worker 2. Starts NATS, frontend, and an aggregated TRT-LLM worker
3. Sends an inference request to verify the system works 3. Cycles through each of the 3 replicas:
4. Terminates the ETCD leader node - Terminate the replica by index
5. Sends another inference request to verify the system still works - Send inference request to verify system still works
- Restart the terminated node
This ensures testing of:
- ETCD leader termination
- Frontend/worker disconnection from their connected ETCD replica
""" """
# Step 1: Start NATS server # Step 1: Start NATS server
with NatsServer(request): with NatsServer(request):
logger.info("NATS server started successfully") logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster # Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster: num_replicas = 3
with EtcdCluster(request, num_replicas=num_replicas) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully") logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes # Get the endpoints for all ETCD nodes
...@@ -168,53 +173,64 @@ def test_etcd_ha_failover_trtllm_aggregated(request, predownload_models): ...@@ -168,53 +173,64 @@ def test_etcd_ha_failover_trtllm_aggregated(request, predownload_models):
): ):
logger.info("Aggregated TRT-LLM worker started successfully") logger.info("Aggregated TRT-LLM worker started successfully")
# Step 5: Send first inference request to verify system is working # Step 5: Send initial inference request to verify system is working
logger.info("Sending first inference request (before failover)") logger.info("Sending initial inference request")
result1 = send_inference_request("What is 2+2? The answer is") result = send_inference_request("What is 2+2? The answer is")
assert ( assert (
"4" in result1.lower() or "four" in result1.lower() "4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'" ), f"Expected '4' or 'four' in response, got: '{result}'"
# Step 6: Identify and terminate the ETCD leader
logger.info("Terminating ETCD leader to test failover")
terminated_idx = etcd_cluster.terminate_leader()
if terminated_idx is None:
pytest.fail("Failed to identify and terminate ETCD leader")
logger.info(f"Terminated ETCD node {terminated_idx}") # Step 6: Cycle through each replica to terminate/verify/restart
for i in range(num_replicas):
# Terminate a replica
logger.info(f"Iteration {i}: Terminating replica etcd-{i}")
etcd_cluster.terminate_replica(i)
# Step 7: Send second inference request to verify system still works # Send inference request to verify system still works
logger.info("Sending second inference request (after failover)") logger.info(
result2 = send_inference_request("The capital of France is") f"Iteration {i}: Sending inference request after termination"
)
result = send_inference_request(
"The capital of France is", max_tokens=20
)
assert ( assert (
"paris" in result2.lower() "paris" in result.lower()
), f"Expected 'Paris' in response, got: '{result2}'" ), f"Iteration {i}: Expected 'Paris' in response, got: '{result}'"
# Restart the terminated replica
logger.info(f"Iteration {i}: Restarting replica etcd-{i}")
etcd_cluster.restart_replica(i)
@pytest.mark.trtllm_marker @pytest.mark.trtllm_marker
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_ha_failover_trtllm_disaggregated( def test_etcd_ha_failover_trtllm_disaggregated(
request, predownload_models, set_ucx_tls_no_mm request, predownload_models, set_ucx_tls_no_mm
): ):
""" """
Test ETCD High Availability with leader failover for TRT-LLM in disaggregated mode. Test ETCD High Availability with repeated node failures and recoveries for TRT-LLM in disaggregated mode.
This test: This test:
1. Starts a 3-node ETCD cluster 1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and both prefill and decode TRT-LLM workers 2. Starts NATS, frontend, and both prefill and decode TRT-LLM workers
3. Sends an inference request to verify the system works 3. Cycles through each of the 3 replicas:
4. Terminates the ETCD leader node - Terminate the replica by index
5. Sends another inference request to verify the system still works - Send inference request to verify system still works
- Restart the terminated node
This ensures testing of:
- ETCD leader termination
- Frontend/worker disconnection from their connected ETCD replica
""" """
# Step 1: Start NATS server # Step 1: Start NATS server
with NatsServer(request): with NatsServer(request):
logger.info("NATS server started successfully") logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster # Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster: num_replicas = 3
with EtcdCluster(request, num_replicas=num_replicas) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully") logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes # Get the endpoints for all ETCD nodes
...@@ -236,34 +252,39 @@ def test_etcd_ha_failover_trtllm_disaggregated( ...@@ -236,34 +252,39 @@ def test_etcd_ha_failover_trtllm_disaggregated(
# TODO: Fix disagg health checks # TODO: Fix disagg health checks
time.sleep(2) time.sleep(2)
# Step 6: Send first inference request to verify system is working # Step 6: Send initial inference request to verify system is working
logger.info("Sending first inference request (before failover)") logger.info("Sending initial inference request")
result1 = send_inference_request("What is 2+2? The answer is") result = send_inference_request("What is 2+2? The answer is")
assert ( assert (
"4" in result1.lower() or "four" in result1.lower() "4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'" ), f"Expected '4' or 'four' in response, got: '{result}'"
# Step 7: Identify and terminate the ETCD leader
logger.info("Terminating ETCD leader to test failover")
terminated_idx = etcd_cluster.terminate_leader()
if terminated_idx is None:
pytest.fail("Failed to identify and terminate ETCD leader")
logger.info(f"Terminated ETCD node {terminated_idx}") # Step 7: Cycle through each replica to terminate/verify/restart
for i in range(num_replicas):
# Terminate a replica
logger.info(f"Iteration {i}: Terminating replica etcd-{i}")
etcd_cluster.terminate_replica(i)
# Step 8: Send second inference request to verify system still works # Send inference request to verify system still works
logger.info("Sending second inference request (after failover)") logger.info(
result2 = send_inference_request("The capital of France is") f"Iteration {i}: Sending inference request after termination"
)
result = send_inference_request(
"The capital of France is", max_tokens=20
)
assert ( assert (
"paris" in result2.lower() "paris" in result.lower()
), f"Expected 'Paris' in response, got: '{result2}'" ), f"Iteration {i}: Expected 'Paris' in response, got: '{result}'"
# Restart the terminated replica
logger.info(f"Iteration {i}: Restarting replica etcd-{i}")
etcd_cluster.restart_replica(i)
@pytest.mark.trtllm_marker @pytest.mark.trtllm_marker
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_non_ha_shutdown_trtllm_aggregated(request, predownload_models): def test_etcd_non_ha_shutdown_trtllm_aggregated(request, predownload_models):
""" """
Test that frontend and worker shut down when single ETCD node is terminated for TRT-LLM in aggregated mode. Test that frontend and worker shut down when single ETCD node is terminated for TRT-LLM in aggregated mode.
...@@ -323,7 +344,6 @@ def test_etcd_non_ha_shutdown_trtllm_aggregated(request, predownload_models): ...@@ -323,7 +344,6 @@ def test_etcd_non_ha_shutdown_trtllm_aggregated(request, predownload_models):
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_non_ha_shutdown_trtllm_disaggregated( def test_etcd_non_ha_shutdown_trtllm_disaggregated(
request, predownload_models, set_ucx_tls_no_mm request, predownload_models, set_ucx_tls_no_mm
): ):
......
...@@ -116,24 +116,29 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -116,24 +116,29 @@ class DynamoWorkerProcess(ManagedProcess):
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_ha_failover_vllm_aggregated(request, predownload_models): def test_etcd_ha_failover_vllm_aggregated(request, predownload_models):
""" """
Test ETCD High Availability with leader failover. Test ETCD High Availability with repeated node failures and recoveries.
This test: This test:
1. Starts a 3-node ETCD cluster 1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and a vLLM worker 2. Starts NATS, frontend, and a vLLM worker
3. Sends an inference request to verify the system works 3. Cycles through each of the 3 replicas:
4. Terminates the ETCD leader node - Terminate the replica by index
5. Sends another inference request to verify the system still works - Send inference request to verify system still works
- Restart the terminated node
This ensures testing of:
- ETCD leader termination
- Frontend/worker disconnection from their connected ETCD replica
""" """
# Step 1: Start NATS server # Step 1: Start NATS server
with NatsServer(request): with NatsServer(request):
logger.info("NATS server started successfully") logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster # Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster: num_replicas = 3
with EtcdCluster(request, num_replicas=num_replicas) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully") logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes # Get the endpoints for all ETCD nodes
...@@ -148,53 +153,64 @@ def test_etcd_ha_failover_vllm_aggregated(request, predownload_models): ...@@ -148,53 +153,64 @@ def test_etcd_ha_failover_vllm_aggregated(request, predownload_models):
with DynamoWorkerProcess(request, etcd_endpoints): with DynamoWorkerProcess(request, etcd_endpoints):
logger.info("Worker started successfully") logger.info("Worker started successfully")
# Step 5: Send first inference request to verify system is working # Step 5: Send initial inference request to verify system is working
logger.info("Sending first inference request (before failover)") logger.info("Sending initial inference request")
result1 = send_inference_request("What is 2+2? The answer is") result = send_inference_request("What is 2+2? The answer is")
assert ( assert (
"4" in result1.lower() or "four" in result1.lower() "4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'" ), f"Expected '4' or 'four' in response, got: '{result}'"
# Step 6: Identify and terminate the ETCD leader
logger.info("Terminating ETCD leader to test failover")
terminated_idx = etcd_cluster.terminate_leader()
if terminated_idx is None:
pytest.fail("Failed to identify and terminate ETCD leader")
logger.info(f"Terminated ETCD node {terminated_idx}") # Step 6: Cycle through each replica to terminate/verify/restart
for i in range(num_replicas):
# Terminate a replica
logger.info(f"Iteration {i}: Terminating replica etcd-{i}")
etcd_cluster.terminate_replica(i)
# Step 7: Send second inference request to verify system still works # Send inference request to verify system still works
logger.info("Sending second inference request (after failover)") logger.info(
result2 = send_inference_request("The capital of France is") f"Iteration {i}: Sending inference request after termination"
)
result = send_inference_request(
"The capital of France is", max_tokens=20
)
assert ( assert (
"paris" in result2.lower() "paris" in result.lower()
), f"Expected 'Paris' in response, got: '{result2}'" ), f"Iteration {i}: Expected 'Paris' in response, got: '{result}'"
# Restart the terminated replica
logger.info(f"Iteration {i}: Restarting replica etcd-{i}")
etcd_cluster.restart_replica(i)
@pytest.mark.vllm @pytest.mark.vllm
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_ha_failover_vllm_disaggregated( def test_etcd_ha_failover_vllm_disaggregated(
request, predownload_models, set_ucx_tls_no_mm request, predownload_models, set_ucx_tls_no_mm
): ):
""" """
Test ETCD High Availability with leader failover in disaggregated mode. Test ETCD High Availability with repeated node failures and recoveries in disaggregated mode.
This test: This test:
1. Starts a 3-node ETCD cluster 1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and both prefill and decode vLLM workers 2. Starts NATS, frontend, and both prefill and decode vLLM workers
3. Sends an inference request to verify the system works 3. Cycles through each of the 3 replicas:
4. Terminates the ETCD leader node - Terminate the replica by index
5. Sends another inference request to verify the system still works - Send inference request to verify system still works
- Restart the terminated node
This ensures testing of:
- ETCD leader termination
- Frontend/worker disconnection from their connected ETCD replica
""" """
# Step 1: Start NATS server # Step 1: Start NATS server
with NatsServer(request): with NatsServer(request):
logger.info("NATS server started successfully") logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster # Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster: num_replicas = 3
with EtcdCluster(request, num_replicas=num_replicas) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully") logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes # Get the endpoints for all ETCD nodes
...@@ -213,34 +229,39 @@ def test_etcd_ha_failover_vllm_disaggregated( ...@@ -213,34 +229,39 @@ def test_etcd_ha_failover_vllm_disaggregated(
with DynamoWorkerProcess(request, etcd_endpoints, is_prefill=False): with DynamoWorkerProcess(request, etcd_endpoints, is_prefill=False):
logger.info("Decode worker started successfully") logger.info("Decode worker started successfully")
# Step 6: Send first inference request to verify system is working # Step 6: Send initial inference request to verify system is working
logger.info("Sending first inference request (before failover)") logger.info("Sending initial inference request")
result1 = send_inference_request("What is 2+2? The answer is") result = send_inference_request("What is 2+2? The answer is")
assert ( assert (
"4" in result1.lower() or "four" in result1.lower() "4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'" ), f"Expected '4' or 'four' in response, got: '{result}'"
# Step 7: Identify and terminate the ETCD leader
logger.info("Terminating ETCD leader to test failover")
terminated_idx = etcd_cluster.terminate_leader()
if terminated_idx is None:
pytest.fail("Failed to identify and terminate ETCD leader")
logger.info(f"Terminated ETCD node {terminated_idx}") # Step 7: Cycle through each replica to terminate/verify/restart
for i in range(num_replicas):
# Terminate a replica
logger.info(f"Iteration {i}: Terminating replica etcd-{i}")
etcd_cluster.terminate_replica(i)
# Step 8: Send second inference request to verify system still works # Send inference request to verify system still works
logger.info("Sending second inference request (after failover)") logger.info(
result2 = send_inference_request("The capital of France is") f"Iteration {i}: Sending inference request after termination"
)
result = send_inference_request(
"The capital of France is", max_tokens=20
)
assert ( assert (
"paris" in result2.lower() "paris" in result.lower()
), f"Expected 'Paris' in response, got: '{result2}'" ), f"Iteration {i}: Expected 'Paris' in response, got: '{result}'"
# Restart the terminated replica
logger.info(f"Iteration {i}: Restarting replica etcd-{i}")
etcd_cluster.restart_replica(i)
@pytest.mark.vllm @pytest.mark.vllm
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_non_ha_shutdown_vllm_aggregated(request, predownload_models): def test_etcd_non_ha_shutdown_vllm_aggregated(request, predownload_models):
""" """
Test that frontend and worker shut down when single ETCD node is terminated. Test that frontend and worker shut down when single ETCD node is terminated.
...@@ -295,7 +316,6 @@ def test_etcd_non_ha_shutdown_vllm_aggregated(request, predownload_models): ...@@ -295,7 +316,6 @@ def test_etcd_non_ha_shutdown_vllm_aggregated(request, predownload_models):
@pytest.mark.gpu_1 @pytest.mark.gpu_1
@pytest.mark.e2e @pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME) @pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.skip(reason="Broken, temporarily disabled")
def test_etcd_non_ha_shutdown_vllm_disaggregated( def test_etcd_non_ha_shutdown_vllm_disaggregated(
request, predownload_models, set_ucx_tls_no_mm request, predownload_models, set_ucx_tls_no_mm
): ):
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json
import logging import logging
import os import os
import shutil import shutil
import subprocess
import tempfile import tempfile
import time import time
from typing import List, Optional from typing import List, Optional
...@@ -62,6 +64,7 @@ class EtcdReplicaServer(ManagedProcess): ...@@ -62,6 +64,7 @@ class EtcdReplicaServer(ManagedProcess):
data_dir: str, data_dir: str,
log_dir: str, log_dir: str,
timeout: int = 30, timeout: int = 30,
cluster_state: str = "new",
): ):
self.name = name self.name = name
self.client_port = client_port self.client_port = client_port
...@@ -81,15 +84,15 @@ class EtcdReplicaServer(ManagedProcess): ...@@ -81,15 +84,15 @@ class EtcdReplicaServer(ManagedProcess):
"--listen-client-urls", "--listen-client-urls",
f"http://0.0.0.0:{client_port}", f"http://0.0.0.0:{client_port}",
"--advertise-client-urls", "--advertise-client-urls",
f"http://localhost:{client_port}", f"http://127.0.0.1:{client_port}",
"--listen-peer-urls", "--listen-peer-urls",
f"http://0.0.0.0:{peer_port}", f"http://0.0.0.0:{peer_port}",
"--initial-advertise-peer-urls", "--initial-advertise-peer-urls",
f"http://localhost:{peer_port}", f"http://127.0.0.1:{peer_port}",
"--initial-cluster", "--initial-cluster",
initial_cluster, initial_cluster,
"--initial-cluster-state", "--initial-cluster-state",
"new", cluster_state,
"--initial-cluster-token", "--initial-cluster-token",
"etcd-cluster", "etcd-cluster",
] ]
...@@ -108,7 +111,7 @@ class EtcdReplicaServer(ManagedProcess): ...@@ -108,7 +111,7 @@ class EtcdReplicaServer(ManagedProcess):
"""Get the status of this ETCD node""" """Get the status of this ETCD node"""
try: try:
response = requests.post( response = requests.post(
f"http://localhost:{self.client_port}/v3/maintenance/status", f"http://127.0.0.1:{self.client_port}/v3/maintenance/status",
json={}, json={},
timeout=2, timeout=2,
) )
...@@ -118,15 +121,19 @@ class EtcdReplicaServer(ManagedProcess): ...@@ -118,15 +121,19 @@ class EtcdReplicaServer(ManagedProcess):
logger.warning(f"Failed to get status for {self.name}: {e}") logger.warning(f"Failed to get status for {self.name}: {e}")
return {} return {}
def is_leader(self) -> bool: def is_leader(self) -> Optional[bool]:
"""Check if this node is the current leader""" """
Check if this node is the current leader.
Returns: True/False on is leader or None if status cannot be retrieved.
"""
status = self.get_status() status = self.get_status()
# In etcd v3 API, we check if this member ID matches the leader ID # In etcd v3 API, we check if this member ID matches the leader ID
if status: if status:
member_id = status.get("header", {}).get("member_id", "") member_id = status.get("header", {}).get("member_id", "")
leader_id = status.get("leader", "") leader_id = status.get("leader", "")
return member_id == leader_id return member_id == leader_id
return False return None
class EtcdCluster: class EtcdCluster:
...@@ -136,13 +143,11 @@ class EtcdCluster: ...@@ -136,13 +143,11 @@ class EtcdCluster:
self, self,
request, request,
num_replicas: int = 3, num_replicas: int = 3,
base_client_port: int = 2379, base_port: int = 2379,
base_peer_port: int = 12380,
): ):
self.request = request self.request = request
self.num_replicas = num_replicas self.num_replicas = num_replicas
self.base_client_port = base_client_port self.base_port = base_port
self.base_peer_port = base_peer_port
self.replicas: List[Optional[EtcdReplicaServer]] = [] self.replicas: List[Optional[EtcdReplicaServer]] = []
self.data_dirs: List[str] = [] self.data_dirs: List[str] = []
self.log_base_dir = f"{request.node.name}_etcd_cluster" self.log_base_dir = f"{request.node.name}_etcd_cluster"
...@@ -156,28 +161,31 @@ class EtcdCluster: ...@@ -156,28 +161,31 @@ class EtcdCluster:
os.makedirs(self.log_base_dir, exist_ok=True) os.makedirs(self.log_base_dir, exist_ok=True)
def start(self): def _get_initial_cluster(self) -> str:
"""Start ETCD cluster with configured number of replicas""" """Build the initial cluster configuration string"""
logger.info(f"Starting {self.num_replicas}-node ETCD cluster")
# Build initial cluster configuration
initial_cluster_parts = [] initial_cluster_parts = []
for i in range(self.num_replicas): for i in range(self.num_replicas):
name = f"etcd-{i}" name = f"etcd-{i}"
peer_port = self.base_peer_port + i peer_port = self.base_port + (2 * i) + 1
initial_cluster_parts.append(f"{name}=http://localhost:{peer_port}") initial_cluster_parts.append(f"{name}=http://127.0.0.1:{peer_port}")
return ",".join(initial_cluster_parts)
initial_cluster = ",".join(initial_cluster_parts)
def _start_replica(self, idx: int, cluster_state: str = "new") -> EtcdReplicaServer:
"""Start a single ETCD replica"""
name = f"etcd-{idx}"
# e.g. base_port = 2379 -> client_port = 2379, 2381, 2383
# e.g. base_port = 2379 -> peer_port = 2380, 2382, 2384
client_port = self.base_port + (2 * idx)
peer_port = self.base_port + (2 * idx) + 1
# Create data dir for the node
data_dir = tempfile.mkdtemp(prefix=f"etcd_{idx}_")
if idx < len(self.data_dirs):
self.data_dirs[idx] = data_dir
else:
self.data_dirs.append(data_dir)
# Start each replica
for i in range(self.num_replicas):
name = f"etcd-{i}"
client_port = self.base_client_port + i
peer_port = self.base_peer_port + i
data_dir = tempfile.mkdtemp(prefix=f"etcd_{i}_")
log_dir = os.path.join(self.log_base_dir, name) log_dir = os.path.join(self.log_base_dir, name)
self.data_dirs.append(data_dir)
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
logger.info( logger.info(
...@@ -189,90 +197,181 @@ class EtcdCluster: ...@@ -189,90 +197,181 @@ class EtcdCluster:
name=name, name=name,
client_port=client_port, client_port=client_port,
peer_port=peer_port, peer_port=peer_port,
initial_cluster=initial_cluster, initial_cluster=self._get_initial_cluster(),
data_dir=data_dir, data_dir=data_dir,
log_dir=log_dir, log_dir=log_dir,
cluster_state=cluster_state,
) )
replica.__enter__() replica.__enter__()
self.replicas.append(replica) return replica
logger.info(f"All {self.num_replicas} ETCD replicas started successfully")
# Wait for cluster to stabilize and elect a leader
self._wait_for_healthy_cluster(timeout=30)
leader_idx = self.find_leader()
if leader_idx is not None:
logger.info(f"Initial leader elected: etcd-{leader_idx}")
else:
logger.warning("No leader elected yet")
def _wait_for_healthy_cluster(self, timeout: int = 30): def _wait_for_healthy_cluster(self, timeout: int = 30):
"""Wait for all replicas to be healthy and responsive. """Wait for cluster to be healthy and elected leader."""
logger.info("Waiting for cluster to become healthy...")
Args:
timeout: Maximum time to wait in seconds
Raises:
RuntimeError: If cluster doesn't become healthy within timeout
"""
logger.info("Waiting for all replicas to be healthy...")
start_time = time.time() start_time = time.time()
while time.time() - start_time < timeout: while time.time() - start_time < timeout:
time.sleep(1) # Check if a leader is elected indicating cluster health
is_healthy = True
# Check if all replicas are responding leader_id = None
all_healthy = True
for i, replica in enumerate(self.replicas): for i, replica in enumerate(self.replicas):
if replica: if replica:
status = replica.get_status() is_leader = replica.is_leader()
if not status: if is_leader is None:
logger.debug(f"etcd-{i} not yet responsive") is_healthy = False
all_healthy = False
break break
if is_leader is True:
if leader_id is not None:
raise RuntimeError(
f"Multiple leaders detected in ETCD cluster etcd-{leader_id} and etcd-{i}"
)
leader_id = i
if all_healthy: if is_healthy and leader_id is not None:
logger.info("All replicas are healthy") logger.info(f"Cluster is healthy with leader at etcd-{leader_id}")
return return
time.sleep(1)
raise RuntimeError(f"ETCD cluster failed to become healthy within {timeout}s") raise RuntimeError(f"ETCD cluster failed to become healthy within {timeout}s")
def find_leader(self) -> Optional[int]: def _replace_member(self, idx: int):
"""Find which replica is currently the leader""" """Remove old member and add new member to the cluster using etcdctl"""
for i, replica in enumerate(self.replicas): # Find a healthy replica to perform member operations
if replica and replica.is_leader(): healthy_replica = None
return i for i, r in enumerate(self.replicas):
return None if r and i != idx:
healthy_replica = r
break
def terminate_leader(self) -> Optional[int]: if not healthy_replica:
"""Terminate the current leader and return its index""" raise RuntimeError("No healthy replica found to perform member operations")
leader_idx = self.find_leader()
if leader_idx is None: name = f"etcd-{idx}"
logger.warning("No leader found to terminate") peer_port = self.base_port + (2 * idx) + 1
return None peer_url = f"http://127.0.0.1:{peer_port}"
logger.info(f"Terminating current leader: etcd-{leader_idx}") # Set ETCDCTL_ENDPOINTS for etcdctl commands
replica = self.replicas[leader_idx] etcdctl_env = os.environ.copy()
etcdctl_env[
"ETCDCTL_ENDPOINTS"
] = f"http://127.0.0.1:{healthy_replica.client_port}"
etcdctl_env["ETCDCTL_API"] = "3"
if replica: # First, get member list to find the old member's ID
replica.__exit__(None, None, None) logger.info(f"Getting member list to find {name}")
self.replicas[leader_idx] = None try:
logger.info(f"Leader etcd-{leader_idx} has been terminated") result = subprocess.run(
["etcdctl", "member", "list", "--write-out=json"],
env=etcdctl_env,
capture_output=True,
text=True,
timeout=5,
)
if result.returncode == 0:
members = json.loads(result.stdout).get("members", [])
old_member_id = None
for member in members:
if member.get("name") == name:
old_member_id = member.get("ID")
break
if old_member_id:
# Convert member ID to hex format (etcdctl expects hex)
hex_member_id = format(int(old_member_id), "x")
logger.info(
f"Removing member with ID {old_member_id} (hex: {hex_member_id})"
)
remove_result = subprocess.run(
["etcdctl", "member", "remove", hex_member_id],
env=etcdctl_env,
capture_output=True,
text=True,
timeout=5,
)
if remove_result.returncode != 0:
raise RuntimeError(
f"Failed to remove old member: {remove_result.stderr}"
)
logger.info(f"Successfully removed old member {name}")
except Exception as e:
raise RuntimeError(f"Error during member removal: {e}")
# Add the new member to the cluster
logger.info(f"Adding new member {name} to cluster with peer URL {peer_url}")
try:
add_result = subprocess.run(
["etcdctl", "member", "add", name, f"--peer-urls={peer_url}"],
env=etcdctl_env,
capture_output=True,
text=True,
timeout=5,
)
if add_result.returncode != 0:
raise RuntimeError(f"Failed to add new member: {add_result.stderr}")
logger.info(f"Successfully added new member {name}")
except Exception as e:
raise RuntimeError(f"Error adding new member: {e}")
def start(self):
"""Start ETCD cluster with configured number of replicas"""
logger.info(f"Starting {self.num_replicas}-node ETCD cluster")
return leader_idx # Start each replica
for i in range(self.num_replicas):
replica = self._start_replica(i, cluster_state="new")
self.replicas.append(replica)
logger.info(f"All {self.num_replicas} ETCD replicas started successfully")
# Wait for cluster to stabilize
self._wait_for_healthy_cluster()
def get_client_endpoints(self) -> List[str]: def get_client_endpoints(self) -> List[str]:
"""Get list of active client endpoints""" """Get list of active client endpoints"""
endpoints = [] endpoints = []
for i, replica in enumerate(self.replicas): for i, replica in enumerate(self.replicas):
if replica: # Only include active replicas if replica: # Only include active replicas
client_port = self.base_client_port + i client_port = self.base_port + (2 * i)
endpoints.append(f"http://localhost:{client_port}") endpoints.append(f"http://127.0.0.1:{client_port}")
return endpoints return endpoints
def terminate_replica(self, idx: int):
"""Terminate a specific replica by index."""
if idx < 0 or idx >= self.num_replicas:
raise RuntimeError(f"Invalid replica index: {idx}")
replica = self.replicas[idx]
if not replica:
raise RuntimeError(f"Replica etcd-{idx} is already terminated")
replica.__exit__(None, None, None)
self.replicas[idx] = None
logger.info(f"Terminated replica etcd-{idx}")
def restart_replica(self, idx: int):
"""Restart a terminated replica"""
if idx < 0 or idx >= self.num_replicas:
raise RuntimeError(f"Invalid replica index: {idx}")
if self.replicas[idx] is not None:
raise RuntimeError(f"Replica etcd-{idx} is already running")
# Make sure the cluster is healthy before restarting
self._wait_for_healthy_cluster()
# Remove old member and add new member
self._replace_member(idx)
# Start the replica with existing cluster state
replica = self._start_replica(idx, cluster_state="existing")
self.replicas[idx] = replica
# Wait for cluster to stabilize
self._wait_for_healthy_cluster()
def stop(self): def stop(self):
"""Clean up all replicas and temporary directories""" """Clean up all replicas and temporary directories"""
logger.info("Cleaning up ETCD cluster") logger.info("Cleaning up ETCD cluster")
...@@ -329,10 +428,10 @@ def send_inference_request(prompt: str, max_tokens: int = 50) -> str: ...@@ -329,10 +428,10 @@ def send_inference_request(prompt: str, max_tokens: int = 50) -> str:
return text return text
else: else:
pytest.fail( pytest.fail(
f"Inference request failed with code {response.status_code}: {response.text}" f"[ETCD HA regression?] Inference request failed with code {response.status_code}: {response.text}"
) )
except Exception as e: except Exception as e:
pytest.fail(f"Inference request failed: {e}") pytest.fail(f"[ETCD HA regression?] Inference request failed: {e}")
def wait_for_processes_to_terminate( def wait_for_processes_to_terminate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment