"vllm/vscode:/vscode.git/clone" did not exist on "285178b3b824d70b46b351daa8f8942d23da264a"
Unverified Commit 87190db0 authored by Jacky's avatar Jacky Committed by GitHub
Browse files

refactor: Simplify TRT-LLM request abort calling logic (#3296)


Signed-off-by: default avatarJacky <18255193+kthui@users.noreply.github.com>
parent 56d20f53
......@@ -20,9 +20,10 @@ import os
from contextlib import asynccontextmanager
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Any, AsyncGenerator, Optional, Union
from typing import AsyncGenerator, Optional, Union
import torch
from tensorrt_llm.executor.result import GenerationResult
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
from tensorrt_llm.llmapi.llm import SamplingParams
......@@ -103,34 +104,23 @@ class HandlerBase:
result["finish_reason"] == "stop" or result["finish_reason"] == "error"
)
async def _handle_cancellation(self, generation_result: Any, context: Context):
async def _handle_cancellation(
self, generation_result: GenerationResult, context: Context
):
"""Background task to handle cancellation by monitoring context state."""
try:
# Wait asynchronously for cancellation signal instead of polling
await context.async_killed_or_stopped()
# Call abort_request on the executor through the LLM instance
if hasattr(self.engine.llm, "_executor") and self.engine.llm._executor:
# Get the internal request ID from the generation result
internal_request_id = getattr(generation_result, "request_id", None)
if internal_request_id is not None:
# TODO: Can this be an official abort method in TRT-LLM?
self.engine.llm._executor.abort_request(internal_request_id)
# Abort the generation
generation_result.abort()
logging.debug(f"Aborted Request ID: {context.id()}")
else:
logging.error(
f"Could not retrieve internal request ID for abort: {context.id()}"
)
else:
logging.error(
f"TensorRT-LLM executor not found for abort request: {context.id()}"
)
except asyncio.CancelledError:
# Task was cancelled, which is expected when generation completes
pass
@asynccontextmanager
async def _cancellation_monitor(
self, generation_result: Any, context: Context
self, generation_result: GenerationResult, context: Context
) -> AsyncGenerator[asyncio.Task, None]:
"""
Context manager for monitoring request cancellation.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment