# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import json import logging import os import shutil import subprocess import tempfile import time from typing import List, Optional import pytest import requests from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME from tests.utils.engine_process import FRONTEND_PORT from tests.utils.managed_process import ManagedProcess logger = logging.getLogger(__name__) class DynamoFrontendProcess(ManagedProcess): """Process manager for Dynamo frontend with ETCD HA support""" def __init__(self, request, etcd_endpoints: list): command = ["python", "-m", "dynamo.frontend"] # Set debug logging and ETCD endpoints env = os.environ.copy() env["DYN_LOG"] = "debug" env["ETCD_ENDPOINTS"] = ",".join(etcd_endpoints) # Unset DYN_SYSTEM_PORT - frontend doesn't use system metrics server env.pop("DYN_SYSTEM_PORT", None) log_dir = f"{request.node.name}_frontend" # Clean up any existing log directory from previous runs try: shutil.rmtree(log_dir) logger.info(f"Cleaned up existing log directory: {log_dir}") except FileNotFoundError: pass super().__init__( command=command, env=env, display_output=True, terminate_existing=True, log_dir=log_dir, ) class EtcdReplicaServer(ManagedProcess): """Single ETCD replica server in a cluster""" def __init__( self, request, name: str, client_port: int, peer_port: int, initial_cluster: str, data_dir: str, log_dir: str, timeout: int = 30, cluster_state: str = "new", ): self.name = name self.client_port = client_port self.peer_port = peer_port self.data_dir = data_dir etcd_env = os.environ.copy() etcd_env["ETCD_ENDPOINTS"] = "" # Clear any inherited ETCD endpoints etcd_env["ALLOW_NONE_AUTHENTICATION"] = "yes" command = [ "etcd", "--name", name, "--data-dir", data_dir, "--listen-client-urls", f"http://0.0.0.0:{client_port}", "--advertise-client-urls", f"http://127.0.0.1:{client_port}", "--listen-peer-urls", f"http://0.0.0.0:{peer_port}", "--initial-advertise-peer-urls", f"http://127.0.0.1:{peer_port}", "--initial-cluster", initial_cluster, "--initial-cluster-state", cluster_state, "--initial-cluster-token", "etcd-cluster", ] super().__init__( env=etcd_env, command=command, timeout=timeout, display_output=False, terminate_existing=False, data_dir=data_dir, log_dir=log_dir, ) def get_status(self) -> dict: """Get the status of this ETCD node""" try: response = requests.post( f"http://127.0.0.1:{self.client_port}/v3/maintenance/status", json={}, timeout=2, ) if response.status_code == 200: return response.json() except Exception as e: logger.warning(f"Failed to get status for {self.name}: {e}") return {} def is_leader(self) -> Optional[bool]: """ Check if this node is the current leader. Returns: True/False on is leader or None if status cannot be retrieved. """ status = self.get_status() # In etcd v3 API, we check if this member ID matches the leader ID if status: member_id = status.get("header", {}).get("member_id", "") leader_id = status.get("leader", "") return member_id == leader_id return None class EtcdCluster: """Manager for an ETCD cluster with configurable number of replicas""" def __init__( self, request, num_replicas: int = 3, base_port: int = 2379, ): self.request = request self.num_replicas = num_replicas self.base_port = base_port self.replicas: List[Optional[EtcdReplicaServer]] = [] self.data_dirs: List[str] = [] self.log_base_dir = f"{request.node.name}_etcd_cluster" # Clean up any existing log directory try: shutil.rmtree(self.log_base_dir) logger.info(f"Cleaned up existing log directory: {self.log_base_dir}") except FileNotFoundError: pass os.makedirs(self.log_base_dir, exist_ok=True) def _get_initial_cluster(self) -> str: """Build the initial cluster configuration string""" initial_cluster_parts = [] for i in range(self.num_replicas): name = f"etcd-{i}" peer_port = self.base_port + (2 * i) + 1 initial_cluster_parts.append(f"{name}=http://127.0.0.1:{peer_port}") return ",".join(initial_cluster_parts) def _start_replica(self, idx: int, cluster_state: str = "new") -> EtcdReplicaServer: """Start a single ETCD replica""" name = f"etcd-{idx}" # e.g. base_port = 2379 -> client_port = 2379, 2381, 2383 # e.g. base_port = 2379 -> peer_port = 2380, 2382, 2384 client_port = self.base_port + (2 * idx) peer_port = self.base_port + (2 * idx) + 1 # Create data dir for the node data_dir = tempfile.mkdtemp(prefix=f"etcd_{idx}_") if idx < len(self.data_dirs): self.data_dirs[idx] = data_dir else: self.data_dirs.append(data_dir) log_dir = os.path.join(self.log_base_dir, name) os.makedirs(log_dir, exist_ok=True) logger.info( f"Starting {name} on client port {client_port}, peer port {peer_port}" ) replica = EtcdReplicaServer( request=self.request, name=name, client_port=client_port, peer_port=peer_port, initial_cluster=self._get_initial_cluster(), data_dir=data_dir, log_dir=log_dir, cluster_state=cluster_state, ) replica.__enter__() return replica def _wait_for_healthy_cluster(self, timeout: int = 30): """Wait for cluster to be healthy and elected leader.""" logger.info("Waiting for cluster to become healthy...") start_time = time.time() while time.time() - start_time < timeout: # Check if a leader is elected indicating cluster health is_healthy = True leader_id = None for i, replica in enumerate(self.replicas): if replica: is_leader = replica.is_leader() if is_leader is None: is_healthy = False break if is_leader is True: if leader_id is not None: raise RuntimeError( f"Multiple leaders detected in ETCD cluster etcd-{leader_id} and etcd-{i}" ) leader_id = i if is_healthy and leader_id is not None: logger.info(f"Cluster is healthy with leader at etcd-{leader_id}") return time.sleep(1) raise RuntimeError(f"ETCD cluster failed to become healthy within {timeout}s") def _replace_member(self, idx: int): """Remove old member and add new member to the cluster using etcdctl""" # Find a healthy replica to perform member operations healthy_replica = None for i, r in enumerate(self.replicas): if r and i != idx: healthy_replica = r break if not healthy_replica: raise RuntimeError("No healthy replica found to perform member operations") name = f"etcd-{idx}" peer_port = self.base_port + (2 * idx) + 1 peer_url = f"http://127.0.0.1:{peer_port}" # Set ETCDCTL_ENDPOINTS for etcdctl commands etcdctl_env = os.environ.copy() etcdctl_env[ "ETCDCTL_ENDPOINTS" ] = f"http://127.0.0.1:{healthy_replica.client_port}" etcdctl_env["ETCDCTL_API"] = "3" # First, get member list to find the old member's ID logger.info(f"Getting member list to find {name}") try: result = subprocess.run( ["etcdctl", "member", "list", "--write-out=json"], env=etcdctl_env, capture_output=True, text=True, timeout=5, ) if result.returncode == 0: members = json.loads(result.stdout).get("members", []) old_member_id = None for member in members: if member.get("name") == name: old_member_id = member.get("ID") break if old_member_id: # Convert member ID to hex format (etcdctl expects hex) hex_member_id = format(int(old_member_id), "x") logger.info( f"Removing member with ID {old_member_id} (hex: {hex_member_id})" ) remove_result = subprocess.run( ["etcdctl", "member", "remove", hex_member_id], env=etcdctl_env, capture_output=True, text=True, timeout=5, ) if remove_result.returncode != 0: raise RuntimeError( f"Failed to remove old member: {remove_result.stderr}" ) logger.info(f"Successfully removed old member {name}") except Exception as e: raise RuntimeError(f"Error during member removal: {e}") # Add the new member to the cluster logger.info(f"Adding new member {name} to cluster with peer URL {peer_url}") try: add_result = subprocess.run( ["etcdctl", "member", "add", name, f"--peer-urls={peer_url}"], env=etcdctl_env, capture_output=True, text=True, timeout=5, ) if add_result.returncode != 0: raise RuntimeError(f"Failed to add new member: {add_result.stderr}") logger.info(f"Successfully added new member {name}") except Exception as e: raise RuntimeError(f"Error adding new member: {e}") def start(self): """Start ETCD cluster with configured number of replicas""" logger.info(f"Starting {self.num_replicas}-node ETCD cluster") # Start each replica for i in range(self.num_replicas): replica = self._start_replica(i, cluster_state="new") self.replicas.append(replica) logger.info(f"All {self.num_replicas} ETCD replicas started successfully") # Wait for cluster to stabilize self._wait_for_healthy_cluster() def get_client_endpoints(self) -> List[str]: """Get list of active client endpoints""" endpoints = [] for i, replica in enumerate(self.replicas): if replica: # Only include active replicas client_port = self.base_port + (2 * i) endpoints.append(f"http://127.0.0.1:{client_port}") return endpoints def terminate_replica(self, idx: int): """Terminate a specific replica by index.""" if idx < 0 or idx >= self.num_replicas: raise RuntimeError(f"Invalid replica index: {idx}") replica = self.replicas[idx] if not replica: raise RuntimeError(f"Replica etcd-{idx} is already terminated") replica.__exit__(None, None, None) self.replicas[idx] = None logger.info(f"Terminated replica etcd-{idx}") def restart_replica(self, idx: int): """Restart a terminated replica""" if idx < 0 or idx >= self.num_replicas: raise RuntimeError(f"Invalid replica index: {idx}") if self.replicas[idx] is not None: raise RuntimeError(f"Replica etcd-{idx} is already running") # Make sure the cluster is healthy before restarting self._wait_for_healthy_cluster() # Remove old member and add new member self._replace_member(idx) # Start the replica with existing cluster state replica = self._start_replica(idx, cluster_state="existing") self.replicas[idx] = replica # Wait for cluster to stabilize self._wait_for_healthy_cluster() def stop(self): """Clean up all replicas and temporary directories""" logger.info("Cleaning up ETCD cluster") # Stop all running replicas for replica in self.replicas: if replica: try: replica.__exit__(None, None, None) except Exception as e: logger.warning(f"Error stopping replica: {e}") self.replicas = [] # Clean up data directories for data_dir in self.data_dirs: try: shutil.rmtree(data_dir) except Exception as e: logger.warning(f"Error removing data directory {data_dir}: {e}") self.data_dirs = [] def __enter__(self): self.start() return self def __exit__(self, exc_type, exc_val, exc_tb): self.stop() def send_inference_request(prompt: str, max_tokens: int = 50) -> str: """Send a simple inference request to the frontend and return the generated text""" payload = { "model": FAULT_TOLERANCE_MODEL_NAME, "prompt": prompt, "max_tokens": max_tokens, "temperature": 0.0, # Make output deterministic } headers = {"Content-Type": "application/json"} logger.info(f"Sending inference request: '{prompt}'") try: response = requests.post( f"http://localhost:{FRONTEND_PORT}/v1/completions", headers=headers, json=payload, timeout=round(max_tokens * 0.6), ) if response.status_code == 200: result = response.json() text = result.get("choices", [{}])[0].get("text", "") logger.info(f"Inference generated text: '{text.strip()}'") return text else: pytest.fail( f"[ETCD HA regression?] Inference request failed with code {response.status_code}: {response.text}" ) except Exception as e: pytest.fail(f"[ETCD HA regression?] Inference request failed: {e}") def wait_for_processes_to_terminate( processes: dict, timeout: int = 30, poll_interval: int = 1 ) -> None: """ Wait for multiple processes to terminate and fail if they don't within timeout. Args: processes: Dictionary mapping process names to ManagedProcess instances timeout: Maximum time to wait in seconds poll_interval: Time between checks in seconds Raises: pytest.fail: If any process is still running after timeout """ logger.info(f"Waiting for {len(processes)} process(es) to terminate") elapsed = 0 terminated = {name: False for name in processes} while elapsed < timeout: time.sleep(poll_interval) elapsed += poll_interval # Check each process for name, process in processes.items(): if ( not terminated[name] and process.proc and process.proc.poll() is not None ): logger.info(f"{name} process has terminated after {elapsed}s") terminated[name] = True # Exit early if all processes have terminated if all(terminated.values()): return # Check for any processes still running and fail still_running = [name for name, term in terminated.items() if not term] if still_running: pytest.fail( f"Process(es) still running after {elapsed}s: {', '.join(still_running)}" )