Unverified Commit 432c5b13 authored by J Wyman's avatar J Wyman Committed by GitHub
Browse files

feat: Shutdown DRT when vLLM engine fails (#2698)


Signed-off-by: default avatarJ Wyman <jwyman@nvidia.com>
parent c2f0baa4
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import logging
import os
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.exceptions import EngineDeadError
from dynamo.runtime import DistributedRuntime
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging
logger = logging.getLogger(__name__)
HEALTH_CHECK_INTERVAL = 2
class VllmEngineMonitor:
"""
Monitors the health of the vLLM engine and initiates a shutdown if the engine is dead.
"""
def __init__(self, runtime: DistributedRuntime, engine_client: AsyncLLM):
if not isinstance(runtime, DistributedRuntime):
raise ValueError(
f"{self.__class__.__name__} requires an instance of DistributedRuntime."
)
if not isinstance(engine_client, AsyncLLM):
raise ValueError(
f"{self.__class__.__name__} requires an instance of AsyncLLM."
)
self.runtime = runtime
self.engine_client = engine_client
self._monitor_task = asyncio.create_task(self._check_engine_health())
logger.info(
f"{self.__class__.__name__} initialized and health check task started."
)
def __del__(self):
self._monitor_task.cancel()
async def _check_engine_health(self):
while True:
try:
await self.engine_client.check_health()
await asyncio.sleep(HEALTH_CHECK_INTERVAL)
except EngineDeadError as e:
logger.error(f"vLLM AsyncLLM health check failed: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
except asyncio.CancelledError:
pass
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import asyncio import asyncio
import logging import logging
import os
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from copy import deepcopy from copy import deepcopy
...@@ -11,9 +12,11 @@ from typing import AsyncGenerator ...@@ -11,9 +12,11 @@ from typing import AsyncGenerator
import msgspec import msgspec
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.engine.exceptions import EngineDeadError
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from .engine_monitor import VllmEngineMonitor
from .protocol import MyRequestOutput from .protocol import MyRequestOutput
configure_dynamo_logging() configure_dynamo_logging()
...@@ -25,11 +28,13 @@ class BaseWorkerHandler(ABC): ...@@ -25,11 +28,13 @@ class BaseWorkerHandler(ABC):
Request handler for the generate and clear_kv_blocks endpoints. Request handler for the generate and clear_kv_blocks endpoints.
""" """
def __init__(self, component, engine, default_sampling_params): def __init__(self, runtime, component, engine, default_sampling_params):
self.runtime = runtime
self.component = component self.component = component
self.engine_client = engine self.engine_client = engine
self.default_sampling_params = default_sampling_params self.default_sampling_params = default_sampling_params
self.kv_publisher = None self.kv_publisher = None
self.engine_monitor = VllmEngineMonitor(runtime, engine)
@abstractmethod @abstractmethod
async def generate(self, request, context) -> AsyncGenerator[dict, None]: async def generate(self, request, context) -> AsyncGenerator[dict, None]:
...@@ -47,44 +52,56 @@ class BaseWorkerHandler(ABC): ...@@ -47,44 +52,56 @@ class BaseWorkerHandler(ABC):
pass pass
async def generate_tokens(self, prompt, sampling_params, request_id): async def generate_tokens(self, prompt, sampling_params, request_id):
gen = self.engine_client.generate(prompt, sampling_params, request_id)
num_output_tokens_so_far = 0
try: try:
async for res in gen: gen = self.engine_client.generate(prompt, sampling_params, request_id)
# res is vllm's RequestOutput
# This is the expected way for a request to end.
# The new token ID will be eos, don't forward it.
if res.finished:
yield {"finish_reason": "stop", "token_ids": []}
break
if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0] num_output_tokens_so_far = 0
next_total_toks = len(output.token_ids) try:
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]} async for res in gen:
if output.finish_reason: # res is vllm's RequestOutput
out["finish_reason"] = output.finish_reason
if output.stop_reason: # This is the expected way for a request to end.
out["stop_reason"] = output.stop_reason # The new token ID will be eos, don't forward it.
yield out if res.finished:
num_output_tokens_so_far = next_total_toks yield {"finish_reason": "stop", "token_ids": []}
except asyncio.CancelledError: break
# raise EngineShGeneratorExit when engine exits so that frontend can migrate the request
raise GeneratorExit( if not res.outputs:
"Decode engine was shut down during token generation" yield {"finish_reason": "error", "token_ids": []}
) from None break
output = res.outputs[0]
next_total_toks = len(output.token_ids)
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
yield out
num_output_tokens_so_far = next_total_toks
except asyncio.CancelledError:
# raise EngineShGeneratorExit when engine exits so that frontend can migrate the request
raise GeneratorExit(
"Decode engine was shut down during token generation"
) from None
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
class DecodeWorkerHandler(BaseWorkerHandler): class DecodeWorkerHandler(BaseWorkerHandler):
def __init__( def __init__(
self, component, engine, default_sampling_params, prefill_worker_client=None self,
runtime,
component,
engine,
default_sampling_params,
prefill_worker_client=None,
): ):
super().__init__(component, engine, default_sampling_params) super().__init__(runtime, component, engine, default_sampling_params)
self.prefill_worker_client = prefill_worker_client self.prefill_worker_client = prefill_worker_client
self.can_prefill = 0 self.can_prefill = 0
self._prefill_check_task = None self._prefill_check_task = None
...@@ -99,10 +116,13 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -99,10 +116,13 @@ class DecodeWorkerHandler(BaseWorkerHandler):
if self.prefill_worker_client is not None: if self.prefill_worker_client is not None:
self.can_prefill = len(self.prefill_worker_client.instance_ids()) self.can_prefill = len(self.prefill_worker_client.instance_ids())
logger.debug(f"Current Prefill Workers: {self.can_prefill}") logger.debug(f"Current Prefill Workers: {self.can_prefill}")
await asyncio.sleep(5) except asyncio.CancelledError:
logger.warning("Prefill check loop cancelled.")
raise
except Exception as e: except Exception as e:
logger.error(f"Error in prefill check loop: {e}") logger.error(f"Error in prefill check loop: {e}")
await asyncio.sleep(5) # Still sleep on error to avoid tight loop
await asyncio.sleep(5)
def cleanup(self): def cleanup(self):
"""Cancel background tasks.""" """Cancel background tasks."""
...@@ -145,46 +165,53 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -145,46 +165,53 @@ class DecodeWorkerHandler(BaseWorkerHandler):
"request_id": request_id, "request_id": request_id,
} }
# TODO Change to prefill queue # TODO Change to prefill queue
if self.prefill_worker_client is not None: if self.prefill_worker_client is not None:
try: try:
prefill_response = await anext( prefill_response = await anext(
await self.prefill_worker_client.round_robin( await self.prefill_worker_client.round_robin(
prefill_request, context=context prefill_request, context=context
)
) )
except Exception as e:
# TODO: Cancellation does not propagate until the first token is received
if context.is_stopped() or context.is_killed():
logger.debug(f"Aborted Remote Prefill Request ID: {request_id}")
# TODO: Raise asyncio.CancelledError into bindings
return
raise e
prefill_response = MyRequestOutput.model_validate_json(
prefill_response.data()
) )
except Exception as e:
# TODO: Cancellation does not propagate until the first token is received
if context.is_stopped() or context.is_killed():
logger.debug(f"Aborted Remote Prefill Request ID: {request_id}")
# TODO: Raise asyncio.CancelledError into bindings
return
raise e
prefill_response = MyRequestOutput.model_validate_json(
prefill_response.data()
)
# Modify original sampling_params for decode
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
sampling_params.extra_args[
"kv_transfer_params"
] = prefill_response.kv_transfer_params
# Modify original sampling_params for decode try:
if sampling_params.extra_args is None: async for tok in self.generate_tokens(prompt, sampling_params, request_id):
sampling_params.extra_args = {} if context.is_stopped() or context.is_killed():
sampling_params.extra_args[ await self.engine_client.abort(request_id)
"kv_transfer_params" logger.debug(f"Aborted Request ID: {request_id}")
] = prefill_response.kv_transfer_params # TODO: Raise asyncio.CancelledError into bindings
break
async for tok in self.generate_tokens(prompt, sampling_params, request_id): yield tok
if context.is_stopped() or context.is_killed():
await self.engine_client.abort(request_id)
logger.debug(f"Aborted Request ID: {request_id}")
# TODO: Raise asyncio.CancelledError into bindings
break
yield tok except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
class PrefillWorkerHandler(BaseWorkerHandler): class PrefillWorkerHandler(BaseWorkerHandler):
def __init__(self, component, engine, default_sampling_params): def __init__(self, runtime, component, engine, default_sampling_params):
super().__init__(component, engine, default_sampling_params) super().__init__(runtime, component, engine, default_sampling_params)
async def generate(self, request, context): async def generate(self, request, context):
request_id = request["request_id"] request_id = request["request_id"]
...@@ -193,7 +220,13 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -193,7 +220,13 @@ class PrefillWorkerHandler(BaseWorkerHandler):
prompt = TokensPrompt(prompt_token_ids=request["token_ids"]) prompt = TokensPrompt(prompt_token_ids=request["token_ids"])
sampling_params = msgspec.convert(request["sampling_params"], SamplingParams) sampling_params = msgspec.convert(request["sampling_params"], SamplingParams)
gen = self.engine_client.generate(prompt, sampling_params, request_id) try:
gen = self.engine_client.generate(prompt, sampling_params, request_id)
except EngineDeadError as e:
logger.error(f"vLLM EngineDeadError: {e}")
logger.warning("Initiating Dynamo Runtime shutdown.")
self.runtime.shutdown()
os._exit(1)
# Generate only 1 token in prefill # Generate only 1 token in prefill
try: try:
......
...@@ -142,7 +142,9 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -142,7 +142,9 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
# TODO register_prefill in similar vein to register_llm # TODO register_prefill in similar vein to register_llm
handler = PrefillWorkerHandler(component, engine_client, default_sampling_params) handler = PrefillWorkerHandler(
runtime, component, engine_client, default_sampling_params
)
try: try:
await asyncio.gather( await asyncio.gather(
...@@ -201,7 +203,11 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -201,7 +203,11 @@ async def init(runtime: DistributedRuntime, config: Config):
logger.info(f"VllmWorker for {config.model} has been initialized") logger.info(f"VllmWorker for {config.model} has been initialized")
handler = DecodeWorkerHandler( handler = DecodeWorkerHandler(
component, engine_client, default_sampling_params, prefill_worker_client runtime,
component,
engine_client,
default_sampling_params,
prefill_worker_client,
) )
if config.engine_args.enable_prefix_caching: if config.engine_args.enable_prefix_caching:
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import os
import shutil
import time
import pytest
import requests
from huggingface_hub import snapshot_download
from tests.utils.deployment_graph import completions_response_handler
from tests.utils.managed_process import ManagedProcess
logger = logging.getLogger(__name__)
class DynamoFrontendProcess(ManagedProcess):
"""Process manager for Dynamo frontend"""
def __init__(self, request):
command = ["python", "-m", "dynamo.frontend", "--router-mode", "round-robin"]
log_dir = f"{request.node.name}_frontend"
# Clean up any existing log directory from previous runs
try:
shutil.rmtree(log_dir)
logger.info(f"Cleaned up existing log directory: {log_dir}")
except FileNotFoundError:
# Directory doesn't exist, which is fine
pass
super().__init__(
command=command,
display_output=True,
terminate_existing=True,
log_dir=log_dir,
)
def get_pid(self) -> int | None:
"""Get the PID of the worker process"""
return self.proc.pid if self.proc else None
class DynamoWorkerProcess(ManagedProcess):
"""Process manager for Dynamo worker with vLLM backend"""
def __init__(self, request, worker_id: str):
self.worker_id = worker_id
command = [
"python3",
"-m",
"dynamo.vllm",
"--model",
"deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"--enforce-eager",
"--gpu-memory-utilization",
"0.45",
"--max-model-len",
"8192",
"--migration-limit",
"3",
]
# Set debug logging environment
env = os.environ.copy()
env["DYN_LOG"] = "debug"
env["DYN_SYSTEM_ENABLED"] = "true"
env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
env["DYN_SYSTEM_PORT"] = "9345"
# TODO: Have the managed process take a command name explicitly to distinguish
# between processes started with the same command.
log_dir = f"{request.node.name}_{worker_id}"
# Clean up any existing log directory from previous runs
try:
shutil.rmtree(log_dir)
logger.info(f"Cleaned up existing log directory: {log_dir}")
except FileNotFoundError:
# Directory doesn't exist, which is fine
pass
super().__init__(
command=command,
env=env,
health_check_urls=[("http://localhost:9345/health", self.is_ready)],
timeout=300,
display_output=True,
terminate_existing=False,
log_dir=log_dir,
)
def get_pid(self) -> int | None:
"""Get the PID of the worker process"""
return self.proc.pid if hasattr(self, "proc") and self.proc else None
def is_ready(self, response) -> bool:
"""Check the health of the worker process"""
try:
data = response.json()
if data.get("status") == "ready":
logger.info(
f"{self.__class__.__name__} {{ name: {self.worker_id} }} status is ready"
)
return True
logger.warning(
f"{self.__class__.__name__} {{ name: {self.worker_id} }} status is not ready: {data.get('status')}"
)
except ValueError:
logger.warning(
f"{self.__class__.__name__} {{ name: {self.worker_id} }} health response is not valid JSON"
)
return False
def download_model() -> None:
"""
Download the DeepSeek-R1-Distill-Llama-8B model from HuggingFace Hub if not already cached.
"""
model_id = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
logger.info(f"Caching model {model_id}...")
max_retries = 5
retry_delay = 30 # seconds
for attempt in range(max_retries):
try:
# Download the model to the default cache directory
# This will skip download if the model is already cached
snapshot_download(
repo_id="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
repo_type="model",
local_files_only=False,
)
logger.info(f"Model {model_id} is ready for use")
return # Success, exit the function
except Exception as e:
if attempt < max_retries - 1: # Not the last attempt
logger.warning(
f"Failed to download model {model_id} (attempt {attempt + 1}/{max_retries}): {e}"
)
logger.info(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else: # Last attempt failed
logger.error(
f"Failed to download model {model_id} after {max_retries} attempts: {e}"
)
raise
def send_completion_request(
prompt: str, max_tokens: int, timeout: int = 120
) -> requests.Response:
"""Send a completion request to the frontend"""
payload = {
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"prompt": prompt,
"max_tokens": max_tokens,
}
headers = {"Content-Type": "application/json"}
logger.info(
f"Sending completion request with prompt: '{prompt[:50]}...' and max_tokens: {max_tokens}"
)
try:
response = requests.post(
"http://localhost:8080/v1/completions",
headers=headers,
json=payload,
timeout=timeout,
)
logger.info(f"Received response with status code: {response.status_code}")
return response
except requests.exceptions.Timeout:
logger.error(f"Request timed out after {timeout} seconds")
raise
except requests.exceptions.RequestException as e:
logger.error(f"Request failed with error: {e}")
raise
@pytest.mark.vllm
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.slow
def test_vllm_health_check_active(request, runtime_services):
"""
End-to-end test for worker fault tolerance with migration support.
This test verifies that when a worker is killed during request processing,
the system can handle the failure gracefully and migrate the request to
another worker.
"""
# Step 0: Download the model from HuggingFace if not already cached
download_model()
# Step 1: Start the frontend
logger.info("Starting frontend...")
with DynamoFrontendProcess(request):
logger.info("Frontend started.")
# Step 2: Start a worker
logger.info("Starting worker...")
with DynamoWorkerProcess(request, "decode") as worker:
logger.info(f"Worker PID: {worker.get_pid()}")
time.sleep(12) # Give the model some time to get started.
# Step 3: Send a test request to prove the worker is live.
test_response = send_completion_request("Who are you?", 100, timeout=60)
completions_response_handler(test_response)
logger.info("Test request completed successfully")
# Step 4: Find and kill vLLM engine processes to force the EngineDeadError condition.
children = worker.subprocesses()
logger.info(f"Worker children: {[child.pid for child in children]}")
for child in children:
cmdline = child.cmdline()
if len(cmdline) > 0 and cmdline[0] == "VLLM::EngineCore":
logger.warning(
f"Killing vLLM engine process {{ pid: {child.pid}, cmdline: '{' '.join(cmdline)}' }}"
)
child.kill()
break
time.sleep(2) # Give some time for the worker to stabilize
# Step 5: Send a request triggering the handler to shutdown everything.
test_response = send_completion_request("How old are you?", 100, timeout=60)
logger.error(f"Test request failed: {test_response}")
# Step 6: Ensure the worker process has been stopped as a result of the EngineDeadError condition.
if worker.is_running():
pytest.fail(
"Worker should not be running after killing vLLM engine process."
)
@pytest.mark.vllm
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.slow
def test_vllm_health_check_passive(request, runtime_services):
"""
End-to-end test for worker fault tolerance with migration support.
This test verifies that when a worker is killed during request processing,
the system can handle the failure gracefully and migrate the request to
another worker.
"""
# Step 0: Download the model from HuggingFace if not already cached
download_model()
# Step 1: Start the frontend
logger.info("Starting frontend...")
with DynamoFrontendProcess(request):
logger.info("Frontend started.")
# Step 2: Start a worker
logger.info("Starting worker...")
with DynamoWorkerProcess(request, "decode") as worker:
logger.info(f"Worker PID: {worker.get_pid()}")
time.sleep(12) # Give the model some time to get started.
# Step 3: Send a test request to prove the worker is live.
test_response = send_completion_request("Who are you?", 100, timeout=60)
completions_response_handler(test_response)
logger.info("Test request completed successfully")
# Step 4: Find and kill vLLM engine processes to force the EngineDeadError condition.
children = worker.subprocesses()
logger.info(f"Worker children: {[child.pid for child in children]}")
for child in children:
cmdline = child.cmdline()
if len(cmdline) > 0 and cmdline[0] == "VLLM::EngineCore":
logger.warning(
f"Killing vLLM engine process {{ pid: {child.pid}, cmdline: '{' '.join(cmdline)}' }}"
)
child.kill()
break
time.sleep(6) # Give some time for the worker to stabilize
# Step 5: Ensure the worker process has been stopped as a result of the EngineDeadError condition.
if worker.is_running():
pytest.fail(
"Worker should not be running after killing vLLM engine process."
)
...@@ -433,6 +433,27 @@ class ManagedProcess: ...@@ -433,6 +433,27 @@ class ManagedProcess:
# Process may have terminated or become inaccessible during iteration # Process may have terminated or become inaccessible during iteration
pass pass
def is_running(self) -> bool:
"""Check if the process is still running"""
return (
hasattr(self, "proc") and self.proc is not None and self.proc.poll() is None
)
def subprocesses(self) -> list[psutil.Process]:
"""Find child processes of the current process."""
if (
not hasattr(self, "proc")
or self.proc is None
or self.proc.poll() is not None
):
return []
try:
parent = psutil.Process(self.proc.pid)
return parent.children(recursive=True)
except psutil.NoSuchProcess:
return []
def main(): def main():
with ManagedProcess( with ManagedProcess(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment