Unverified Commit 432c5b13 authored by J Wyman's avatar J Wyman Committed by GitHub
Browse files

feat: Shutdown DRT when vLLM engine fails (#2698)


Signed-off-by: default avatarJ Wyman <jwyman@nvidia.com>
parent c2f0baa4
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import os
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.exceptions import EngineDeadError
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging
logger = logging.getLogger(__name__)
HEALTH_CHECK_INTERVAL = 2
class VllmEngineMonitor:
"""
Monitors the health of the vLLM engine and initiates a shutdown if the engine is dead.
"""
def __init__(self, runtime: DistributedRuntime, engine_client: AsyncLLM):
if not isinstance(runtime, DistributedRuntime):
raise ValueError(
f"{self.__class__.__name__} requires an instance of DistributedRuntime."
)
if not isinstance(engine_client, AsyncLLM):
raise ValueError(
f"{self.__class__.__name__} requires an instance of AsyncLLM."
)
self.runtime = runtime
self.engine_client = engine_client
self._monitor_task = asyncio.create_task(self._check_engine_health())
logger.info(
f"{self.__class__.__name__} initialized and health check task started."
)
def __del__(self):
self._monitor_task.cancel()
async def _check_engine_health(self):
while True:
try:
await self.engine_client.check_health()
await asyncio.sleep(HEALTH_CHECK_INTERVAL)
except EngineDeadError as e:
logger.error(f"vLLM AsyncLLM health check failed: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
except asyncio.CancelledError:
pass
......@@ -3,6 +3,7 @@
import asyncio
import logging
import os
import uuid
from abc import ABC, abstractmethod
from copy import deepcopy
......@@ -11,9 +12,11 @@ from typing import AsyncGenerator
import msgspec
from vllm.inputs import TokensPrompt
from vllm.sampling_params import SamplingParams
from vllm.v1.engine.exceptions import EngineDeadError
from dynamo.runtime.logging import configure_dynamo_logging
from .engine_monitor import VllmEngineMonitor
from .protocol import MyRequestOutput
configure_dynamo_logging()
......@@ -25,11 +28,13 @@ class BaseWorkerHandler(ABC):
Request handler for the generate and clear_kv_blocks endpoints.
"""
def __init__(self, component, engine, default_sampling_params):
def __init__(self, runtime, component, engine, default_sampling_params):
self.runtime = runtime
self.component = component
self.engine_client = engine
self.default_sampling_params = default_sampling_params
self.kv_publisher = None
self.engine_monitor = VllmEngineMonitor(runtime, engine)
@abstractmethod
async def generate(self, request, context) -> AsyncGenerator[dict, None]:
......@@ -47,6 +52,7 @@ class BaseWorkerHandler(ABC):
pass
async def generate_tokens(self, prompt, sampling_params, request_id):
try:
gen = self.engine_client.generate(prompt, sampling_params, request_id)
num_output_tokens_so_far = 0
......@@ -79,12 +85,23 @@ class BaseWorkerHandler(ABC):
"Decode engine was shut down during token generation"
) from None
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
class DecodeWorkerHandler(BaseWorkerHandler):
def __init__(
self, component, engine, default_sampling_params, prefill_worker_client=None
self,
runtime,
component,
engine,
default_sampling_params,
prefill_worker_client=None,
):
super().__init__(component, engine, default_sampling_params)
super().__init__(runtime, component, engine, default_sampling_params)
self.prefill_worker_client = prefill_worker_client
self.can_prefill = 0
self._prefill_check_task = None
......@@ -99,10 +116,13 @@ class DecodeWorkerHandler(BaseWorkerHandler):
if self.prefill_worker_client is not None:
self.can_prefill = len(self.prefill_worker_client.instance_ids())
logger.debug(f"Current Prefill Workers: {self.can_prefill}")
await asyncio.sleep(5)
except asyncio.CancelledError:
logger.warning("Prefill check loop cancelled.")
raise
except Exception as e:
logger.error(f"Error in prefill check loop: {e}")
await asyncio.sleep(5) # Still sleep on error to avoid tight loop
await asyncio.sleep(5)
def cleanup(self):
"""Cancel background tasks."""
......@@ -172,6 +192,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
"kv_transfer_params"
] = prefill_response.kv_transfer_params
try:
async for tok in self.generate_tokens(prompt, sampling_params, request_id):
if context.is_stopped() or context.is_killed():
await self.engine_client.abort(request_id)
......@@ -181,10 +202,16 @@ class DecodeWorkerHandler(BaseWorkerHandler):
yield tok
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
class PrefillWorkerHandler(BaseWorkerHandler):
def __init__(self, component, engine, default_sampling_params):
super().__init__(component, engine, default_sampling_params)
def __init__(self, runtime, component, engine, default_sampling_params):
super().__init__(runtime, component, engine, default_sampling_params)
async def generate(self, request, context):
request_id = request["request_id"]
......@@ -193,7 +220,13 @@ class PrefillWorkerHandler(BaseWorkerHandler):
prompt = TokensPrompt(prompt_token_ids=request["token_ids"])
sampling_params = msgspec.convert(request["sampling_params"], SamplingParams)
try:
gen = self.engine_client.generate(prompt, sampling_params, request_id)
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
# Generate only 1 token in prefill
try:
......
......@@ -142,7 +142,9 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
# TODO register_prefill in similar vein to register_llm
handler = PrefillWorkerHandler(component, engine_client, default_sampling_params)
handler = PrefillWorkerHandler(
runtime, component, engine_client, default_sampling_params
)
try:
await asyncio.gather(
......@@ -201,7 +203,11 @@ async def init(runtime: DistributedRuntime, config: Config):
logger.info(f"VllmWorker for {config.model} has been initialized")
handler = DecodeWorkerHandler(
component, engine_client, default_sampling_params, prefill_worker_client
runtime,
component,
engine_client,
default_sampling_params,
prefill_worker_client,
)
if config.engine_args.enable_prefix_caching:
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import os
import shutil
import time
import pytest
import requests
from huggingface_hub import snapshot_download
from tests.utils.deployment_graph import completions_response_handler
from tests.utils.managed_process import ManagedProcess
logger = logging.getLogger(__name__)
class DynamoFrontendProcess(ManagedProcess):
"""Process manager for Dynamo frontend"""
def __init__(self, request):
command = ["python", "-m", "dynamo.frontend", "--router-mode", "round-robin"]
log_dir = f"{request.node.name}_frontend"
# Clean up any existing log directory from previous runs
try:
shutil.rmtree(log_dir)
logger.info(f"Cleaned up existing log directory: {log_dir}")
except FileNotFoundError:
# Directory doesn't exist, which is fine
pass
super().__init__(
command=command,
display_output=True,
terminate_existing=True,
log_dir=log_dir,
)
def get_pid(self) -> int | None:
"""Get the PID of the worker process"""
return self.proc.pid if self.proc else None
class DynamoWorkerProcess(ManagedProcess):
"""Process manager for Dynamo worker with vLLM backend"""
def __init__(self, request, worker_id: str):
self.worker_id = worker_id
command = [
"python3",
"-m",
"dynamo.vllm",
"--model",
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"--enforce-eager",
"--gpu-memory-utilization",
"0.45",
"--max-model-len",
"8192",
"--migration-limit",
"3",
]
# Set debug logging environment
env = os.environ.copy()
env["DYN_LOG"] = "debug"
env["DYN_SYSTEM_ENABLED"] = "true"
env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
env["DYN_SYSTEM_PORT"] = "9345"
# TODO: Have the managed process take a command name explicitly to distinguish
# between processes started with the same command.
log_dir = f"{request.node.name}_{worker_id}"
# Clean up any existing log directory from previous runs
try:
shutil.rmtree(log_dir)
logger.info(f"Cleaned up existing log directory: {log_dir}")
except FileNotFoundError:
# Directory doesn't exist, which is fine
pass
super().__init__(
command=command,
env=env,
health_check_urls=[("http://localhost:9345/health", self.is_ready)],
timeout=300,
display_output=True,
terminate_existing=False,
log_dir=log_dir,
)
def get_pid(self) -> int | None:
"""Get the PID of the worker process"""
return self.proc.pid if hasattr(self, "proc") and self.proc else None
def is_ready(self, response) -> bool:
"""Check the health of the worker process"""
try:
data = response.json()
if data.get("status") == "ready":
logger.info(
f"{self.__class__.__name__} {{ name: {self.worker_id} }} status is ready"
)
return True
logger.warning(
f"{self.__class__.__name__} {{ name: {self.worker_id} }} status is not ready: {data.get('status')}"
)
except ValueError:
logger.warning(
f"{self.__class__.__name__} {{ name: {self.worker_id} }} health response is not valid JSON"
)
return False
def download_model() -> None:
"""
Download the DeepSeek-R1-Distill-Llama-8B model from HuggingFace Hub if not already cached.
"""
model_id = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
logger.info(f"Caching model {model_id}...")
max_retries = 5
retry_delay = 30 # seconds
for attempt in range(max_retries):
try:
# Download the model to the default cache directory
# This will skip download if the model is already cached
snapshot_download(
repo_id="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
repo_type="model",
local_files_only=False,
)
logger.info(f"Model {model_id} is ready for use")
return # Success, exit the function
except Exception as e:
if attempt < max_retries - 1: # Not the last attempt
logger.warning(
f"Failed to download model {model_id} (attempt {attempt + 1}/{max_retries}): {e}"
)
logger.info(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else: # Last attempt failed
logger.error(
f"Failed to download model {model_id} after {max_retries} attempts: {e}"
)
raise
def send_completion_request(
prompt: str, max_tokens: int, timeout: int = 120
) -> requests.Response:
"""Send a completion request to the frontend"""
payload = {
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"prompt": prompt,
"max_tokens": max_tokens,
}
headers = {"Content-Type": "application/json"}
logger.info(
f"Sending completion request with prompt: '{prompt[:50]}...' and max_tokens: {max_tokens}"
)
try:
response = requests.post(
"http://localhost:8080/v1/completions",
headers=headers,
json=payload,
timeout=timeout,
)
logger.info(f"Received response with status code: {response.status_code}")
return response
except requests.exceptions.Timeout:
logger.error(f"Request timed out after {timeout} seconds")
raise
except requests.exceptions.RequestException as e:
logger.error(f"Request failed with error: {e}")
raise
@pytest.mark.vllm
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.slow
def test_vllm_health_check_active(request, runtime_services):
"""
End-to-end test for worker fault tolerance with migration support.
This test verifies that when a worker is killed during request processing,
the system can handle the failure gracefully and migrate the request to
another worker.
"""
# Step 0: Download the model from HuggingFace if not already cached
download_model()
# Step 1: Start the frontend
logger.info("Starting frontend...")
with DynamoFrontendProcess(request):
logger.info("Frontend started.")
# Step 2: Start a worker
logger.info("Starting worker...")
with DynamoWorkerProcess(request, "decode") as worker:
logger.info(f"Worker PID: {worker.get_pid()}")
time.sleep(12) # Give the model some time to get started.
# Step 3: Send a test request to prove the worker is live.
test_response = send_completion_request("Who are you?", 100, timeout=60)
completions_response_handler(test_response)
logger.info("Test request completed successfully")
# Step 4: Find and kill vLLM engine processes to force the EngineDeadError condition.
children = worker.subprocesses()
logger.info(f"Worker children: {[child.pid for child in children]}")
for child in children:
cmdline = child.cmdline()
if len(cmdline) > 0 and cmdline[0] == "VLLM::EngineCore":
logger.warning(
f"Killing vLLM engine process {{ pid: {child.pid}, cmdline: '{' '.join(cmdline)}' }}"
)
child.kill()
break
time.sleep(2) # Give some time for the worker to stabilize
# Step 5: Send a request triggering the handler to shutdown everything.
test_response = send_completion_request("How old are you?", 100, timeout=60)
logger.error(f"Test request failed: {test_response}")
# Step 6: Ensure the worker process has been stopped as a result of the EngineDeadError condition.
if worker.is_running():
pytest.fail(
"Worker should not be running after killing vLLM engine process."
)
@pytest.mark.vllm
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.slow
def test_vllm_health_check_passive(request, runtime_services):
"""
End-to-end test for worker fault tolerance with migration support.
This test verifies that when a worker is killed during request processing,
the system can handle the failure gracefully and migrate the request to
another worker.
"""
# Step 0: Download the model from HuggingFace if not already cached
download_model()
# Step 1: Start the frontend
logger.info("Starting frontend...")
with DynamoFrontendProcess(request):
logger.info("Frontend started.")
# Step 2: Start a worker
logger.info("Starting worker...")
with DynamoWorkerProcess(request, "decode") as worker:
logger.info(f"Worker PID: {worker.get_pid()}")
time.sleep(12) # Give the model some time to get started.
# Step 3: Send a test request to prove the worker is live.
test_response = send_completion_request("Who are you?", 100, timeout=60)
completions_response_handler(test_response)
logger.info("Test request completed successfully")
# Step 4: Find and kill vLLM engine processes to force the EngineDeadError condition.
children = worker.subprocesses()
logger.info(f"Worker children: {[child.pid for child in children]}")
for child in children:
cmdline = child.cmdline()
if len(cmdline) > 0 and cmdline[0] == "VLLM::EngineCore":
logger.warning(
f"Killing vLLM engine process {{ pid: {child.pid}, cmdline: '{' '.join(cmdline)}' }}"
)
child.kill()
break
time.sleep(6) # Give some time for the worker to stabilize
# Step 5: Ensure the worker process has been stopped as a result of the EngineDeadError condition.
if worker.is_running():
pytest.fail(
"Worker should not be running after killing vLLM engine process."
)
......@@ -433,6 +433,27 @@ class ManagedProcess:
# Process may have terminated or become inaccessible during iteration
pass
def is_running(self) -> bool:
"""Check if the process is still running"""
return (
hasattr(self, "proc") and self.proc is not None and self.proc.poll() is None
)
def subprocesses(self) -> list[psutil.Process]:
"""Find child processes of the current process."""
if (
not hasattr(self, "proc")
or self.proc is None
or self.proc.poll() is not None
):
return []
try:
parent = psutil.Process(self.proc.pid)
return parent.children(recursive=True)
except psutil.NoSuchProcess:
return []
def main():
with ManagedProcess(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment