Unverified Commit 87190db0 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

refactor: Simplify TRT-LLM request abort calling logic (#3296)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent 56d20f53
...@@ -20,9 +20,10 @@ import os ...@@ -20,9 +20,10 @@ import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from enum import Enum from enum import Enum
from typing import Any, AsyncGenerator, Optional, Union from typing import AsyncGenerator, Optional, Union
import torch import torch
from tensorrt_llm.executor.result import GenerationResult
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi.llm import SamplingParams from tensorrt_llm.llmapi.llm import SamplingParams
...@@ -103,34 +104,23 @@ class HandlerBase: ...@@ -103,34 +104,23 @@ class HandlerBase:
result["finish_reason"] == "stop" or result["finish_reason"] == "error" result["finish_reason"] == "stop" or result["finish_reason"] == "error"
) )
async def _handle_cancellation(self, generation_result: Any, context: Context): async def _handle_cancellation(
self, generation_result: GenerationResult, context: Context
):
"""Background task to handle cancellation by monitoring context state.""" """Background task to handle cancellation by monitoring context state."""
try: try:
# Wait asynchronously for cancellation signal instead of polling # Wait asynchronously for cancellation signal instead of polling
await context.async_killed_or_stopped() await context.async_killed_or_stopped()
# Call abort_request on the executor through the LLM instance # Abort the generation
if hasattr(self.engine.llm, "_executor") and self.engine.llm._executor: generation_result.abort()
# Get the internal request ID from the generation result
internal_request_id = getattr(generation_result, "request_id", None)
if internal_request_id is not None:
# TODO: Can this be an official abort method in TRT-LLM?
self.engine.llm._executor.abort_request(internal_request_id)
logging.debug(f"Aborted Request ID: {context.id()}") logging.debug(f"Aborted Request ID: {context.id()}")
else:
logging.error(
f"Could not retrieve internal request ID for abort: {context.id()}"
)
else:
logging.error(
f"TensorRT-LLM executor not found for abort request: {context.id()}"
)
except asyncio.CancelledError: except asyncio.CancelledError:
# Task was cancelled, which is expected when generation completes # Task was cancelled, which is expected when generation completes
pass pass
@asynccontextmanager @asynccontextmanager
async def _cancellation_monitor( async def _cancellation_monitor(
self, generation_result: Any, context: Context self, generation_result: GenerationResult, context: Context
) -> AsyncGenerator[asyncio.Task, None]: ) -> AsyncGenerator[asyncio.Task, None]:
""" """
Context manager for monitoring request cancellation. Context manager for monitoring request cancellation.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment