Unverified Commit 9b6c4f91 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: ETCD high availability client failover - lease keep-alive resilience (#3868)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent 9c5bdf83
...@@ -20,31 +20,18 @@ use etcd_client::{ ...@@ -20,31 +20,18 @@ use etcd_client::{
pub use etcd_client::{ConnectOptions, KeyValue, LeaseClient}; pub use etcd_client::{ConnectOptions, KeyValue, LeaseClient};
use tokio::time::{Duration, interval}; use tokio::time::{Duration, interval};
mod connector;
mod lease; mod lease;
mod lock; mod lock;
mod path; mod path;
use connector::Connector;
use lease::*; use lease::*;
pub use lock::*; pub use lock::*;
pub use path::*; pub use path::*;
use super::utils::build_in_runtime; use super::utils::build_in_runtime;
/// ETCD Client
#[derive(Clone)]
pub struct Client {
client: etcd_client::Client,
primary_lease: u64,
runtime: Runtime,
rt: Arc<tokio::runtime::Runtime>,
}
impl std::fmt::Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "etcd::Client primary_lease={}", self.primary_lease)
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Lease { pub struct Lease {
/// ETCD lease ID /// ETCD lease ID
...@@ -86,6 +73,21 @@ impl Lease { ...@@ -86,6 +73,21 @@ impl Lease {
} }
} }
/// ETCD Client
#[derive(Clone)]
pub struct Client {
connector: Arc<Connector>,
primary_lease: u64,
runtime: Runtime,
rt: Arc<tokio::runtime::Runtime>,
}
impl std::fmt::Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "etcd::Client primary_lease={}", self.primary_lease)
}
}
impl Client { impl Client {
pub fn builder() -> ClientOptionsBuilder { pub fn builder() -> ClientOptionsBuilder {
ClientOptionsBuilder::default() ClientOptionsBuilder::default()
...@@ -102,12 +104,13 @@ impl Client { ...@@ -102,12 +104,13 @@ impl Client {
pub async fn new(config: ClientOptions, runtime: Runtime) -> Result<Self> { pub async fn new(config: ClientOptions, runtime: Runtime) -> Result<Self> {
let token = runtime.primary_token(); let token = runtime.primary_token();
let ((client, lease_id), rt) = build_in_runtime( let ((connector, lease_id), rt) = build_in_runtime(
async move { async move {
let client = etcd_client::Client::connect( let etcd_urls = config.etcd_url.clone();
config.etcd_url.clone(), let connect_options = config.etcd_connect_options.clone();
config.etcd_connect_options,
) // Create the connector
let connector = Connector::new(etcd_urls, connect_options)
.await .await
.with_context(|| { .with_context(|| {
format!( format!(
...@@ -117,9 +120,7 @@ impl Client { ...@@ -117,9 +120,7 @@ impl Client {
})?; })?;
let lease_id = if config.attach_lease { let lease_id = if config.attach_lease {
let lease_client = client.lease_client(); let lease = create_lease(connector.clone(), 10, token)
let lease = create_lease(lease_client, 10, token)
.await .await
.with_context(|| { .with_context(|| {
format!( format!(
...@@ -133,23 +134,24 @@ impl Client { ...@@ -133,23 +134,24 @@ impl Client {
0 0
}; };
Ok((client, lease_id)) Ok((connector, lease_id))
}, },
1, 1,
) )
.await?; .await?;
Ok(Client { Ok(Client {
client, connector,
primary_lease: lease_id, primary_lease: lease_id,
rt, rt,
runtime, runtime,
}) })
} }
/// Get a reference to the underlying [`etcd_client::Client`] instance. /// Get a clone of the underlying [`etcd_client::Client`] instance.
pub(crate) fn etcd_client(&self) -> &etcd_client::Client { /// This returns a clone since the client is behind an RwLock.
&self.client pub fn etcd_client(&self) -> etcd_client::Client {
self.connector.get_client()
} }
/// Get the primary lease ID. /// Get the primary lease ID.
...@@ -169,16 +171,16 @@ impl Client { ...@@ -169,16 +171,16 @@ impl Client {
/// This [`Lease`] will be tied to the [`Runtime`], specifically a child [`CancellationToken`]. /// This [`Lease`] will be tied to the [`Runtime`], specifically a child [`CancellationToken`].
pub async fn create_lease(&self, ttl: u64) -> Result<Lease> { pub async fn create_lease(&self, ttl: u64) -> Result<Lease> {
let token = self.runtime.child_token(); let token = self.runtime.child_token();
let lease_client = self.client.lease_client();
self.rt self.rt
.spawn(create_lease(lease_client, ttl, token)) .spawn(create_lease(self.connector.clone(), ttl, token))
.await? .await?
} }
// Revoke an etcd lease given its lease id. A wrapper over etcd_client::LeaseClient::revoke // Revoke an etcd lease given its lease id. A wrapper over etcd_client::LeaseClient::revoke
pub async fn revoke_lease(&self, lease_id: u64) -> Result<()> { pub async fn revoke_lease(&self, lease_id: u64) -> Result<()> {
let lease_client = self.client.lease_client(); self.rt
self.rt.spawn(revoke_lease(lease_client, lease_id)).await? .spawn(revoke_lease(self.connector.clone(), lease_id))
.await?
} }
pub async fn kv_create(&self, key: &str, value: Vec<u8>, lease_id: Option<u64>) -> Result<()> { pub async fn kv_create(&self, key: &str, value: Vec<u8>, lease_id: Option<u64>) -> Result<()> {
...@@ -193,7 +195,7 @@ impl Client { ...@@ -193,7 +195,7 @@ impl Client {
]); ]);
// Execute the transaction // Execute the transaction
let result = self.client.kv_client().txn(txn).await?; let result = self.connector.get_client().kv_client().txn(txn).await?;
if result.succeeded() { if result.succeeded() {
Ok(()) Ok(())
...@@ -232,7 +234,7 @@ impl Client { ...@@ -232,7 +234,7 @@ impl Client {
]); ]);
// Execute the transaction // Execute the transaction
let result = self.client.kv_client().txn(txn).await?; let result = self.connector.get_client().kv_client().txn(txn).await?;
// We have to enumerate the response paths to determine if the transaction succeeded // We have to enumerate the response paths to determine if the transaction succeeded
if result.succeeded() { if result.succeeded() {
...@@ -266,7 +268,8 @@ impl Client { ...@@ -266,7 +268,8 @@ impl Client {
let id = lease_id.unwrap_or(self.lease_id()); let id = lease_id.unwrap_or(self.lease_id());
let put_options = PutOptions::new().with_lease(id as i64); let put_options = PutOptions::new().with_lease(id as i64);
let _ = self let _ = self
.client .connector
.get_client()
.kv_client() .kv_client()
.put(key.as_ref(), value.as_ref(), Some(put_options)) .put(key.as_ref(), value.as_ref(), Some(put_options))
.await?; .await?;
...@@ -282,7 +285,8 @@ impl Client { ...@@ -282,7 +285,8 @@ impl Client {
let options = options let options = options
.unwrap_or_default() .unwrap_or_default()
.with_lease(self.primary_lease().id() as i64); .with_lease(self.primary_lease().id() as i64);
self.client self.connector
.get_client()
.kv_client() .kv_client()
.put(key.as_ref(), value.as_ref(), Some(options)) .put(key.as_ref(), value.as_ref(), Some(options))
.await .await
...@@ -294,7 +298,12 @@ impl Client { ...@@ -294,7 +298,12 @@ impl Client {
key: impl Into<Vec<u8>>, key: impl Into<Vec<u8>>,
options: Option<GetOptions>, options: Option<GetOptions>,
) -> Result<Vec<KeyValue>> { ) -> Result<Vec<KeyValue>> {
let mut get_response = self.client.kv_client().get(key, options).await?; let mut get_response = self
.connector
.get_client()
.kv_client()
.get(key, options)
.await?;
Ok(get_response.take_kvs()) Ok(get_response.take_kvs())
} }
...@@ -303,7 +312,8 @@ impl Client { ...@@ -303,7 +312,8 @@ impl Client {
key: impl Into<Vec<u8>>, key: impl Into<Vec<u8>>,
options: Option<DeleteOptions>, options: Option<DeleteOptions>,
) -> Result<u64> { ) -> Result<u64> {
self.client self.connector
.get_client()
.kv_client() .kv_client()
.delete(key, options) .delete(key, options)
.await .await
...@@ -313,7 +323,8 @@ impl Client { ...@@ -313,7 +323,8 @@ impl Client {
pub async fn kv_get_prefix(&self, prefix: impl AsRef<str>) -> Result<Vec<KeyValue>> { pub async fn kv_get_prefix(&self, prefix: impl AsRef<str>) -> Result<Vec<KeyValue>> {
let mut get_response = self let mut get_response = self
.client .connector
.get_client()
.kv_client() .kv_client()
.get(prefix.as_ref(), Some(GetOptions::new().with_prefix())) .get(prefix.as_ref(), Some(GetOptions::new().with_prefix()))
.await?; .await?;
...@@ -328,7 +339,7 @@ impl Client { ...@@ -328,7 +339,7 @@ impl Client {
key: impl Into<Vec<u8>>, key: impl Into<Vec<u8>>,
lease_id: Option<u64>, lease_id: Option<u64>,
) -> Result<LockResponse> { ) -> Result<LockResponse> {
let mut lock_client = self.client.lock_client(); let mut lock_client = self.connector.get_client().lock_client();
let id = lease_id.unwrap_or(self.lease_id()); let id = lease_id.unwrap_or(self.lease_id());
let options = LockOptions::new().with_lease(id as i64); let options = LockOptions::new().with_lease(id as i64);
lock_client lock_client
...@@ -339,7 +350,7 @@ impl Client { ...@@ -339,7 +350,7 @@ impl Client {
/// Release a distributed lock using the key from the LockResponse /// Release a distributed lock using the key from the LockResponse
pub async fn unlock(&self, lock_key: impl Into<Vec<u8>>) -> Result<()> { pub async fn unlock(&self, lock_key: impl Into<Vec<u8>>) -> Result<()> {
let mut lock_client = self.client.lock_client(); let mut lock_client = self.connector.get_client().lock_client();
lock_client lock_client
.unlock(lock_key) .unlock(lock_key)
.await .await
...@@ -367,8 +378,9 @@ impl Client { ...@@ -367,8 +378,9 @@ impl Client {
prefix: impl AsRef<str> + std::fmt::Display, prefix: impl AsRef<str> + std::fmt::Display,
include_existing: bool, include_existing: bool,
) -> Result<PrefixWatcher> { ) -> Result<PrefixWatcher> {
let mut kv_client = self.client.kv_client(); let client = self.connector.get_client();
let mut watch_client = self.client.watch_client(); let mut kv_client = client.kv_client();
let mut watch_client = client.watch_client();
let mut get_response = kv_client let mut get_response = kv_client
.get(prefix.as_ref(), Some(GetOptions::new().with_prefix())) .get(prefix.as_ref(), Some(GetOptions::new().with_prefix()))
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::{ErrorContext, Result, error};
use etcd_client::ConnectOptions;
use parking_lot::RwLock;
use std::{sync::Arc, time::Duration};
use tokio::{sync::Mutex, time::sleep};
/// Manages ETCD client connections with reconnection support
pub struct Connector {
/// The actual ETCD client, protected by RwLock for safe updates during reconnection
/// WARNING: Do not recursively acquire a read lock when the current thread already holds one
client: RwLock<etcd_client::Client>,
/// Configuration for connecting to ETCD
etcd_urls: Vec<String>,
connect_options: Option<ConnectOptions>,
/// Tracks the current backoff duration and last successful connect time
/// The Mutex ensures only one reconnect operation runs at a time
backoff_state: Mutex<BackoffState>,
}
impl Connector {
/// Create a new connector with an established connection
pub async fn new(
etcd_urls: Vec<String>,
connect_options: Option<ConnectOptions>,
) -> Result<Arc<Self>> {
// Connect to ETCD
let client = Self::connect(&etcd_urls, &connect_options).await?;
Ok(Arc::new(Self {
client: RwLock::new(client),
etcd_urls,
connect_options,
backoff_state: Mutex::new(BackoffState::default()),
}))
}
/// Connect to ETCD cluster
async fn connect(
etcd_urls: &[String],
connect_options: &Option<ConnectOptions>,
) -> Result<etcd_client::Client> {
etcd_client::Client::connect(etcd_urls.to_vec(), connect_options.clone())
.await
.with_context(|| {
format!(
"Unable to connect to etcd server at {}. Check etcd server status",
etcd_urls.join(", ")
)
})
}
/// Get a clone of the current ETCD client
pub fn get_client(&self) -> etcd_client::Client {
self.client.read().clone()
}
/// Reconnect to ETCD cluster with retry logic
/// Respects the deadline and returns error if exceeded
///
/// Backoff behavior:
/// - Starts at 0 (immediate reconnect) if this is the first reconnect or enough time has passed
/// since the last reconnect
/// - Increments exponentially for continuous failures
/// - Resets to 0 only when: this is a new call AND current_time > last_connect_time + residual_backoff
///
/// The mutex ensures only one reconnect operation runs at a time globally
pub async fn reconnect(&self, deadline: std::time::Instant) -> Result<()> {
let mut backoff_state = self.backoff_state.lock().await;
tracing::warn!("Reconnecting to ETCD cluster at: {:?}", self.etcd_urls);
backoff_state.attempt_reset();
loop {
backoff_state.apply_backoff(deadline).await;
if std::time::Instant::now() >= deadline {
return Err(error!(
"Unable to reconnect to ETCD cluster: deadline exceeded"
));
}
match Self::connect(&self.etcd_urls, &self.connect_options).await {
Ok(new_client) => {
tracing::info!("Successfully reconnected to ETCD cluster");
// Update the client behind the lock
let mut client_guard = self.client.write();
*client_guard = new_client;
return Ok(());
}
Err(e) => {
tracing::warn!(
"Reconnection failed (remaining time: {:?}): {}",
deadline.saturating_duration_since(std::time::Instant::now()),
e
);
}
}
}
}
/// Get the ETCD URLs
pub fn etcd_urls(&self) -> &[String] {
&self.etcd_urls
}
/// Get the connection options
pub fn connect_options(&self) -> &Option<ConnectOptions> {
&self.connect_options
}
}
#[derive(Debug)]
struct BackoffState {
/// Initial backoff duration for reconnection attempts
pub initial_backoff: Duration,
/// Minimum backoff duration for reconnection attempts
pub min_backoff: Duration,
/// Maximum backoff duration for reconnection attempts
pub max_backoff: Duration,
/// Current backoff duration (starts at 0 for immediate reconnect)
current_backoff: Duration,
/// Last time a connection establishment was attempted
last_connect_attempt: std::time::Instant,
}
impl Default for BackoffState {
fn default() -> Self {
Self {
initial_backoff: Duration::from_millis(500),
min_backoff: Duration::from_millis(50),
max_backoff: Duration::from_secs(5),
current_backoff: Duration::ZERO,
last_connect_attempt: std::time::Instant::now(),
}
}
}
impl BackoffState {
/// Reset backoff to 0 if enough time has passed since the last connection
pub fn attempt_reset(&mut self) {
if std::time::Instant::now() > self.last_connect_attempt + self.current_backoff {
tracing::debug!("Resetting backoff to 0 (first reconnect or enough time has passed)");
self.current_backoff = Duration::ZERO;
}
}
/// Apply backoff and update backoff state for possible next connection attempt
pub async fn apply_backoff(&mut self, deadline: std::time::Instant) {
if self.current_backoff > Duration::ZERO {
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
let backoff = std::cmp::min(self.current_backoff, remaining / 2);
let backoff = std::cmp::min(backoff, self.max_backoff);
let backoff = std::cmp::max(backoff, self.min_backoff);
self.current_backoff = backoff * 2;
tracing::debug!(
"Applying backoff of {:?} (remaining time: {:?})",
backoff,
remaining
);
sleep(backoff).await;
} else {
self.current_backoff = self.initial_backoff;
}
self.last_connect_attempt = std::time::Instant::now();
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use super::connector::Connector;
use super::*; use super::*;
use std::sync::Arc;
use std::time::Duration;
use tokio::time::{sleep, timeout};
/// Create a [`Lease`] with a given time-to-live (TTL) attached to the [`CancellationToken`]. /// Create a [`Lease`] with a given time-to-live (TTL) attached to the [`CancellationToken`].
pub async fn create_lease( pub async fn create_lease(
mut lease_client: LeaseClient, connector: Arc<Connector>,
ttl: u64, ttl: u64,
token: CancellationToken, token: CancellationToken,
) -> Result<Lease> { ) -> Result<Lease> {
let mut lease_client = connector.get_client().lease_client();
let lease = lease_client.grant(ttl as i64, None).await?; let lease = lease_client.grant(ttl as i64, None).await?;
let id = lease.id() as u64; let id = lease.id() as u64;
...@@ -17,7 +22,7 @@ pub async fn create_lease( ...@@ -17,7 +22,7 @@ pub async fn create_lease(
let clone = token.clone(); let clone = token.clone();
tokio::spawn(async move { tokio::spawn(async move {
match keep_alive(lease_client, id, ttl, child).await { match keep_alive(connector, id, ttl, child).await {
Ok(_) => tracing::trace!("keep alive task exited successfully"), Ok(_) => tracing::trace!("keep alive task exited successfully"),
Err(e) => { Err(e) => {
tracing::error!( tracing::error!(
...@@ -36,7 +41,8 @@ pub async fn create_lease( ...@@ -36,7 +41,8 @@ pub async fn create_lease(
} }
/// Revoke a lease given its lease id. A wrapper over etcd_client::LeaseClient::revoke /// Revoke a lease given its lease id. A wrapper over etcd_client::LeaseClient::revoke
pub async fn revoke_lease(mut lease_client: LeaseClient, lease_id: u64) -> Result<()> { pub async fn revoke_lease(connector: Arc<Connector>, lease_id: u64) -> Result<()> {
let mut lease_client = connector.get_client().lease_client();
match lease_client.revoke(lease_id as i64).await { match lease_client.revoke(lease_id as i64).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(e) => { Err(e) => {
...@@ -46,25 +52,52 @@ pub async fn revoke_lease(mut lease_client: LeaseClient, lease_id: u64) -> Resul ...@@ -46,25 +52,52 @@ pub async fn revoke_lease(mut lease_client: LeaseClient, lease_id: u64) -> Resul
} }
} }
/// Task to keep leases alive. /// Task to keep leases alive with reconnection support.
/// ///
/// If this task returns an error, the cancellation token will be invoked on the runtime. /// If this task returns an error, the cancellation token will be invoked on the runtime.
/// If async fn keep_alive(
pub async fn keep_alive( connector: Arc<Connector>,
client: LeaseClient,
lease_id: u64, lease_id: u64,
ttl: u64, mut ttl: u64,
token: CancellationToken, token: CancellationToken,
) -> Result<()> { ) -> Result<()> {
let mut ttl = ttl;
let mut deadline = create_deadline(ttl)?; let mut deadline = create_deadline(ttl)?;
let mut client = client; loop {
let (mut heartbeat_sender, mut heartbeat_receiver) = client.keep_alive(lease_id as i64).await?; // Try to establish or re-establish the keep-alive stream
let mut lease_client = connector.get_client().lease_client();
let (mut heartbeat_sender, mut heartbeat_receiver) = match lease_client
.keep_alive(lease_id as i64)
.await
{
Ok((sender, receiver)) => {
tracing::debug!(lease_id, "Established keep-alive stream");
(sender, receiver)
}
Err(e) => {
tracing::warn!(lease_id, error = %e, "Failed to establish keep-alive stream");
// Try to reconnect with the deadline, but also check for cancellation
tokio::select! {
biased;
reconnect_result = connector.reconnect(deadline) => {
match reconnect_result {
Err(e) => return Err(e),
_ => continue,
}
}
_ = token.cancelled() => {
tracing::debug!(lease_id, "Cancellation token triggered during reconnection");
return Ok(());
}
}
}
};
// Keep-alive loop with the established stream
loop { loop {
// if the deadline is exceeded, then we have failed to issue a heartbeat in time
// we may be permanently disconnected from the etcd server, so we are now officially done
if deadline < std::time::Instant::now() { if deadline < std::time::Instant::now() {
return Err(error!( return Err(error!(
"Unable to refresh lease - deadline exceeded. Check etcd server status" "Unable to refresh lease - deadline exceeded. Check etcd server status"
...@@ -75,33 +108,49 @@ pub async fn keep_alive( ...@@ -75,33 +108,49 @@ pub async fn keep_alive(
biased; biased;
status = heartbeat_receiver.message() => { status = heartbeat_receiver.message() => {
if let Some(resp) = status? { match status {
Ok(Some(resp)) => {
tracing::trace!(lease_id, "keep alive response received: {:?}", resp); tracing::trace!(lease_id, "keep alive response received: {:?}", resp);
// update ttl and deadline // Update ttl and deadline from response
ttl = resp.ttl() as u64; ttl = resp.ttl() as u64;
deadline = create_deadline(ttl)?; deadline = create_deadline(ttl)?;
if resp.ttl() == 0 { if resp.ttl() == 0 {
return Err(error!("Unable to maintain lease - expired or revoked. Check etcd server status")); return Err(error!("Unable to maintain lease - expired or revoked. Check etcd server status"));
} }
}
Ok(None) => {
tracing::warn!(lease_id, "Keep-alive stream unexpectedly ended");
break;
}
Err(e) => {
tracing::warn!(lease_id, error = %e, "Keep-alive stream error");
break;
}
} }
} }
_ = token.cancelled() => { _ = token.cancelled() => {
tracing::trace!(lease_id, "cancellation token triggered; revoking lease"); tracing::debug!(lease_id, "cancellation token triggered; revoking lease");
let _ = client.revoke(lease_id as i64).await?; if let Err(e) = lease_client.revoke(lease_id as i64).await {
tracing::warn!(
lease_id,
error = %e,
"Failed to revoke lease during cancellation. Cleanup may be incomplete."
);
}
return Ok(()); return Ok(());
} }
_ = tokio::time::sleep(tokio::time::Duration::from_secs(ttl / 2)) => { _ = tokio::time::sleep(Duration::from_secs(ttl / 2)) => {
tracing::trace!(lease_id, "sending keep alive"); tracing::trace!(lease_id, "sending keep alive");
// if we get a error issuing the heartbeat, set the ttl to 0 // if we get a error issuing the heartbeat, set the ttl to 0
// this will allow us to poll the response stream once and the cancellation token once, then // this will allow us to poll the response stream once and the cancellation
// immediately try to tick the heartbeat // token once, then immediately try to tick the heartbeat
// this will repeat until either the heartbeat is reestablished or the deadline is exceeded // this will repeat until either the heartbeat is reestablished or the deadline
// is exceeded
if let Err(e) = heartbeat_sender.keep_alive().await { if let Err(e) = heartbeat_sender.keep_alive().await {
tracing::warn!( tracing::warn!(
lease_id, lease_id,
...@@ -111,7 +160,7 @@ pub async fn keep_alive( ...@@ -111,7 +160,7 @@ pub async fn keep_alive(
ttl = 0; ttl = 0;
} }
} }
}
} }
} }
} }
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import os
import shutil
import pytest
from tests.conftest import NatsServer
from tests.fault_tolerance.etcd_ha.utils import (
DynamoFrontendProcess,
EtcdCluster,
send_inference_request,
wait_for_processes_to_terminate,
)
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.engine_process import FRONTEND_PORT
from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_health_generate, check_models_api
logger = logging.getLogger(__name__)
class DynamoWorkerProcess(ManagedProcess):
"""Process manager for Dynamo worker with vLLM backend and ETCD HA support"""
def __init__(self, request, etcd_endpoints: list):
command = [
"python3",
"-m",
"dynamo.vllm",
"--model",
FAULT_TOLERANCE_MODEL_NAME,
"--enforce-eager",
"--gpu-memory-utilization",
"0.45",
"--max-model-len",
"8192",
]
# Health checks - frontend model registration
health_check_urls = [
(f"http://localhost:{FRONTEND_PORT}/v1/models", check_models_api),
(f"http://localhost:{FRONTEND_PORT}/health", check_health_generate),
]
# Set debug logging and ETCD endpoints
env = os.environ.copy()
env["DYN_LOG"] = "debug"
env["ETCD_ENDPOINTS"] = ",".join(etcd_endpoints)
log_dir = f"{request.node.name}_worker"
# Clean up any existing log directory from previous runs
try:
shutil.rmtree(log_dir)
logger.info(f"Cleaned up existing log directory: {log_dir}")
except FileNotFoundError:
pass
super().__init__(
command=command,
env=env,
health_check_urls=health_check_urls,
timeout=120,
display_output=True,
terminate_existing=False,
stragglers=[
"VLLM::EngineCore",
],
straggler_commands=[
"-m dynamo.vllm",
],
log_dir=log_dir,
)
@pytest.mark.vllm
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
@pytest.mark.xfail(
strict=False,
reason="Lease watch failover not yet implemented, only lease keep alive failover is implemented",
)
def test_etcd_ha_failover_vllm_aggregated(request, predownload_models):
"""
Test ETCD High Availability with leader failover.
This test:
1. Starts a 3-node ETCD cluster
2. Starts NATS, frontend, and a vLLM worker
3. Sends an inference request to verify the system works
4. Terminates the ETCD leader node
5. Sends another inference request to verify the system still works
"""
# Step 1: Start NATS server
with NatsServer(request):
logger.info("NATS server started successfully")
# Step 2: Start 3-node ETCD cluster
with EtcdCluster(request) as etcd_cluster:
logger.info("3-node ETCD cluster started successfully")
# Get the endpoints for all ETCD nodes
etcd_endpoints = etcd_cluster.get_client_endpoints()
logger.info(f"ETCD endpoints: {etcd_endpoints}")
# Step 3: Start the frontend with ETCD endpoints
with DynamoFrontendProcess(request, etcd_endpoints):
logger.info("Frontend started successfully")
# Step 4: Start a vLLM worker
with DynamoWorkerProcess(request, etcd_endpoints):
logger.info("Worker started successfully")
# Step 5: Send first inference request to verify system is working
logger.info("Sending first inference request (before failover)")
result1 = send_inference_request("What is 2+2? The answer is")
assert (
"4" in result1.lower() or "four" in result1.lower()
), f"Expected '4' or 'four' in response, got: '{result1}'"
# Step 6: Identify and terminate the ETCD leader
logger.info("Terminating ETCD leader to test failover")
terminated_idx = etcd_cluster.terminate_leader()
if terminated_idx is None:
pytest.fail("Failed to identify and terminate ETCD leader")
logger.info(f"Terminated ETCD node {terminated_idx}")
# Step 7: Send second inference request to verify system still works
logger.info("Sending second inference request (after failover)")
result2 = send_inference_request("The capital of France is")
assert (
"paris" in result2.lower()
), f"Expected 'Paris' in response, got: '{result2}'"
@pytest.mark.vllm
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_etcd_non_ha_shutdown_vllm_aggregated(request, predownload_models):
"""
Test that frontend and worker shut down when single ETCD node is terminated.
This test:
1. Starts a single ETCD node (no cluster)
2. Starts NATS, frontend, and a vLLM worker
3. Sends an inference request to verify the system works
4. Terminates the single ETCD node
5. Verifies that frontend and worker shut down gracefully
"""
# Step 1: Start NATS server
with NatsServer(request):
logger.info("NATS server started successfully")
# Step 2: Start single ETCD node using EtcdCluster with num_replicas=1
with EtcdCluster(request, num_replicas=1) as etcd_cluster:
logger.info("Single ETCD node started successfully")
# Get the endpoint for the single ETCD node
etcd_endpoints = etcd_cluster.get_client_endpoints()
logger.info(f"ETCD endpoint: {etcd_endpoints}")
# Step 3: Start the frontend with ETCD endpoint
with DynamoFrontendProcess(request, etcd_endpoints) as frontend:
logger.info("Frontend started successfully")
# Step 4: Start a vLLM worker
with DynamoWorkerProcess(request, etcd_endpoints) as worker:
logger.info("Worker started successfully")
# Step 5: Send inference request to verify system is working
logger.info("Sending inference request")
result = send_inference_request("What is 2+2? The answer is")
assert (
"4" in result.lower() or "four" in result.lower()
), f"Expected '4' or 'four' in response, got: '{result}'"
logger.info("System is working correctly with single ETCD node")
# Step 6: Terminate the ETCD node
logger.info("Terminating single ETCD node")
etcd_cluster.stop()
# Step 7: Wait and verify frontend and worker detect the loss
wait_for_processes_to_terminate(
{"Worker": worker, "Frontend": frontend}
)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import os
import shutil
import tempfile
import time
from typing import List, Optional
import pytest
import requests
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.engine_process import FRONTEND_PORT
from tests.utils.managed_process import ManagedProcess
logger = logging.getLogger(__name__)
class DynamoFrontendProcess(ManagedProcess):
"""Process manager for Dynamo frontend with ETCD HA support"""
def __init__(self, request, etcd_endpoints: list):
command = ["python", "-m", "dynamo.frontend"]
# Set debug logging and ETCD endpoints
env = os.environ.copy()
env["DYN_LOG"] = "debug"
env["ETCD_ENDPOINTS"] = ",".join(etcd_endpoints)
log_dir = f"{request.node.name}_frontend"
# Clean up any existing log directory from previous runs
try:
shutil.rmtree(log_dir)
logger.info(f"Cleaned up existing log directory: {log_dir}")
except FileNotFoundError:
pass
super().__init__(
command=command,
env=env,
display_output=True,
terminate_existing=True,
log_dir=log_dir,
)
class EtcdReplicaServer(ManagedProcess):
"""Single ETCD replica server in a cluster"""
def __init__(
self,
request,
name: str,
client_port: int,
peer_port: int,
initial_cluster: str,
data_dir: str,
log_dir: str,
timeout: int = 30,
):
self.name = name
self.client_port = client_port
self.peer_port = peer_port
self.data_dir = data_dir
etcd_env = os.environ.copy()
etcd_env["ETCD_ENDPOINTS"] = "" # Clear any inherited ETCD endpoints
etcd_env["ALLOW_NONE_AUTHENTICATION"] = "yes"
command = [
"etcd",
"--name",
name,
"--data-dir",
data_dir,
"--listen-client-urls",
f"http://0.0.0.0:{client_port}",
"--advertise-client-urls",
f"http://localhost:{client_port}",
"--listen-peer-urls",
f"http://0.0.0.0:{peer_port}",
"--initial-advertise-peer-urls",
f"http://localhost:{peer_port}",
"--initial-cluster",
initial_cluster,
"--initial-cluster-state",
"new",
"--initial-cluster-token",
"etcd-cluster",
]
super().__init__(
env=etcd_env,
command=command,
timeout=timeout,
display_output=False,
terminate_existing=False,
data_dir=data_dir,
log_dir=log_dir,
)
def get_status(self) -> dict:
"""Get the status of this ETCD node"""
try:
response = requests.post(
f"http://localhost:{self.client_port}/v3/maintenance/status",
json={},
timeout=2,
)
if response.status_code == 200:
return response.json()
except Exception as e:
logger.warning(f"Failed to get status for {self.name}: {e}")
return {}
def is_leader(self) -> bool:
"""Check if this node is the current leader"""
status = self.get_status()
# In etcd v3 API, we check if this member ID matches the leader ID
if status:
member_id = status.get("header", {}).get("member_id", "")
leader_id = status.get("leader", "")
return member_id == leader_id
return False
class EtcdCluster:
"""Manager for an ETCD cluster with configurable number of replicas"""
def __init__(
self,
request,
num_replicas: int = 3,
base_client_port: int = 2379,
base_peer_port: int = 12380,
):
self.request = request
self.num_replicas = num_replicas
self.base_client_port = base_client_port
self.base_peer_port = base_peer_port
self.replicas: List[Optional[EtcdReplicaServer]] = []
self.data_dirs: List[str] = []
self.log_base_dir = f"{request.node.name}_etcd_cluster"
# Clean up any existing log directory
try:
shutil.rmtree(self.log_base_dir)
logger.info(f"Cleaned up existing log directory: {self.log_base_dir}")
except FileNotFoundError:
pass
os.makedirs(self.log_base_dir, exist_ok=True)
def start(self):
"""Start ETCD cluster with configured number of replicas"""
logger.info(f"Starting {self.num_replicas}-node ETCD cluster")
# Build initial cluster configuration
initial_cluster_parts = []
for i in range(self.num_replicas):
name = f"etcd-{i}"
peer_port = self.base_peer_port + i
initial_cluster_parts.append(f"{name}=http://localhost:{peer_port}")
initial_cluster = ",".join(initial_cluster_parts)
# Start each replica
for i in range(self.num_replicas):
name = f"etcd-{i}"
client_port = self.base_client_port + i
peer_port = self.base_peer_port + i
data_dir = tempfile.mkdtemp(prefix=f"etcd_{i}_")
log_dir = os.path.join(self.log_base_dir, name)
self.data_dirs.append(data_dir)
os.makedirs(log_dir, exist_ok=True)
logger.info(
f"Starting {name} on client port {client_port}, peer port {peer_port}"
)
replica = EtcdReplicaServer(
request=self.request,
name=name,
client_port=client_port,
peer_port=peer_port,
initial_cluster=initial_cluster,
data_dir=data_dir,
log_dir=log_dir,
)
replica.__enter__()
self.replicas.append(replica)
logger.info(f"All {self.num_replicas} ETCD replicas started successfully")
# Wait for cluster to stabilize and elect a leader
self._wait_for_healthy_cluster(timeout=30)
leader_idx = self.find_leader()
if leader_idx is not None:
logger.info(f"Initial leader elected: etcd-{leader_idx}")
else:
logger.warning("No leader elected yet")
def _wait_for_healthy_cluster(self, timeout: int = 30):
"""Wait for all replicas to be healthy and responsive.
Args:
timeout: Maximum time to wait in seconds
Raises:
RuntimeError: If cluster doesn't become healthy within timeout
"""
logger.info("Waiting for all replicas to be healthy...")
start_time = time.time()
while time.time() - start_time < timeout:
time.sleep(1)
# Check if all replicas are responding
all_healthy = True
for i, replica in enumerate(self.replicas):
if replica:
status = replica.get_status()
if not status:
logger.debug(f"etcd-{i} not yet responsive")
all_healthy = False
break
if all_healthy:
logger.info("All replicas are healthy")
return
raise RuntimeError(f"ETCD cluster failed to become healthy within {timeout}s")
def find_leader(self) -> Optional[int]:
"""Find which replica is currently the leader"""
for i, replica in enumerate(self.replicas):
if replica and replica.is_leader():
return i
return None
def terminate_leader(self) -> Optional[int]:
"""Terminate the current leader and return its index"""
leader_idx = self.find_leader()
if leader_idx is None:
logger.warning("No leader found to terminate")
return None
logger.info(f"Terminating current leader: etcd-{leader_idx}")
replica = self.replicas[leader_idx]
if replica:
replica.__exit__(None, None, None)
self.replicas[leader_idx] = None
logger.info(f"Leader etcd-{leader_idx} has been terminated")
return leader_idx
def get_client_endpoints(self) -> List[str]:
"""Get list of active client endpoints"""
endpoints = []
for i, replica in enumerate(self.replicas):
if replica: # Only include active replicas
client_port = self.base_client_port + i
endpoints.append(f"http://localhost:{client_port}")
return endpoints
def stop(self):
"""Clean up all replicas and temporary directories"""
logger.info("Cleaning up ETCD cluster")
# Stop all running replicas
for replica in self.replicas:
if replica:
try:
replica.__exit__(None, None, None)
except Exception as e:
logger.warning(f"Error stopping replica: {e}")
self.replicas = []
# Clean up data directories
for data_dir in self.data_dirs:
try:
shutil.rmtree(data_dir)
except Exception as e:
logger.warning(f"Error removing data directory {data_dir}: {e}")
self.data_dirs = []
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
def send_inference_request(prompt: str, max_tokens: int = 50) -> str:
"""Send a simple inference request to the frontend and return the generated text"""
payload = {
"model": FAULT_TOLERANCE_MODEL_NAME,
"prompt": prompt,
"max_tokens": max_tokens,
"temperature": 0.0, # Make output deterministic
}
headers = {"Content-Type": "application/json"}
logger.info(f"Sending inference request: '{prompt}'")
try:
response = requests.post(
f"http://localhost:{FRONTEND_PORT}/v1/completions",
headers=headers,
json=payload,
timeout=round(max_tokens * 0.6),
)
if response.status_code == 200:
result = response.json()
text = result.get("choices", [{}])[0].get("text", "")
logger.info(f"Inference generated text: '{text.strip()}'")
return text
else:
pytest.fail(
f"Inference request failed with code {response.status_code}: {response.text}"
)
except Exception as e:
pytest.fail(f"Inference request failed: {e}")
def wait_for_processes_to_terminate(
processes: dict, timeout: int = 30, poll_interval: int = 1
) -> None:
"""
Wait for multiple processes to terminate and fail if they don't within timeout.
Args:
processes: Dictionary mapping process names to ManagedProcess instances
timeout: Maximum time to wait in seconds
poll_interval: Time between checks in seconds
Raises:
pytest.fail: If any process is still running after timeout
"""
logger.info(f"Waiting for {len(processes)} process(es) to terminate")
elapsed = 0
terminated = {name: False for name in processes}
while elapsed < timeout:
time.sleep(poll_interval)
elapsed += poll_interval
# Check each process
for name, process in processes.items():
if (
not terminated[name]
and process.proc
and process.proc.poll() is not None
):
logger.info(f"{name} process has terminated after {elapsed}s")
terminated[name] = True
# Exit early if all processes have terminated
if all(terminated.values()):
return
# Check for any processes still running and fail
still_running = [name for name, term in terminated.items() if not term]
if still_running:
pytest.fail(
f"Process(es) still running after {elapsed}s: {', '.join(still_running)}"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment