Unverified Commit 77e66ae5 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

feat: Request Cancellation TRT-LLM (#3193)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent a13c4cb6
...@@ -13,17 +13,20 @@ ...@@ -13,17 +13,20 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import copy import copy
import logging import logging
import os import os
from contextlib import asynccontextmanager
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from enum import Enum from enum import Enum
from typing import Optional, Union from typing import Any, AsyncGenerator, Optional, Union
import torch import torch
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi.llm import SamplingParams from tensorrt_llm.llmapi.llm import SamplingParams
from dynamo._core import Context
from dynamo.logits_processing.examples import HelloWorldLogitsProcessor from dynamo.logits_processing.examples import HelloWorldLogitsProcessor
from dynamo.nixl_connect import Connector from dynamo.nixl_connect import Connector
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
...@@ -100,14 +103,71 @@ class HandlerBase: ...@@ -100,14 +103,71 @@ class HandlerBase:
result["finish_reason"] == "stop" or result["finish_reason"] == "error" result["finish_reason"] == "stop" or result["finish_reason"] == "error"
) )
async def _handle_cancellation(self, generation_result: Any, context: Context):
"""Background task to handle cancellation by monitoring context state."""
try:
# Wait asynchronously for cancellation signal instead of polling
await context.async_killed_or_stopped()
# Call abort_request on the executor through the LLM instance
if hasattr(self.engine.llm, "_executor") and self.engine.llm._executor:
# Get the internal request ID from the generation result
internal_request_id = getattr(generation_result, "request_id", None)
if internal_request_id is not None:
# TODO: Can this be an official abort method in TRT-LLM?
self.engine.llm._executor.abort_request(internal_request_id)
logging.debug(f"Aborted Request ID: {context.id()}")
else:
logging.error(
f"Could not retrieve internal request ID for abort: {context.id()}"
)
else:
logging.error(
f"TensorRT-LLM executor not found for abort request: {context.id()}"
)
except asyncio.CancelledError:
# Task was cancelled, which is expected when generation completes
pass
@asynccontextmanager
async def _cancellation_monitor(
self, generation_result: Any, context: Context
) -> AsyncGenerator[asyncio.Task, None]:
"""
Context manager for monitoring request cancellation.
Automatically creates a background task to monitor for cancellation and
cleans it up when the context exits.
Yields:
asyncio.Task: The cancellation monitoring task
"""
cancellation_task = asyncio.create_task(
self._handle_cancellation(generation_result, context)
)
try:
yield cancellation_task
finally:
# Clean up the background cancellation task
if not cancellation_task.done():
cancellation_task.cancel()
try:
await cancellation_task
except asyncio.CancelledError:
pass
async def generate_locally( async def generate_locally(
self, request: dict, embeddings: Optional[Union[torch.Tensor, dict]] = None self,
request: dict,
context: Context,
embeddings: Optional[Union[torch.Tensor, dict]] = None,
): ):
""" """
Generate responses based on the disaggregation mode in the request. Generate responses based on the disaggregation mode in the request.
Args: Args:
request: The request dictionary containing generation parameters request: The request dictionary containing generation parameters
context: Context object for cancellation handling
embeddings: Optional tensor or dict containing embeddings for multimodal processing embeddings: Optional tensor or dict containing embeddings for multimodal processing
""" """
logging.debug(f"Request: {request}") logging.debug(f"Request: {request}")
...@@ -192,50 +252,57 @@ class HandlerBase: ...@@ -192,50 +252,57 @@ class HandlerBase:
sampling_params.logits_processor = adapters sampling_params.logits_processor = adapters
# NEW: Updated engine call to include multimodal data # NEW: Updated engine call to include multimodal data
async for res in self.engine.llm.generate_async( generation_result = self.engine.llm.generate_async(
inputs=processed_input, # Use the correctly extracted inputs inputs=processed_input, # Use the correctly extracted inputs
sampling_params=sampling_params, sampling_params=sampling_params,
disaggregated_params=disaggregated_params, disaggregated_params=disaggregated_params,
streaming=streaming, streaming=streaming,
): )
# TRTLLM engine needs to start generating tokens first before stats
# can be retrieved. # Use the context manager to handle cancellation monitoring
if self.first_generation and self.publisher: async with self._cancellation_monitor(generation_result, context):
self.publisher.start() async for res in generation_result:
self.first_generation = False # TRTLLM engine needs to start generating tokens first before stats
# can be retrieved.
# Upon completion, send a final chunk with "stop" as the finish reason. if self.first_generation and self.publisher:
# This signals to the client that the stream has ended. self.publisher.start()
if res.finished and self.disaggregation_mode != DisaggregationMode.PREFILL: self.first_generation = False
# Upon completion, send a final chunk with "stop" as the finish reason.
# This signals to the client that the stream has ended.
if (
res.finished
and self.disaggregation_mode != DisaggregationMode.PREFILL
):
if self.multimodal_processor:
final_out = self.multimodal_processor.get_stop_response(
request_id, model_name
)
yield final_out
if not res.outputs:
yield {"finish_reason": "error", "token_ids": []}
break
output = res.outputs[0]
# The engine returns all tokens generated so far. We must calculate the new
# tokens generated in this iteration to create the "delta".
next_total_toks = len(output.token_ids)
if self.multimodal_processor: if self.multimodal_processor:
final_out = self.multimodal_processor.get_stop_response( out = self.multimodal_processor.create_response_chunk(
request_id, model_name output, num_output_tokens_so_far, request_id, model_name
) )
yield final_out else:
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if not res.outputs: if output.finish_reason:
yield {"finish_reason": "error", "token_ids": []} out["finish_reason"] = output.finish_reason
break if output.stop_reason:
out["stop_reason"] = output.stop_reason
output = res.outputs[0] if self.disaggregation_mode == DisaggregationMode.PREFILL:
# The engine returns all tokens generated so far. We must calculate the new # Return the disaggregated params only when operating in prefill mode.
# tokens generated in this iteration to create the "delta". out["disaggregated_params"] = asdict(
next_total_toks = len(output.token_ids) DisaggregatedParamsCodec.encode(output.disaggregated_params)
if self.multimodal_processor: )
out = self.multimodal_processor.create_response_chunk( # Yield the chunk to the client and update the token count for the next iteration.
output, num_output_tokens_so_far, request_id, model_name yield out
) num_output_tokens_so_far = next_total_toks
else:
out = {"token_ids": output.token_ids[num_output_tokens_so_far:]}
if output.finish_reason:
out["finish_reason"] = output.finish_reason
if output.stop_reason:
out["stop_reason"] = output.stop_reason
if self.disaggregation_mode == DisaggregationMode.PREFILL:
# Return the disaggregated params only when operating in prefill mode.
out["disaggregated_params"] = asdict(
DisaggregatedParamsCodec.encode(output.disaggregated_params)
)
# Yield the chunk to the client and update the token count for the next iteration.
yield out
num_output_tokens_so_far = next_total_toks
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import copy import copy
import logging import logging
from dynamo._core import Context
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.encode_helper import EncodeHelper from dynamo.trtllm.encode_helper import EncodeHelper
from dynamo.trtllm.request_handlers.handler_base import ( from dynamo.trtllm.request_handlers.handler_base import (
...@@ -66,9 +67,10 @@ class AggregatedHandler(HandlerBase): ...@@ -66,9 +67,10 @@ class AggregatedHandler(HandlerBase):
def __init__(self, config: RequestHandlerConfig): def __init__(self, config: RequestHandlerConfig):
super().__init__(config) super().__init__(config)
async def generate(self, request: dict): async def generate(self, request: dict, context: Context):
logging.debug(f"New Request ID: {context.id()}")
# Implement all steps locally. # Implement all steps locally.
async for res in self.generate_locally(request): async for res in self.generate_locally(request, context):
yield res yield res
...@@ -80,7 +82,8 @@ class EncodeHandler(HandlerBase): ...@@ -80,7 +82,8 @@ class EncodeHandler(HandlerBase):
def __init__(self, config: RequestHandlerConfig): def __init__(self, config: RequestHandlerConfig):
super().__init__(config) super().__init__(config)
async def generate(self, request: dict): async def generate(self, request: dict, context: Context):
logging.debug(f"New Request ID: {context.id()}")
if self.connector: if self.connector:
# Use helper method to process embedding request # Use helper method to process embedding request
async for response in EncodeHelper.process_embedding_request( async for response in EncodeHelper.process_embedding_request(
...@@ -122,11 +125,12 @@ class PrefillHandler(HandlerBase): ...@@ -122,11 +125,12 @@ class PrefillHandler(HandlerBase):
encode_response, self.connector encode_response, self.connector
) )
async def remote_decode(self, request: dict): async def remote_decode(self, request: dict, context: Context):
async for res in await self.next_client.round_robin(request): async for res in await self.next_client.round_robin(request, context=context):
yield res.data() yield res.data()
async def generate(self, request: dict): async def generate(self, request: dict, context: Context):
logging.debug(f"New Request ID: {context.id()}")
logging.debug(f"PrefillHandler.generate received request: {request}") logging.debug(f"PrefillHandler.generate received request: {request}")
embeddings_tensor = None embeddings_tensor = None
...@@ -145,12 +149,18 @@ class PrefillHandler(HandlerBase): ...@@ -145,12 +149,18 @@ class PrefillHandler(HandlerBase):
prefill_request = copy.deepcopy(request) prefill_request = copy.deepcopy(request)
prefill_response = None prefill_response = None
response_count = 0 response_count = 0
async for res in self.generate_locally(prefill_request, embeddings_tensor): async for res in self.generate_locally(
prefill_request, context, embeddings_tensor
):
prefill_response = res prefill_response = res
response_count += 1 response_count += 1
if response_count > 1: if response_count > 1:
raise ValueError("Prefill response should be generated only once.") raise ValueError("Prefill response should be generated only once.")
if context.is_stopped() or context.is_killed():
# Local generate abort monitor will print debug log, so only returning here.
return
if ( if (
self.disaggregation_strategy == DisaggregationStrategy.PREFILL_FIRST self.disaggregation_strategy == DisaggregationStrategy.PREFILL_FIRST
and not self.check_error(prefill_response) and not self.check_error(prefill_response)
...@@ -161,8 +171,12 @@ class PrefillHandler(HandlerBase): ...@@ -161,8 +171,12 @@ class PrefillHandler(HandlerBase):
request["disaggregated_params"] = prefill_response[ request["disaggregated_params"] = prefill_response[
"disaggregated_params" "disaggregated_params"
] ]
async for res in self.remote_decode(request): async for res in self.remote_decode(request, context):
yield res yield res
if context.is_stopped() or context.is_killed():
logging.debug(f"Aborted Remote Request ID: {context.id()}")
return
else: else:
# Return response to the decode handler. # Return response to the decode handler.
yield prefill_response yield prefill_response
...@@ -176,11 +190,12 @@ class DecodeHandler(HandlerBase): ...@@ -176,11 +190,12 @@ class DecodeHandler(HandlerBase):
def __init__(self, config: RequestHandlerConfig): def __init__(self, config: RequestHandlerConfig):
super().__init__(config) super().__init__(config)
async def remote_prefill(self, request: dict): async def remote_prefill(self, request: dict, context: Context):
async for res in await self.next_client.round_robin(request): async for res in await self.next_client.round_robin(request, context=context):
yield res yield res
async def generate(self, request: dict): async def generate(self, request: dict, context: Context):
logging.debug(f"New Request ID: {context.id()}")
if self.disaggregation_strategy == DisaggregationStrategy.DECODE_FIRST: if self.disaggregation_strategy == DisaggregationStrategy.DECODE_FIRST:
prefill_response = None prefill_response = None
# If operating under decode_first strategy, the decode handler needs to trigger # If operating under decode_first strategy, the decode handler needs to trigger
...@@ -188,12 +203,16 @@ class DecodeHandler(HandlerBase): ...@@ -188,12 +203,16 @@ class DecodeHandler(HandlerBase):
response_count = 0 response_count = 0
# Do not yield the prefill response directly. # Do not yield the prefill response directly.
# Instead, capture it and extract the state. # Instead, capture it and extract the state.
async for res in self.remote_prefill(request): async for res in self.remote_prefill(request, context):
prefill_response = res prefill_response = res
response_count += 1 response_count += 1
if response_count > 1: if response_count > 1:
raise ValueError("Prefill response should be generated only once.") raise ValueError("Prefill response should be generated only once.")
if context.is_stopped() or context.is_killed():
logging.debug(f"Aborted Remote Request ID: {context.id()}")
return
response_data = ( response_data = (
prefill_response.data() if prefill_response is not None else None prefill_response.data() if prefill_response is not None else None
) )
...@@ -204,5 +223,5 @@ class DecodeHandler(HandlerBase): ...@@ -204,5 +223,5 @@ class DecodeHandler(HandlerBase):
if prefill_response is not None and response_data is not None: if prefill_response is not None and response_data is not None:
request["disaggregated_params"] = response_data["disaggregated_params"] request["disaggregated_params"] = response_data["disaggregated_params"]
async for res in self.generate_locally(request): async for res in self.generate_locally(request, context):
yield res yield res
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import os
import shutil
import time
import pytest
from tests.fault_tolerance.cancellation.utils import (
DynamoFrontendProcess,
read_log_content,
send_request_and_cancel,
strip_ansi_codes,
)
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.engine_process import FRONTEND_PORT
from tests.utils.managed_process import ManagedProcess
from tests.utils.payloads import check_health_generate, check_models_api
logger = logging.getLogger(__name__)
class DynamoWorkerProcess(ManagedProcess):
"""Process manager for Dynamo worker with TensorRT-LLM backend"""
def __init__(self, request, mode: str = "prefill_and_decode", strategy: str = ""):
"""
Initialize TensorRT-LLM worker process.
Args:
request: pytest request object
mode: One of "prefill_and_decode", "prefill", "decode"
strategy: One of "decode_first", "prefill_first"
"""
command = [
"python3",
"-m",
"dynamo.trtllm",
"--model",
FAULT_TOLERANCE_MODEL_NAME,
"--disaggregation-mode",
mode,
"--free-gpu-memory-fraction",
"0.45",
"--max-seq-len",
"8192",
"--migration-limit",
"3",
]
if mode != "prefill_and_decode":
with open("test_request_cancellation_trtllm_config.yaml", "w") as f:
f.write("cache_transceiver_config:\n backend: DEFAULT\n")
f.write("disable_overlap_scheduler: true\n")
command += [
"--extra-engine-args",
"test_request_cancellation_trtllm_config.yaml",
"--disaggregation-strategy",
strategy,
]
health_check_urls = [
(f"http://localhost:{FRONTEND_PORT}/v1/models", check_models_api),
(f"http://localhost:{FRONTEND_PORT}/health", check_health_generate),
]
# Set port based on worker type
if mode == "prefill":
port = "8082"
health_check_urls = [(f"http://localhost:{port}/health", self.is_ready)]
elif mode == "decode":
port = "8081"
health_check_urls = [(f"http://localhost:{port}/health", self.is_ready)]
else: # prefill_and_decode
port = "8081"
# Set debug logging environment
env = os.environ.copy()
env["DYN_LOG"] = "debug"
env["DYN_SYSTEM_ENABLED"] = "true"
env["DYN_SYSTEM_USE_ENDPOINT_HEALTH_STATUS"] = '["generate"]'
env["DYN_SYSTEM_PORT"] = port
# Set log directory based on worker type
log_dir = f"{request.node.name}_{mode}_worker"
# Clean up any existing log directory from previous runs
try:
shutil.rmtree(log_dir)
logger.info(f"Cleaned up existing log directory: {log_dir}")
except FileNotFoundError:
# Directory doesn't exist, which is fine
pass
super().__init__(
command=command,
env=env,
health_check_urls=health_check_urls,
timeout=300,
display_output=True,
terminate_existing=False,
log_dir=log_dir,
)
self.mode = mode
def get_pid(self):
"""Get the PID of the worker process"""
return self.proc.pid if self.proc else None
def is_ready(self, response) -> bool:
"""Check the health of the worker process"""
try:
data = response.json()
if data.get("status") == "ready":
logger.info(f"{self.mode.capitalize()} worker status is ready")
return True
logger.warning(
f"{self.mode.capitalize()} worker status is not ready: {data.get('status')}"
)
except ValueError:
logger.warning(
f"{self.mode.capitalize()} worker health response is not valid JSON"
)
return False
def verify_request_cancelled(
frontend_process: DynamoFrontendProcess,
worker_process: DynamoWorkerProcess,
remote_worker_process: DynamoWorkerProcess | None = None,
frontend_log_offset: int = 0,
worker_log_offset: int = 0,
remote_worker_log_offset: int = 0,
assert_request_reach_remote_worker: bool = False,
assert_cancel_at_remote_worker: bool = False,
) -> tuple[int, int, int]:
"""Verify the logs contain expected cancellation messages"""
# Check worker log for cancellation pattern
worker_log_content = read_log_content(worker_process._log_path)
new_worker_content = worker_log_content[worker_log_offset:]
# Find the LAST occurrence of "New Request ID: <id>" line (health checks may log earlier ones)
request_id = None
for line in reversed(new_worker_content.split("\n")):
# Strip ANSI codes and whitespace for pattern matching
clean_line = strip_ansi_codes(line).strip()
if "New Request ID: " in clean_line:
# Extract ID from the last delimiter occurrence on the line
parts = clean_line.rsplit("New Request ID: ", 1)
if len(parts) > 1:
request_id = parts[-1].strip()
break
if request_id is None:
pytest.fail("Could not find 'New Request ID: <id>' pattern in worker log")
has_worker_cancellation = False
cancellation_pattern = f"Aborted {'Remote ' if assert_cancel_at_remote_worker else ''}Request ID: {request_id}"
for line in new_worker_content.split("\n"):
# Strip ANSI codes and whitespace for pattern matching
clean_line = strip_ansi_codes(line).strip()
if clean_line.endswith(cancellation_pattern):
has_worker_cancellation = True
break
if not has_worker_cancellation:
pytest.fail(f"Could not find '{cancellation_pattern}' pattern in worker log")
# Check remote worker log if provided
if remote_worker_process is not None:
remote_worker_log_content = read_log_content(remote_worker_process._log_path)
new_remote_worker_content = remote_worker_log_content[remote_worker_log_offset:]
# Check if the same request ID reached remote worker
if assert_request_reach_remote_worker:
has_reach_remote = False
remote_reach_pattern = f"New Request ID: {request_id}"
for line in new_remote_worker_content.split("\n"):
clean_line = strip_ansi_codes(line).strip()
if clean_line.endswith(remote_reach_pattern):
has_reach_remote = True
break
if not has_reach_remote:
pytest.fail(
f"Could not find '{remote_reach_pattern}' pattern in remote worker log"
)
# Check if the same request ID was cancelled at remote worker
if assert_cancel_at_remote_worker:
has_remote_cancel = False
remote_cancel_pattern = f"Aborted Request ID: {request_id}"
for line in remote_worker_log_content.split("\n"):
clean_line = strip_ansi_codes(line).strip()
if clean_line.endswith(remote_cancel_pattern):
has_remote_cancel = True
break
if not has_remote_cancel:
pytest.fail(
f"Could not find '{remote_cancel_pattern}' pattern in remote worker log"
)
# Check frontend log for cancellation issued pattern
frontend_log_content = read_log_content(frontend_process._log_path)
new_frontend_content = frontend_log_content[frontend_log_offset:]
has_kill_message = False
kill_message = "issued control message Kill to sender"
for line in new_frontend_content.split("\n"):
# Strip ANSI codes and whitespace for pattern matching
clean_line = strip_ansi_codes(line).strip()
if clean_line.endswith(kill_message):
has_kill_message = True
break
if not has_kill_message:
pytest.fail("Could not find cancellation issued in frontend log")
return (
len(frontend_log_content),
len(worker_log_content),
(0 if remote_worker_process is None else len(remote_worker_log_content)),
)
@pytest.mark.trtllm_marker
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_request_cancellation_trtllm_aggregated(
request, runtime_services, predownload_models
):
"""
End-to-end test for request cancellation functionality in aggregated mode.
This test verifies that when a request is cancelled by the client,
the system properly handles the cancellation and cleans up resources
on the worker side in aggregated (prefill_and_decode) mode.
"""
# Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start an aggregated worker
logger.info("Starting aggregated worker...")
worker = DynamoWorkerProcess(request, mode="prefill_and_decode")
with worker:
logger.info(f"Aggregated Worker PID: {worker.get_pid()}")
# TODO: Why wait after worker ready fixes frontend 404 / 500 flakiness?
time.sleep(2)
# Step 3: Test request cancellation
frontend_log_offset, worker_log_offset = 0, 0
test_scenarios = [
("completion", "Completion request cancellation"),
("chat_completion", "Chat completion request cancellation"),
(
"chat_completion_stream",
"Chat completion stream request cancellation",
),
]
for i, (request_type, description) in enumerate(test_scenarios, 1):
logger.info(f"Testing {description.lower()}...")
send_request_and_cancel(request_type)
logger.info(
"Checking for cancellation messages in worker and frontend logs..."
)
time.sleep(0.05) # time for cancellation to propagate
frontend_log_offset, worker_log_offset, _ = verify_request_cancelled(
frontend,
worker,
frontend_log_offset=frontend_log_offset,
worker_log_offset=worker_log_offset,
)
logger.info(f"{description} detected successfully")
@pytest.mark.trtllm_marker
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_request_cancellation_trtllm_decode_first_decode_cancel(
request, runtime_services, predownload_models
):
"""
End-to-end test for request cancellation during decode phase with decode_first strategy.
This test verifies that when a request is cancelled by the client during the decode phase,
the system properly handles the cancellation and cleans up resources
on the decode worker side in a disaggregated setup using decode_first strategy.
"""
# Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start the prefill worker
logger.info("Starting prefill worker...")
prefill_worker = DynamoWorkerProcess(
request, mode="prefill", strategy="decode_first"
)
with prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
# Step 3: Start the decode worker
logger.info("Starting decode worker...")
decode_worker = DynamoWorkerProcess(
request, mode="decode", strategy="decode_first"
)
with decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
# TODO: Why wait after worker ready fixes frontend 404 / 500 flakiness?
time.sleep(2)
# Step 4: Test request cancellation for completion scenario only
logger.info(
"Testing completion request cancellation in decode worker (decode phase)..."
)
send_request_and_cancel("completion")
logger.info(
"Checking for cancellation messages in decode and prefill worker and frontend logs..."
)
time.sleep(0.05) # time for cancellation to propagate
verify_request_cancelled(
frontend,
decode_worker,
prefill_worker,
assert_request_reach_remote_worker=True,
assert_cancel_at_remote_worker=False,
)
@pytest.mark.trtllm_marker
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_request_cancellation_trtllm_decode_first_remote_prefill_cancel(
request, runtime_services, predownload_models
):
"""
End-to-end test for request cancellation during remote prefill phase with decode_first strategy.
This test verifies that when a request is cancelled by the client during the remote prefill phase,
the system properly handles the cancellation and cleans up resources
on both the decode and prefill workers in a disaggregated setup using decode_first strategy.
"""
# Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start the prefill worker
logger.info("Starting prefill worker...")
prefill_worker = DynamoWorkerProcess(
request, mode="prefill", strategy="decode_first"
)
with prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
# Step 3: Start the decode worker
logger.info("Starting decode worker...")
decode_worker = DynamoWorkerProcess(
request, mode="decode", strategy="decode_first"
)
with decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
# TODO: Why wait after worker ready fixes frontend 404 / 500 flakiness?
time.sleep(2)
# Step 4: Test request cancellation during remote prefill phase
logger.info(
"Testing completion request cancellation during remote prefill phase..."
)
send_request_and_cancel("completion", timeout=0.1, use_long_prompt=True)
logger.info(
"Checking for cancellation messages in decode and prefill worker and frontend logs..."
)
time.sleep(0.05) # time for cancellation to propagate
verify_request_cancelled(
frontend,
decode_worker,
prefill_worker,
assert_request_reach_remote_worker=True,
assert_cancel_at_remote_worker=True,
)
@pytest.mark.trtllm_marker
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_request_cancellation_trtllm_prefill_first_prefill_cancel(
request, runtime_services, predownload_models
):
"""
End-to-end test for request cancellation during prefill phase with prefill_first strategy.
This test verifies that when a request is cancelled by the client during the prefill phase,
the system properly handles the cancellation and cleans up resources
on the prefill worker side in a disaggregated setup using prefill_first strategy.
"""
# Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start the decode worker
logger.info("Starting decode worker...")
decode_worker = DynamoWorkerProcess(
request, mode="decode", strategy="prefill_first"
)
with decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
# Step 3: Start the prefill worker
logger.info("Starting prefill worker...")
prefill_worker = DynamoWorkerProcess(
request, mode="prefill", strategy="prefill_first"
)
with prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
# TODO: Why wait after worker ready fixes frontend 404 / 500 flakiness?
time.sleep(2)
# Step 4: Test request cancellation during prefill phase
logger.info(
"Testing completion request cancellation during prefill phase..."
)
send_request_and_cancel("completion", timeout=0.1, use_long_prompt=True)
logger.info(
"Checking for cancellation messages in prefill and decode worker and frontend logs..."
)
time.sleep(0.05) # time for cancellation to propagate
verify_request_cancelled(
frontend,
prefill_worker,
decode_worker,
assert_request_reach_remote_worker=False,
assert_cancel_at_remote_worker=False,
)
@pytest.mark.trtllm_marker
@pytest.mark.gpu_1
@pytest.mark.e2e
@pytest.mark.model(FAULT_TOLERANCE_MODEL_NAME)
def test_request_cancellation_trtllm_prefill_first_remote_decode_cancel(
request, runtime_services, predownload_models
):
"""
End-to-end test for request cancellation during remote decode phase with prefill_first strategy.
This test verifies that when a request is cancelled by the client during the remote decode phase,
the system properly handles the cancellation and cleans up resources
on both the prefill and decode workers in a disaggregated setup using prefill_first strategy.
"""
# Step 1: Start the frontend
with DynamoFrontendProcess(request) as frontend:
logger.info("Frontend started successfully")
# Step 2: Start the decode worker
logger.info("Starting decode worker...")
decode_worker = DynamoWorkerProcess(
request, mode="decode", strategy="prefill_first"
)
with decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
# Step 3: Start the prefill worker
logger.info("Starting prefill worker...")
prefill_worker = DynamoWorkerProcess(
request, mode="prefill", strategy="prefill_first"
)
with prefill_worker:
logger.info(f"Prefill Worker PID: {prefill_worker.get_pid()}")
# TODO: Why wait after worker ready fixes frontend 404 / 500 flakiness?
time.sleep(2)
# Step 4: Test request cancellation during remote decode phase
logger.info(
"Testing completion request cancellation during remote decode phase..."
)
send_request_and_cancel("completion")
logger.info(
"Checking for cancellation messages in prefill and decode worker and frontend logs..."
)
time.sleep(0.05) # time for cancellation to propagate
verify_request_cancelled(
frontend,
prefill_worker,
decode_worker,
assert_request_reach_remote_worker=True,
assert_cancel_at_remote_worker=True,
)
...@@ -3,13 +3,17 @@ ...@@ -3,13 +3,17 @@
import logging import logging
import os import os
import re
import shutil import shutil
import time import time
import pytest import pytest
import requests
from tests.fault_tolerance.cancellation.utils import (
DynamoFrontendProcess,
read_log_content,
send_request_and_cancel,
strip_ansi_codes,
)
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.engine_process import FRONTEND_PORT from tests.utils.engine_process import FRONTEND_PORT
from tests.utils.managed_process import ManagedProcess from tests.utils.managed_process import ManagedProcess
...@@ -18,35 +22,6 @@ from tests.utils.payloads import check_health_generate, check_models_api ...@@ -18,35 +22,6 @@ from tests.utils.payloads import check_health_generate, check_models_api
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class DynamoFrontendProcess(ManagedProcess):
"""Process manager for Dynamo frontend"""
def __init__(self, request):
command = ["python", "-m", "dynamo.frontend"]
# Set debug logging environment
env = os.environ.copy()
env["DYN_LOG"] = "debug"
log_dir = f"{request.node.name}_frontend"
# Clean up any existing log directory from previous runs
try:
shutil.rmtree(log_dir)
logger.info(f"Cleaned up existing log directory: {log_dir}")
except FileNotFoundError:
# Directory doesn't exist, which is fine
pass
super().__init__(
command=command,
env=env,
display_output=True,
terminate_existing=True,
log_dir=log_dir,
)
class DynamoWorkerProcess(ManagedProcess): class DynamoWorkerProcess(ManagedProcess):
"""Process manager for Dynamo worker with vLLM backend""" """Process manager for Dynamo worker with vLLM backend"""
...@@ -137,138 +112,6 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -137,138 +112,6 @@ class DynamoWorkerProcess(ManagedProcess):
return False return False
def send_completion_request(
prompt: str, max_tokens: int, timeout: int | float = 120
) -> requests.Response:
"""Send a completion request to the frontend"""
payload = {
"model": FAULT_TOLERANCE_MODEL_NAME,
"prompt": prompt,
"max_tokens": max_tokens,
}
headers = {"Content-Type": "application/json"}
logger.info(
f"Sending completion request with prompt: '{prompt[:50]}...' and max_tokens: {max_tokens}"
)
session = requests.Session()
try:
response = session.post(
"http://localhost:8000/v1/completions",
headers=headers,
json=payload,
timeout=timeout,
)
logger.info(f"Received response with status code: {response.status_code}")
return response
except requests.exceptions.Timeout:
logger.error(f"Request timed out after {timeout} seconds")
raise
except requests.exceptions.RequestException as e:
logger.error(f"Request failed with error: {e}")
raise
def send_chat_completion_request(
prompt: str, max_tokens: int, timeout: int | float = 120, stream: bool = False
) -> requests.Response:
"""Send a chat completion request to the frontend"""
payload = {
"model": FAULT_TOLERANCE_MODEL_NAME,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"stream": stream,
}
headers = {"Content-Type": "application/json"}
logger.info(
f"Sending chat completion request (stream={stream}) with prompt: '{prompt[:50]}...' and max_tokens: {max_tokens}"
)
session = requests.Session()
try:
response = session.post(
"http://localhost:8000/v1/chat/completions",
headers=headers,
json=payload,
timeout=timeout,
stream=stream,
)
logger.info(f"Received response with status code: {response.status_code}")
return response
except requests.exceptions.Timeout:
logger.error(f"Request timed out after {timeout} seconds")
raise
except requests.exceptions.RequestException as e:
logger.error(f"Request failed with error: {e}")
raise
def send_request_and_cancel(
request_type: str = "completion",
timeout: int | float = 1,
use_long_prompt: bool = False,
):
"""Send a request with short timeout to trigger cancellation"""
logger.info(f"Sending {request_type} request to be cancelled...")
prompt = "Tell me a very long and detailed story about the history of artificial intelligence, including all major milestones, researchers, and breakthroughs?"
if use_long_prompt:
prompt += " Make sure it is" + " long" * 8000 + "!"
try:
if request_type == "completion":
response = send_completion_request(prompt, 8000, timeout)
elif request_type == "chat_completion":
response = send_chat_completion_request(prompt, 8000, timeout, False)
elif request_type == "chat_completion_stream":
response = send_chat_completion_request(prompt, 8000, timeout, True)
# Read a few responses and then disconnect
if response.status_code == 200:
itr_count, max_itr = 0, 5
try:
for res in response.iter_lines():
logger.info(f"Received response {itr_count + 1}: {res[:50]}...")
itr_count += 1
if itr_count >= max_itr:
break
time.sleep(0.1)
except Exception as e:
pytest.fail(f"Stream reading failed: {e}")
response.close()
raise Exception("Closed response")
else:
pytest.fail(f"Unknown request type: {request_type}")
pytest.fail(
f"{request_type} request completed unexpectedly - should have been cancelled"
)
except Exception as e:
logger.info(f"{request_type} request was cancelled: {e}")
def read_log_content(log_path: str | None) -> str:
"""Read log content from a file"""
if log_path is None:
pytest.fail("Log path is None - cannot read log content")
try:
with open(log_path, "r") as f:
return f.read()
except Exception as e:
pytest.fail(f"Could not read log file {log_path}: {e}")
def strip_ansi_codes(text: str) -> str:
"""Remove ANSI color codes from text"""
ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
return ansi_escape.sub("", text)
def verify_request_cancelled( def verify_request_cancelled(
frontend_process: DynamoFrontendProcess, frontend_process: DynamoFrontendProcess,
worker_process: DynamoWorkerProcess, worker_process: DynamoWorkerProcess,
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import os
import re
import shutil
import time
import pytest
import requests
from tests.utils.constants import FAULT_TOLERANCE_MODEL_NAME
from tests.utils.engine_process import FRONTEND_PORT
from tests.utils.managed_process import ManagedProcess
logger = logging.getLogger(__name__)
class DynamoFrontendProcess(ManagedProcess):
"""Process manager for Dynamo frontend"""
def __init__(self, request):
command = ["python", "-m", "dynamo.frontend"]
# Set debug logging environment
env = os.environ.copy()
env["DYN_LOG"] = "debug"
log_dir = f"{request.node.name}_frontend"
# Clean up any existing log directory from previous runs
try:
shutil.rmtree(log_dir)
logger.info(f"Cleaned up existing log directory: {log_dir}")
except FileNotFoundError:
# Directory doesn't exist, which is fine
pass
super().__init__(
command=command,
env=env,
display_output=True,
terminate_existing=True,
log_dir=log_dir,
)
def send_completion_request(
prompt: str, max_tokens: int, timeout: int | float = 120
) -> requests.Response:
"""Send a completion request to the frontend"""
payload = {
"model": FAULT_TOLERANCE_MODEL_NAME,
"prompt": prompt,
"max_tokens": max_tokens,
}
headers = {"Content-Type": "application/json"}
logger.info(
f"Sending completion request with prompt: '{prompt[:50]}...' and max_tokens: {max_tokens}"
)
session = requests.Session()
try:
response = session.post(
f"http://localhost:{FRONTEND_PORT}/v1/completions",
headers=headers,
json=payload,
timeout=timeout,
)
logger.info(f"Received response with status code: {response.status_code}")
return response
except requests.exceptions.Timeout:
logger.error(f"Request timed out after {timeout} seconds")
raise
except requests.exceptions.RequestException as e:
logger.error(f"Request failed with error: {e}")
raise
def send_chat_completion_request(
prompt: str, max_tokens: int, timeout: int | float = 120, stream: bool = False
) -> requests.Response:
"""Send a chat completion request to the frontend"""
payload = {
"model": FAULT_TOLERANCE_MODEL_NAME,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"stream": stream,
}
headers = {"Content-Type": "application/json"}
logger.info(
f"Sending chat completion request (stream={stream}) with prompt: '{prompt[:50]}...' and max_tokens: {max_tokens}"
)
session = requests.Session()
try:
response = session.post(
f"http://localhost:{FRONTEND_PORT}/v1/chat/completions",
headers=headers,
json=payload,
timeout=timeout,
stream=stream,
)
logger.info(f"Received response with status code: {response.status_code}")
return response
except requests.exceptions.Timeout:
logger.error(f"Request timed out after {timeout} seconds")
raise
except requests.exceptions.RequestException as e:
logger.error(f"Request failed with error: {e}")
raise
def send_request_and_cancel(
request_type: str = "completion",
timeout: int | float = 1,
use_long_prompt: bool = False,
):
"""Send a request with short timeout to trigger cancellation"""
logger.info(f"Sending {request_type} request to be cancelled...")
prompt = "Tell me a very long and detailed story about the history of artificial intelligence, including all major milestones, researchers, and breakthroughs?"
if use_long_prompt:
prompt += " Make sure it is" + " long" * 8000 + "!"
try:
if request_type == "completion":
response = send_completion_request(prompt, 8000, timeout)
elif request_type == "chat_completion":
response = send_chat_completion_request(prompt, 8000, timeout, False)
elif request_type == "chat_completion_stream":
response = send_chat_completion_request(prompt, 8000, timeout, True)
# Read a few responses and then disconnect
if response.status_code == 200:
itr_count, max_itr = 0, 5
try:
for res in response.iter_lines():
logger.info(f"Received response {itr_count + 1}: {res[:50]}...")
itr_count += 1
if itr_count >= max_itr:
break
time.sleep(0.1)
except Exception as e:
pytest.fail(f"Stream reading failed: {e}")
response.close()
raise Exception("Closed response")
else:
pytest.fail(f"Unknown request type: {request_type}")
pytest.fail(
f"{request_type} request completed unexpectedly - should have been cancelled"
)
except Exception as e:
logger.info(f"{request_type} request was cancelled: {e}")
def read_log_content(log_path: str | None) -> str:
"""Read log content from a file"""
if log_path is None:
pytest.fail("Log path is None - cannot read log content")
try:
with open(log_path, "r") as f:
return f.read()
except Exception as e:
pytest.fail(f"Could not read log file {log_path}: {e}")
def strip_ansi_codes(text: str) -> str:
"""Remove ANSI color codes from text"""
ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
return ansi_escape.sub("", text)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment