Unverified Commit f99b78f0 authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

chore: add mypy typing for trtllm (#6860)

parent 8cef50c6
......@@ -29,6 +29,7 @@ VALID_TRTLLM_CONNECTORS = {"none", "kvbm"}
class Config(DynamoRuntimeConfig, DynamoTrtllmConfig):
component: str
use_kv_events: bool
connector: list[str] # Redeclare for mypy (inherited from DynamoRuntimeConfig)
def validate(self) -> None:
DynamoRuntimeConfig.validate(self)
......
......@@ -3,6 +3,7 @@
"""Dynamo TRT-LLM backend configuration ArgGroup."""
import argparse
from typing import Optional
from tensorrt_llm.llmapi import BuildConfig
......@@ -20,7 +21,7 @@ DEFAULT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
class DynamoTrtllmArgGroup(ArgGroup):
"""TensorRT-LLM-specific Dynamo wrapper configuration."""
def add_arguments(self, parser) -> None:
def add_arguments(self, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--version",
action="version",
......
......@@ -4,6 +4,7 @@
import asyncio
import logging
import threading
from collections.abc import AsyncGenerator
from dataclasses import asdict
from typing import Any, Dict, Optional, Union
......@@ -377,13 +378,13 @@ class EncodeHelper:
@staticmethod
async def process_encode_request(
request: Dict[str, Any],
multimodal_processor,
multimodal_processor: Any,
connector: Optional[nixl_connect.Connector],
tokenizer=None,
model_dir=None,
model_type=None,
engine=None,
):
tokenizer: Any = None,
model_dir: Optional[str] = None,
model_type: Optional[str] = None,
engine: Any = None,
) -> AsyncGenerator[dict, None]:
"""
Process an ENCODE-mode request. Dispatches to the appropriate flow.
......@@ -447,7 +448,7 @@ class EncodeHelper:
# if the model's tokenizer_config chat template emits them).
token_ids = request.get("token_ids")
async for response in EncodeHelper._process_full_epd_flow(
token_ids,
token_ids, # type: ignore
image_urls,
tokenizer,
model_dir,
......
......@@ -4,8 +4,9 @@
import enum
import logging
import time
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional
from typing import Any, Optional
from tensorrt_llm import LLM, MultimodalEncoder
from tensorrt_llm.llmapi.llm import BaseLLM
......@@ -31,9 +32,9 @@ class Backend(str, enum.Enum):
class TensorRTLLMEngine:
def __init__(
self,
engine_args,
engine_args: dict[str, Any],
disaggregation_mode: Optional[DisaggregationMode] = None,
):
) -> None:
self._llm: Optional[LLM] = None
self.disaggregation_mode = (
disaggregation_mode
......@@ -63,7 +64,7 @@ class TensorRTLLMEngine:
"""Whether the multimodal encoder LLM is initialized."""
return self._llm is not None
async def initialize(self):
async def initialize(self) -> None:
if not self._llm:
if self.disaggregation_mode == DisaggregationMode.ENCODE:
# Initialize the multimodal encoder for full EPD
......@@ -75,7 +76,7 @@ class TensorRTLLMEngine:
# Skip MultimodalEncoder for architectures that handle vision
# encoding inside the main model (e.g. Llama4).
if self._is_unsupported_encoder_arch(model):
if self._is_unsupported_encoder_arch(model): # type: ignore
return
max_batch_size = self.engine_args.get("max_batch_size", 1)
......@@ -93,7 +94,7 @@ class TensorRTLLMEngine:
# (model path, backend settings, KV cache config, disaggregation settings, etc.)
self._llm = self._llm_cls(**self.engine_args)
async def cleanup(self):
async def cleanup(self) -> None:
if self._llm:
try:
self._llm.shutdown()
......@@ -166,9 +167,9 @@ class TensorRTLLMEngine:
@asynccontextmanager
async def get_llm_engine(
engine_args,
engine_args: dict[str, Any],
disaggregation_mode: Optional[DisaggregationMode] = None,
component_gauges=None,
component_gauges: Any = None,
) -> AsyncGenerator[TensorRTLLMEngine, None]:
"""Get TensorRT-LLM engine instance with load time tracking.
......
......@@ -8,6 +8,7 @@ This module defines the default health check payload for TRT-LLM backends.
"""
import logging
from typing import Any
from dynamo.health_check import HealthCheckPayload
......@@ -55,7 +56,7 @@ class TrtllmHealthCheckPayload(HealthCheckPayload):
Provides TRT-LLM defaults and inherits environment override support from base class.
"""
def __init__(self, tokenizer=None):
def __init__(self, tokenizer: Any = None) -> None:
"""
Initialize TRT-LLM health check payload with TRT-LLM-specific defaults.
......
......@@ -31,9 +31,9 @@ class TrtllmDynamoLogitsAdapter(LogitsProcessor):
req_ids: int,
logits: torch.Tensor,
ids: List[List[int]],
stream_ptr,
stream_ptr: Optional[int],
client_id: Optional[int] = None,
):
) -> None:
"""
TensorRT-LLM logits processor interface.
......
......@@ -40,7 +40,12 @@ class TokenizerProtocol(Protocol):
the tokenizer's decode method not being found on a generic 'object' type.
"""
def decode(self, token_ids: List[int]) -> str:
def decode(
self,
token_ids: List[int],
skip_special_tokens: bool = True,
clean_up_tokenization_spaces: bool = True,
) -> str:
...
......
......@@ -26,9 +26,10 @@ import threading
import time
import traceback
import weakref
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from queue import Queue
from typing import Awaitable, Callable, Dict, Optional, Union
from typing import Any, Awaitable, Callable, Dict, Optional, Union
import msgpack
import zmq
......@@ -87,7 +88,7 @@ class ZmqKvEventPublisher:
Publishes events from TensorRT-LLM engine to ZMQ for consolidator to consume.
"""
def __init__(self, zmq_endpoint: str, kv_block_size: int, topic: str = ""):
def __init__(self, zmq_endpoint: str, kv_block_size: int, topic: str = "") -> None:
"""
Initialize ZMQ publisher.
......@@ -120,7 +121,7 @@ class ZmqKvEventPublisher:
block_mm_infos: Optional[list[dict | None]] = None,
attention_dp_rank: int = 0,
lora_name: Optional[str] = None,
):
) -> None:
"""Publish a BlockStored event.
Note: event_id is managed internally via self.sequence counter.
......@@ -133,7 +134,7 @@ class ZmqKvEventPublisher:
# Create event in the same format as vLLM's ZmqEventPublisher:
# All blocks should have the same size (kv_block_size)
event = {
event: dict[str, Any] = {
"type": "BlockStored",
"block_hashes": block_hashes_signed,
"parent_block_hash": parent_hash_signed,
......@@ -149,7 +150,9 @@ class ZmqKvEventPublisher:
self._publish_event(event, attention_dp_rank)
def publish_removed(self, block_hashes: list[int], attention_dp_rank: int = 0):
def publish_removed(
self, block_hashes: list[int], attention_dp_rank: int = 0
) -> None:
"""Publish a BlockRemoved event.
Note: event_id is managed internally via self.sequence counter.
......@@ -164,7 +167,7 @@ class ZmqKvEventPublisher:
self._publish_event(event, attention_dp_rank)
def publish_all_cleared(self):
def publish_all_cleared(self) -> None:
"""Publish an AllBlocksCleared event."""
event = {"type": "AllBlocksCleared"}
self._publish_event(event)
......@@ -197,7 +200,7 @@ class ZmqKvEventPublisher:
except Exception as e:
logging.error(f"Failed to publish ZMQ event: {e}", exc_info=True)
def shutdown(self):
def shutdown(self) -> None:
"""Shutdown the ZMQ publisher."""
if self.socket:
self.socket.close()
......@@ -229,10 +232,10 @@ class ManagedThread(threading.Thread):
self._stop_event = threading.Event()
def set_loop(self, loop: asyncio.AbstractEventLoop):
def set_loop(self, loop: asyncio.AbstractEventLoop) -> None:
self.loop = loop
def run(self):
def run(self) -> None:
while not self._stop_event.is_set():
task: Optional[
Union[Callable[..., Awaitable[bool]], weakref.WeakMethod]
......@@ -272,7 +275,7 @@ class ManagedThread(threading.Thread):
logging.info(f"Thread {self.name} stopped.")
def stop(self):
def stop(self) -> None:
self._stop_event.set()
if self._current_future and not self._current_future.done():
self._current_future.cancel()
......@@ -297,16 +300,16 @@ class Publisher:
def __init__(
self,
endpoint,
engine,
worker_id,
kv_block_size,
metrics_labels,
endpoint: Any,
engine: Any,
worker_id: Any,
kv_block_size: int,
metrics_labels: Any,
component_gauges: LLMBackendMetrics,
zmq_endpoint: Optional[str] = None,
enable_local_indexer: bool = False,
metrics_collector=None,
):
metrics_collector: Any = None,
) -> None:
self.endpoint = endpoint
self.engine = engine
self.worker_id = worker_id
......@@ -324,7 +327,7 @@ class Publisher:
self.processing_initial_created_events = True
# Needed by the events and metrics publishers
self.metrics_publisher = None
self.metrics_publisher: Optional[WorkerMetricsPublisher] = None
self.kv_event_publishers: Optional[
Dict[int, KvEventPublisher]
] = None # One per attention_dp_rank
......@@ -359,7 +362,7 @@ class Publisher:
return
await self.metrics_publisher.create_endpoint(self.endpoint)
def initialize(self):
def initialize(self) -> None:
# Setup the metrics publisher
self.metrics_publisher = WorkerMetricsPublisher()
self._init_publish_metrics_thread()
......@@ -474,6 +477,7 @@ class Publisher:
kv_total_blocks = stat["kvCacheStats"]["maxNumBlocks"]
logging.debug(f"Publishing stats: kv_active_blocks: {kv_active_blocks}")
# TRT-LLM doesn't use data parallelism currently (dp_rank=None for NATS, "0" for Prometheus)
assert self.metrics_publisher is not None
self.metrics_publisher.publish(None, kv_active_blocks)
# Publish Prometheus metrics
......@@ -680,7 +684,7 @@ class Publisher:
elif data["type"] == "created" and self.processing_initial_created_events:
self.update_max_window_size(event)
def start(self):
def start(self) -> None:
if (
self.publish_kv_cache_events_thread
and not self.publish_kv_cache_events_thread.is_alive()
......@@ -698,13 +702,13 @@ class Publisher:
self.publish_stats_thread.start()
logging.debug("Started stats thread")
def check_error_queue(self):
def check_error_queue(self) -> Optional[Exception]:
if not self.error_queue.empty():
logging.error("Error in publishers error queue")
return self.error_queue.get()
return None
async def cleanup(self):
async def cleanup(self) -> None:
"""Cleanup threads and resources"""
self._stop_event.set()
# Add timeout to prevent hanging
......@@ -729,7 +733,7 @@ class Publisher:
if self.zmq_kv_event_publisher:
self.zmq_kv_event_publisher.shutdown()
def update_max_window_size(self, event):
def update_max_window_size(self, event: dict) -> None:
if "window_size" in event:
window_size = event["window_size"]
if self.max_window_size is None or window_size > self.max_window_size:
......@@ -744,7 +748,7 @@ class Publisher:
# TRTLLM emits a "created" event at the very beginning when it creates the KV cache,
# so we can use the "created" event to identify the max_window_size of the global
# attention layer in the model engine.
def should_drop_event(self, event):
def should_drop_event(self, event: dict) -> bool:
# There are two cases for KV event filtering:
#
# 1. If "window_size" is NOT in the KV event:
......@@ -768,16 +772,16 @@ class Publisher:
@asynccontextmanager
async def get_publisher(
endpoint,
engine,
worker_id,
kv_block_size,
metrics_labels,
endpoint: Any,
engine: Any,
worker_id: Any,
kv_block_size: int,
metrics_labels: Any,
component_gauges: LLMBackendMetrics,
zmq_endpoint: Optional[str] = None,
enable_local_indexer: bool = False,
metrics_collector=None,
):
metrics_collector: Any = None,
) -> AsyncGenerator[Publisher, None]:
publisher = Publisher(
endpoint,
engine,
......
......@@ -4,6 +4,7 @@
"""Handler for aggregated (prefill + decode) mode with optional encoder disaggregation."""
import logging
from collections.abc import AsyncGenerator
from typing import Optional
from dynamo._core import Context
......@@ -33,7 +34,9 @@ class AggregatedHandler(HandlerBase):
super().__init__(config)
self._encoder_cache = encoder_cache
async def generate(self, request: dict, context: Context):
async def generate(
self, request: dict, context: Context
) -> AsyncGenerator[dict, None]:
"""Generate response, optionally using remote encoder for multimodal."""
logging.debug(f"AggregatedHandler Request ID: {context.id()}")
......
......@@ -18,9 +18,10 @@ import dataclasses
import logging
import os
import re
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from dataclasses import asdict, dataclass
from typing import Any, AsyncGenerator, Optional, Union
from typing import Any, Optional, Union
import torch
from tensorrt_llm.executor.result import GenerationResult
......@@ -103,7 +104,7 @@ class HandlerBase(BaseGenerativeHandler):
self.shutdown_event = config.shutdown_event
self.disable_request_abort = config.disable_request_abort
def check_error(self, result: dict):
def check_error(self, result: dict) -> bool:
"""
Check if there is an error in the result.
"""
......@@ -194,7 +195,7 @@ class HandlerBase(BaseGenerativeHandler):
Raise GeneratorExit if shutdown event is triggered.
"""
try:
cancellation_triggers = [
cancellation_triggers: list[asyncio.Future[Any]] = [
context.async_killed_or_stopped(), # Request cancellation
]
# Shutdown cancellation
......@@ -437,7 +438,7 @@ class HandlerBase(BaseGenerativeHandler):
Tuple of (disaggregated_params, ep_disaggregated_params, epd_metadata)
"""
disaggregated_params = None
epd_metadata = {}
epd_metadata: dict[str, Any] = {}
# PREFILL mode: setup context_only params
if self.disaggregation_mode == DisaggregationMode.PREFILL:
......@@ -608,7 +609,7 @@ class HandlerBase(BaseGenerativeHandler):
context: Context,
embeddings: Optional[Union[torch.Tensor, dict]] = None,
ep_disaggregated_params: Optional[DisaggregatedParams] = None,
):
) -> AsyncGenerator[dict, None]:
"""
Generate responses based on the disaggregation mode in the request.
......
......@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import logging
from collections.abc import AsyncGenerator
from typing import Optional
from dynamo._core import Context
......@@ -65,7 +66,9 @@ class EncodeHandler(HandlerBase):
self.model_type = self.multimodal_processor.model_type
self.tokenizer = self.multimodal_processor.tokenizer
async def generate(self, request: dict, context: Context):
async def generate(
self, request: dict, context: Context
) -> AsyncGenerator[dict, None]:
logging.debug(f"New Request ID: {context.id()}")
if self.multimodal_processor is None:
logging.error("encode handler: no multimodal_processor configured")
......@@ -121,7 +124,9 @@ class PrefillHandler(HandlerBase):
encode_response, self.connector
)
async def generate(self, request: dict, context: Context):
async def generate(
self, request: dict, context: Context
) -> AsyncGenerator[dict, None]:
"""
Prefill worker: process prompt and return disaggregated_params.
Frontend routes to decode workers automatically.
......@@ -195,7 +200,9 @@ class DecodeHandler(HandlerBase):
def __init__(self, config: RequestHandlerConfig):
super().__init__(config)
async def generate(self, request: dict, context: Context):
async def generate(
self, request: dict, context: Context
) -> AsyncGenerator[dict, None]:
"""
Decode worker: generate tokens using disaggregated_params from prefill.
If disaggregated_params is present, prefill was done. Otherwise generate normally.
......
......@@ -4,6 +4,7 @@
import asyncio
import re as re_mod
from dataclasses import dataclass
from typing import Any
from unittest import mock
from unittest.mock import MagicMock
......@@ -284,7 +285,7 @@ class TestGuidedDecodingFromToolChoice:
def test_empty_choice_ignored(self):
"""Empty choice list should not produce a regex."""
sampling_params = MockSamplingParams()
request = {
request: dict[str, Any] = {
"sampling_options": {
"guided_decoding": {
"choice": [],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment