"lib/bindings/vscode:/vscode.git/clone" did not exist on "177d662f86099602465dfeede907f5709608fdc6"
Unverified Commit f99b78f0 authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

chore: add mypy typing for trtllm (#6860)

parent 8cef50c6
...@@ -29,6 +29,7 @@ VALID_TRTLLM_CONNECTORS = {"none", "kvbm"} ...@@ -29,6 +29,7 @@ VALID_TRTLLM_CONNECTORS = {"none", "kvbm"}
class Config(DynamoRuntimeConfig, DynamoTrtllmConfig): class Config(DynamoRuntimeConfig, DynamoTrtllmConfig):
component: str component: str
use_kv_events: bool use_kv_events: bool
connector: list[str] # Redeclare for mypy (inherited from DynamoRuntimeConfig)
def validate(self) -> None: def validate(self) -> None:
DynamoRuntimeConfig.validate(self) DynamoRuntimeConfig.validate(self)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
"""Dynamo TRT-LLM backend configuration ArgGroup.""" """Dynamo TRT-LLM backend configuration ArgGroup."""
import argparse
from typing import Optional from typing import Optional
from tensorrt_llm.llmapi import BuildConfig from tensorrt_llm.llmapi import BuildConfig
...@@ -20,7 +21,7 @@ DEFAULT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" ...@@ -20,7 +21,7 @@ DEFAULT_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
class DynamoTrtllmArgGroup(ArgGroup): class DynamoTrtllmArgGroup(ArgGroup):
"""TensorRT-LLM-specific Dynamo wrapper configuration.""" """TensorRT-LLM-specific Dynamo wrapper configuration."""
def add_arguments(self, parser) -> None: def add_arguments(self, parser: argparse.ArgumentParser) -> None:
parser.add_argument( parser.add_argument(
"--version", "--version",
action="version", action="version",
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import asyncio import asyncio
import logging import logging
import threading import threading
from collections.abc import AsyncGenerator
from dataclasses import asdict from dataclasses import asdict
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
...@@ -377,13 +378,13 @@ class EncodeHelper: ...@@ -377,13 +378,13 @@ class EncodeHelper:
@staticmethod @staticmethod
async def process_encode_request( async def process_encode_request(
request: Dict[str, Any], request: Dict[str, Any],
multimodal_processor, multimodal_processor: Any,
connector: Optional[nixl_connect.Connector], connector: Optional[nixl_connect.Connector],
tokenizer=None, tokenizer: Any = None,
model_dir=None, model_dir: Optional[str] = None,
model_type=None, model_type: Optional[str] = None,
engine=None, engine: Any = None,
): ) -> AsyncGenerator[dict, None]:
""" """
Process an ENCODE-mode request. Dispatches to the appropriate flow. Process an ENCODE-mode request. Dispatches to the appropriate flow.
...@@ -447,7 +448,7 @@ class EncodeHelper: ...@@ -447,7 +448,7 @@ class EncodeHelper:
# if the model's tokenizer_config chat template emits them). # if the model's tokenizer_config chat template emits them).
token_ids = request.get("token_ids") token_ids = request.get("token_ids")
async for response in EncodeHelper._process_full_epd_flow( async for response in EncodeHelper._process_full_epd_flow(
token_ids, token_ids, # type: ignore
image_urls, image_urls,
tokenizer, tokenizer,
model_dir, model_dir,
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
import enum import enum
import logging import logging
import time import time
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional from typing import Any, Optional
from tensorrt_llm import LLM, MultimodalEncoder from tensorrt_llm import LLM, MultimodalEncoder
from tensorrt_llm.llmapi.llm import BaseLLM from tensorrt_llm.llmapi.llm import BaseLLM
...@@ -31,9 +32,9 @@ class Backend(str, enum.Enum): ...@@ -31,9 +32,9 @@ class Backend(str, enum.Enum):
class TensorRTLLMEngine: class TensorRTLLMEngine:
def __init__( def __init__(
self, self,
engine_args, engine_args: dict[str, Any],
disaggregation_mode: Optional[DisaggregationMode] = None, disaggregation_mode: Optional[DisaggregationMode] = None,
): ) -> None:
self._llm: Optional[LLM] = None self._llm: Optional[LLM] = None
self.disaggregation_mode = ( self.disaggregation_mode = (
disaggregation_mode disaggregation_mode
...@@ -63,7 +64,7 @@ class TensorRTLLMEngine: ...@@ -63,7 +64,7 @@ class TensorRTLLMEngine:
"""Whether the multimodal encoder LLM is initialized.""" """Whether the multimodal encoder LLM is initialized."""
return self._llm is not None return self._llm is not None
async def initialize(self): async def initialize(self) -> None:
if not self._llm: if not self._llm:
if self.disaggregation_mode == DisaggregationMode.ENCODE: if self.disaggregation_mode == DisaggregationMode.ENCODE:
# Initialize the multimodal encoder for full EPD # Initialize the multimodal encoder for full EPD
...@@ -75,7 +76,7 @@ class TensorRTLLMEngine: ...@@ -75,7 +76,7 @@ class TensorRTLLMEngine:
# Skip MultimodalEncoder for architectures that handle vision # Skip MultimodalEncoder for architectures that handle vision
# encoding inside the main model (e.g. Llama4). # encoding inside the main model (e.g. Llama4).
if self._is_unsupported_encoder_arch(model): if self._is_unsupported_encoder_arch(model): # type: ignore
return return
max_batch_size = self.engine_args.get("max_batch_size", 1) max_batch_size = self.engine_args.get("max_batch_size", 1)
...@@ -93,7 +94,7 @@ class TensorRTLLMEngine: ...@@ -93,7 +94,7 @@ class TensorRTLLMEngine:
# (model path, backend settings, KV cache config, disaggregation settings, etc.) # (model path, backend settings, KV cache config, disaggregation settings, etc.)
self._llm = self._llm_cls(**self.engine_args) self._llm = self._llm_cls(**self.engine_args)
async def cleanup(self): async def cleanup(self) -> None:
if self._llm: if self._llm:
try: try:
self._llm.shutdown() self._llm.shutdown()
...@@ -166,9 +167,9 @@ class TensorRTLLMEngine: ...@@ -166,9 +167,9 @@ class TensorRTLLMEngine:
@asynccontextmanager @asynccontextmanager
async def get_llm_engine( async def get_llm_engine(
engine_args, engine_args: dict[str, Any],
disaggregation_mode: Optional[DisaggregationMode] = None, disaggregation_mode: Optional[DisaggregationMode] = None,
component_gauges=None, component_gauges: Any = None,
) -> AsyncGenerator[TensorRTLLMEngine, None]: ) -> AsyncGenerator[TensorRTLLMEngine, None]:
"""Get TensorRT-LLM engine instance with load time tracking. """Get TensorRT-LLM engine instance with load time tracking.
......
...@@ -8,6 +8,7 @@ This module defines the default health check payload for TRT-LLM backends. ...@@ -8,6 +8,7 @@ This module defines the default health check payload for TRT-LLM backends.
""" """
import logging import logging
from typing import Any
from dynamo.health_check import HealthCheckPayload from dynamo.health_check import HealthCheckPayload
...@@ -55,7 +56,7 @@ class TrtllmHealthCheckPayload(HealthCheckPayload): ...@@ -55,7 +56,7 @@ class TrtllmHealthCheckPayload(HealthCheckPayload):
Provides TRT-LLM defaults and inherits environment override support from base class. Provides TRT-LLM defaults and inherits environment override support from base class.
""" """
def __init__(self, tokenizer=None): def __init__(self, tokenizer: Any = None) -> None:
""" """
Initialize TRT-LLM health check payload with TRT-LLM-specific defaults. Initialize TRT-LLM health check payload with TRT-LLM-specific defaults.
......
...@@ -31,9 +31,9 @@ class TrtllmDynamoLogitsAdapter(LogitsProcessor): ...@@ -31,9 +31,9 @@ class TrtllmDynamoLogitsAdapter(LogitsProcessor):
req_ids: int, req_ids: int,
logits: torch.Tensor, logits: torch.Tensor,
ids: List[List[int]], ids: List[List[int]],
stream_ptr, stream_ptr: Optional[int],
client_id: Optional[int] = None, client_id: Optional[int] = None,
): ) -> None:
""" """
TensorRT-LLM logits processor interface. TensorRT-LLM logits processor interface.
......
...@@ -40,7 +40,12 @@ class TokenizerProtocol(Protocol): ...@@ -40,7 +40,12 @@ class TokenizerProtocol(Protocol):
the tokenizer's decode method not being found on a generic 'object' type. the tokenizer's decode method not being found on a generic 'object' type.
""" """
def decode(self, token_ids: List[int]) -> str: def decode(
self,
token_ids: List[int],
skip_special_tokens: bool = True,
clean_up_tokenization_spaces: bool = True,
) -> str:
... ...
......
...@@ -26,9 +26,10 @@ import threading ...@@ -26,9 +26,10 @@ import threading
import time import time
import traceback import traceback
import weakref import weakref
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from queue import Queue from queue import Queue
from typing import Awaitable, Callable, Dict, Optional, Union from typing import Any, Awaitable, Callable, Dict, Optional, Union
import msgpack import msgpack
import zmq import zmq
...@@ -87,7 +88,7 @@ class ZmqKvEventPublisher: ...@@ -87,7 +88,7 @@ class ZmqKvEventPublisher:
Publishes events from TensorRT-LLM engine to ZMQ for consolidator to consume. Publishes events from TensorRT-LLM engine to ZMQ for consolidator to consume.
""" """
def __init__(self, zmq_endpoint: str, kv_block_size: int, topic: str = ""): def __init__(self, zmq_endpoint: str, kv_block_size: int, topic: str = "") -> None:
""" """
Initialize ZMQ publisher. Initialize ZMQ publisher.
...@@ -120,7 +121,7 @@ class ZmqKvEventPublisher: ...@@ -120,7 +121,7 @@ class ZmqKvEventPublisher:
block_mm_infos: Optional[list[dict | None]] = None, block_mm_infos: Optional[list[dict | None]] = None,
attention_dp_rank: int = 0, attention_dp_rank: int = 0,
lora_name: Optional[str] = None, lora_name: Optional[str] = None,
): ) -> None:
"""Publish a BlockStored event. """Publish a BlockStored event.
Note: event_id is managed internally via self.sequence counter. Note: event_id is managed internally via self.sequence counter.
...@@ -133,7 +134,7 @@ class ZmqKvEventPublisher: ...@@ -133,7 +134,7 @@ class ZmqKvEventPublisher:
# Create event in the same format as vLLM's ZmqEventPublisher: # Create event in the same format as vLLM's ZmqEventPublisher:
# All blocks should have the same size (kv_block_size) # All blocks should have the same size (kv_block_size)
event = { event: dict[str, Any] = {
"type": "BlockStored", "type": "BlockStored",
"block_hashes": block_hashes_signed, "block_hashes": block_hashes_signed,
"parent_block_hash": parent_hash_signed, "parent_block_hash": parent_hash_signed,
...@@ -149,7 +150,9 @@ class ZmqKvEventPublisher: ...@@ -149,7 +150,9 @@ class ZmqKvEventPublisher:
self._publish_event(event, attention_dp_rank) self._publish_event(event, attention_dp_rank)
def publish_removed(self, block_hashes: list[int], attention_dp_rank: int = 0): def publish_removed(
self, block_hashes: list[int], attention_dp_rank: int = 0
) -> None:
"""Publish a BlockRemoved event. """Publish a BlockRemoved event.
Note: event_id is managed internally via self.sequence counter. Note: event_id is managed internally via self.sequence counter.
...@@ -164,7 +167,7 @@ class ZmqKvEventPublisher: ...@@ -164,7 +167,7 @@ class ZmqKvEventPublisher:
self._publish_event(event, attention_dp_rank) self._publish_event(event, attention_dp_rank)
def publish_all_cleared(self): def publish_all_cleared(self) -> None:
"""Publish an AllBlocksCleared event.""" """Publish an AllBlocksCleared event."""
event = {"type": "AllBlocksCleared"} event = {"type": "AllBlocksCleared"}
self._publish_event(event) self._publish_event(event)
...@@ -197,7 +200,7 @@ class ZmqKvEventPublisher: ...@@ -197,7 +200,7 @@ class ZmqKvEventPublisher:
except Exception as e: except Exception as e:
logging.error(f"Failed to publish ZMQ event: {e}", exc_info=True) logging.error(f"Failed to publish ZMQ event: {e}", exc_info=True)
def shutdown(self): def shutdown(self) -> None:
"""Shutdown the ZMQ publisher.""" """Shutdown the ZMQ publisher."""
if self.socket: if self.socket:
self.socket.close() self.socket.close()
...@@ -229,10 +232,10 @@ class ManagedThread(threading.Thread): ...@@ -229,10 +232,10 @@ class ManagedThread(threading.Thread):
self._stop_event = threading.Event() self._stop_event = threading.Event()
def set_loop(self, loop: asyncio.AbstractEventLoop): def set_loop(self, loop: asyncio.AbstractEventLoop) -> None:
self.loop = loop self.loop = loop
def run(self): def run(self) -> None:
while not self._stop_event.is_set(): while not self._stop_event.is_set():
task: Optional[ task: Optional[
Union[Callable[..., Awaitable[bool]], weakref.WeakMethod] Union[Callable[..., Awaitable[bool]], weakref.WeakMethod]
...@@ -272,7 +275,7 @@ class ManagedThread(threading.Thread): ...@@ -272,7 +275,7 @@ class ManagedThread(threading.Thread):
logging.info(f"Thread {self.name} stopped.") logging.info(f"Thread {self.name} stopped.")
def stop(self): def stop(self) -> None:
self._stop_event.set() self._stop_event.set()
if self._current_future and not self._current_future.done(): if self._current_future and not self._current_future.done():
self._current_future.cancel() self._current_future.cancel()
...@@ -297,16 +300,16 @@ class Publisher: ...@@ -297,16 +300,16 @@ class Publisher:
def __init__( def __init__(
self, self,
endpoint, endpoint: Any,
engine, engine: Any,
worker_id, worker_id: Any,
kv_block_size, kv_block_size: int,
metrics_labels, metrics_labels: Any,
component_gauges: LLMBackendMetrics, component_gauges: LLMBackendMetrics,
zmq_endpoint: Optional[str] = None, zmq_endpoint: Optional[str] = None,
enable_local_indexer: bool = False, enable_local_indexer: bool = False,
metrics_collector=None, metrics_collector: Any = None,
): ) -> None:
self.endpoint = endpoint self.endpoint = endpoint
self.engine = engine self.engine = engine
self.worker_id = worker_id self.worker_id = worker_id
...@@ -324,7 +327,7 @@ class Publisher: ...@@ -324,7 +327,7 @@ class Publisher:
self.processing_initial_created_events = True self.processing_initial_created_events = True
# Needed by the events and metrics publishers # Needed by the events and metrics publishers
self.metrics_publisher = None self.metrics_publisher: Optional[WorkerMetricsPublisher] = None
self.kv_event_publishers: Optional[ self.kv_event_publishers: Optional[
Dict[int, KvEventPublisher] Dict[int, KvEventPublisher]
] = None # One per attention_dp_rank ] = None # One per attention_dp_rank
...@@ -359,7 +362,7 @@ class Publisher: ...@@ -359,7 +362,7 @@ class Publisher:
return return
await self.metrics_publisher.create_endpoint(self.endpoint) await self.metrics_publisher.create_endpoint(self.endpoint)
def initialize(self): def initialize(self) -> None:
# Setup the metrics publisher # Setup the metrics publisher
self.metrics_publisher = WorkerMetricsPublisher() self.metrics_publisher = WorkerMetricsPublisher()
self._init_publish_metrics_thread() self._init_publish_metrics_thread()
...@@ -474,6 +477,7 @@ class Publisher: ...@@ -474,6 +477,7 @@ class Publisher:
kv_total_blocks = stat["kvCacheStats"]["maxNumBlocks"] kv_total_blocks = stat["kvCacheStats"]["maxNumBlocks"]
logging.debug(f"Publishing stats: kv_active_blocks: {kv_active_blocks}") logging.debug(f"Publishing stats: kv_active_blocks: {kv_active_blocks}")
# TRT-LLM doesn't use data parallelism currently (dp_rank=None for NATS, "0" for Prometheus) # TRT-LLM doesn't use data parallelism currently (dp_rank=None for NATS, "0" for Prometheus)
assert self.metrics_publisher is not None
self.metrics_publisher.publish(None, kv_active_blocks) self.metrics_publisher.publish(None, kv_active_blocks)
# Publish Prometheus metrics # Publish Prometheus metrics
...@@ -680,7 +684,7 @@ class Publisher: ...@@ -680,7 +684,7 @@ class Publisher:
elif data["type"] == "created" and self.processing_initial_created_events: elif data["type"] == "created" and self.processing_initial_created_events:
self.update_max_window_size(event) self.update_max_window_size(event)
def start(self): def start(self) -> None:
if ( if (
self.publish_kv_cache_events_thread self.publish_kv_cache_events_thread
and not self.publish_kv_cache_events_thread.is_alive() and not self.publish_kv_cache_events_thread.is_alive()
...@@ -698,13 +702,13 @@ class Publisher: ...@@ -698,13 +702,13 @@ class Publisher:
self.publish_stats_thread.start() self.publish_stats_thread.start()
logging.debug("Started stats thread") logging.debug("Started stats thread")
def check_error_queue(self): def check_error_queue(self) -> Optional[Exception]:
if not self.error_queue.empty(): if not self.error_queue.empty():
logging.error("Error in publishers error queue") logging.error("Error in publishers error queue")
return self.error_queue.get() return self.error_queue.get()
return None return None
async def cleanup(self): async def cleanup(self) -> None:
"""Cleanup threads and resources""" """Cleanup threads and resources"""
self._stop_event.set() self._stop_event.set()
# Add timeout to prevent hanging # Add timeout to prevent hanging
...@@ -729,7 +733,7 @@ class Publisher: ...@@ -729,7 +733,7 @@ class Publisher:
if self.zmq_kv_event_publisher: if self.zmq_kv_event_publisher:
self.zmq_kv_event_publisher.shutdown() self.zmq_kv_event_publisher.shutdown()
def update_max_window_size(self, event): def update_max_window_size(self, event: dict) -> None:
if "window_size" in event: if "window_size" in event:
window_size = event["window_size"] window_size = event["window_size"]
if self.max_window_size is None or window_size > self.max_window_size: if self.max_window_size is None or window_size > self.max_window_size:
...@@ -744,7 +748,7 @@ class Publisher: ...@@ -744,7 +748,7 @@ class Publisher:
# TRTLLM emits a "created" event at the very beginning when it creates the KV cache, # TRTLLM emits a "created" event at the very beginning when it creates the KV cache,
# so we can use the "created" event to identify the max_window_size of the global # so we can use the "created" event to identify the max_window_size of the global
# attention layer in the model engine. # attention layer in the model engine.
def should_drop_event(self, event): def should_drop_event(self, event: dict) -> bool:
# There are two cases for KV event filtering: # There are two cases for KV event filtering:
# #
# 1. If "window_size" is NOT in the KV event: # 1. If "window_size" is NOT in the KV event:
...@@ -768,16 +772,16 @@ class Publisher: ...@@ -768,16 +772,16 @@ class Publisher:
@asynccontextmanager @asynccontextmanager
async def get_publisher( async def get_publisher(
endpoint, endpoint: Any,
engine, engine: Any,
worker_id, worker_id: Any,
kv_block_size, kv_block_size: int,
metrics_labels, metrics_labels: Any,
component_gauges: LLMBackendMetrics, component_gauges: LLMBackendMetrics,
zmq_endpoint: Optional[str] = None, zmq_endpoint: Optional[str] = None,
enable_local_indexer: bool = False, enable_local_indexer: bool = False,
metrics_collector=None, metrics_collector: Any = None,
): ) -> AsyncGenerator[Publisher, None]:
publisher = Publisher( publisher = Publisher(
endpoint, endpoint,
engine, engine,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""Handler for aggregated (prefill + decode) mode with optional encoder disaggregation.""" """Handler for aggregated (prefill + decode) mode with optional encoder disaggregation."""
import logging import logging
from collections.abc import AsyncGenerator
from typing import Optional from typing import Optional
from dynamo._core import Context from dynamo._core import Context
...@@ -33,7 +34,9 @@ class AggregatedHandler(HandlerBase): ...@@ -33,7 +34,9 @@ class AggregatedHandler(HandlerBase):
super().__init__(config) super().__init__(config)
self._encoder_cache = encoder_cache self._encoder_cache = encoder_cache
async def generate(self, request: dict, context: Context): async def generate(
self, request: dict, context: Context
) -> AsyncGenerator[dict, None]:
"""Generate response, optionally using remote encoder for multimodal.""" """Generate response, optionally using remote encoder for multimodal."""
logging.debug(f"AggregatedHandler Request ID: {context.id()}") logging.debug(f"AggregatedHandler Request ID: {context.id()}")
......
...@@ -18,9 +18,10 @@ import dataclasses ...@@ -18,9 +18,10 @@ import dataclasses
import logging import logging
import os import os
import re import re
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Any, AsyncGenerator, Optional, Union from typing import Any, Optional, Union
import torch import torch
from tensorrt_llm.executor.result import GenerationResult from tensorrt_llm.executor.result import GenerationResult
...@@ -103,7 +104,7 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -103,7 +104,7 @@ class HandlerBase(BaseGenerativeHandler):
self.shutdown_event = config.shutdown_event self.shutdown_event = config.shutdown_event
self.disable_request_abort = config.disable_request_abort self.disable_request_abort = config.disable_request_abort
def check_error(self, result: dict): def check_error(self, result: dict) -> bool:
""" """
Check if there is an error in the result. Check if there is an error in the result.
""" """
...@@ -194,7 +195,7 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -194,7 +195,7 @@ class HandlerBase(BaseGenerativeHandler):
Raise GeneratorExit if shutdown event is triggered. Raise GeneratorExit if shutdown event is triggered.
""" """
try: try:
cancellation_triggers = [ cancellation_triggers: list[asyncio.Future[Any]] = [
context.async_killed_or_stopped(), # Request cancellation context.async_killed_or_stopped(), # Request cancellation
] ]
# Shutdown cancellation # Shutdown cancellation
...@@ -437,7 +438,7 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -437,7 +438,7 @@ class HandlerBase(BaseGenerativeHandler):
Tuple of (disaggregated_params, ep_disaggregated_params, epd_metadata) Tuple of (disaggregated_params, ep_disaggregated_params, epd_metadata)
""" """
disaggregated_params = None disaggregated_params = None
epd_metadata = {} epd_metadata: dict[str, Any] = {}
# PREFILL mode: setup context_only params # PREFILL mode: setup context_only params
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
...@@ -608,7 +609,7 @@ class HandlerBase(BaseGenerativeHandler): ...@@ -608,7 +609,7 @@ class HandlerBase(BaseGenerativeHandler):
context: Context, context: Context,
embeddings: Optional[Union[torch.Tensor, dict]] = None, embeddings: Optional[Union[torch.Tensor, dict]] = None,
ep_disaggregated_params: Optional[DisaggregatedParams] = None, ep_disaggregated_params: Optional[DisaggregatedParams] = None,
): ) -> AsyncGenerator[dict, None]:
""" """
Generate responses based on the disaggregation mode in the request. Generate responses based on the disaggregation mode in the request.
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
from collections.abc import AsyncGenerator
from typing import Optional from typing import Optional
from dynamo._core import Context from dynamo._core import Context
...@@ -65,7 +66,9 @@ class EncodeHandler(HandlerBase): ...@@ -65,7 +66,9 @@ class EncodeHandler(HandlerBase):
self.model_type = self.multimodal_processor.model_type self.model_type = self.multimodal_processor.model_type
self.tokenizer = self.multimodal_processor.tokenizer self.tokenizer = self.multimodal_processor.tokenizer
async def generate(self, request: dict, context: Context): async def generate(
self, request: dict, context: Context
) -> AsyncGenerator[dict, None]:
logging.debug(f"New Request ID: {context.id()}") logging.debug(f"New Request ID: {context.id()}")
if self.multimodal_processor is None: if self.multimodal_processor is None:
logging.error("encode handler: no multimodal_processor configured") logging.error("encode handler: no multimodal_processor configured")
...@@ -121,7 +124,9 @@ class PrefillHandler(HandlerBase): ...@@ -121,7 +124,9 @@ class PrefillHandler(HandlerBase):
encode_response, self.connector encode_response, self.connector
) )
async def generate(self, request: dict, context: Context): async def generate(
self, request: dict, context: Context
) -> AsyncGenerator[dict, None]:
""" """
Prefill worker: process prompt and return disaggregated_params. Prefill worker: process prompt and return disaggregated_params.
Frontend routes to decode workers automatically. Frontend routes to decode workers automatically.
...@@ -195,7 +200,9 @@ class DecodeHandler(HandlerBase): ...@@ -195,7 +200,9 @@ class DecodeHandler(HandlerBase):
def __init__(self, config: RequestHandlerConfig): def __init__(self, config: RequestHandlerConfig):
super().__init__(config) super().__init__(config)
async def generate(self, request: dict, context: Context): async def generate(
self, request: dict, context: Context
) -> AsyncGenerator[dict, None]:
""" """
Decode worker: generate tokens using disaggregated_params from prefill. Decode worker: generate tokens using disaggregated_params from prefill.
If disaggregated_params is present, prefill was done. Otherwise generate normally. If disaggregated_params is present, prefill was done. Otherwise generate normally.
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import asyncio import asyncio
import re as re_mod import re as re_mod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any
from unittest import mock from unittest import mock
from unittest.mock import MagicMock from unittest.mock import MagicMock
...@@ -284,7 +285,7 @@ class TestGuidedDecodingFromToolChoice: ...@@ -284,7 +285,7 @@ class TestGuidedDecodingFromToolChoice:
def test_empty_choice_ignored(self): def test_empty_choice_ignored(self):
"""Empty choice list should not produce a regex.""" """Empty choice list should not produce a regex."""
sampling_params = MockSamplingParams() sampling_params = MockSamplingParams()
request = { request: dict[str, Any] = {
"sampling_options": { "sampling_options": {
"guided_decoding": { "guided_decoding": {
"choice": [], "choice": [],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment