Unverified Commit 2831bfec authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

chore: add mypy typing to vllm (#6858)

parent 0f01e724
......@@ -29,7 +29,7 @@ class VllmEngineMonitor:
self,
runtime: DistributedRuntime,
engine_client: AsyncLLM,
shutdown_event: asyncio.Event = None,
shutdown_event: asyncio.Event | None = None,
):
if not isinstance(runtime, DistributedRuntime):
raise ValueError(
......
......@@ -261,7 +261,7 @@ def build_sampling_params_openai(
return sampling_params
def get_dp_range_for_worker(vllm_config: VllmConfig) -> range:
def get_dp_range_for_worker(vllm_config: VllmConfig) -> tuple[int, int]:
"""
Get the global DP rank range that this worker is responsible for based on vLLM config.
Note that the 'vllm_config' is normalized so the load balancing flags are set properly.
......@@ -318,7 +318,7 @@ class BaseWorkerHandler(ABC):
self.enable_multimodal = enable_multimodal
self.enable_frontend_decoding = enable_frontend_decoding
# NIXL connector for frontend decoding - lazy initialized
self._nixl_connector = None
self._nixl_connector: nixl_connect.Connector | None = None
self._nixl_connector_lock = asyncio.Lock()
# LoRA tracking: name -> LoRAInfo(id, path)
self.loaded_loras: dict[str, LoRAInfo] = {}
......
......@@ -7,7 +7,7 @@ import logging
import os
import tempfile
import time
from typing import Optional
from typing import Any, Optional
import uvloop
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
......@@ -37,17 +37,6 @@ from dynamo.llm import (
fetch_model,
register_model,
)
# Optional imports for frontend decoding support
try:
from dynamo.llm import MediaDecoder, MediaFetcher
MEDIA_DECODER_AVAILABLE = True
except ImportError:
MediaDecoder = None
MediaFetcher = None
MEDIA_DECODER_AVAILABLE = False
from dynamo.runtime import DistributedRuntime, Endpoint
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.worker_factory import WorkerFactory
......@@ -63,6 +52,18 @@ from .health_check import (
)
from .publisher import DYNAMO_COMPONENT_REGISTRY, StatLoggerFactory
# Optional imports for frontend decoding support
MediaDecoder: type | None = None
MediaFetcher: type | None = None
try:
from dynamo.llm import MediaDecoder, MediaFetcher
MEDIA_DECODER_AVAILABLE = True
except ImportError:
MediaDecoder = None
MediaFetcher = None
MEDIA_DECODER_AVAILABLE = False
configure_dynamo_logging()
logger = logging.getLogger(__name__)
shutdown_endpoints: list = []
......@@ -93,7 +94,7 @@ def run_dynamo_headless(config: Config) -> None:
run_headless(args)
async def worker():
async def worker() -> None:
config = parse_args()
dump_config(config.dump_config_to, config)
......@@ -198,7 +199,9 @@ async def worker():
logger.debug("Worker function completed, exiting...")
def setup_metrics_collection(config: Config, generate_endpoint, logger):
def setup_metrics_collection(
config: Config, generate_endpoint: Endpoint, logger: logging.Logger
) -> None:
"""Set up metrics collection for vLLM and LMCache metrics.
In multiprocess mode (PROMETHEUS_MULTIPROC_DIR set), metrics are stored:
......@@ -294,8 +297,9 @@ def setup_kv_event_publisher(
vllm_config: VllmConfig,
consolidator_enabled: bool = False,
consolidator_port: Optional[int] = 5558,
) -> Optional[KvEventPublisher]:
) -> Optional[list[KvEventPublisher]]:
"""
list[KvEventPublisher] | None
Set up KV event publishers for prefix caching if enabled.
Creates one publisher per dp_rank since each dp_rank publishes to a different port.
Args:
......@@ -365,7 +369,9 @@ def setup_kv_event_publisher(
return kv_publishers if kv_publishers else None
def setup_vllm_engine(config, stat_logger=None):
def setup_vllm_engine(
config: Config, stat_logger: Optional[StatLoggerFactory] = None
) -> tuple[AsyncLLM, VllmConfig, Any, Any, LLMBackendMetrics]:
# vLLM v0.11.0 bug: vllm/v1.metrics/prometheus.py:79 passes TemporaryDirectory object
# instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR
# ourselves to avoid this and handle cleanup properly.
......@@ -511,11 +517,11 @@ def setup_vllm_engine(config, stat_logger=None):
async def register_vllm_model(
model_input: ModelInput,
model_type: ModelType,
generate_endpoint,
generate_endpoint: Endpoint,
config: Config,
engine_client: AsyncLLM,
vllm_config: VllmConfig,
):
) -> None:
"""
Helper function to register a vLLM model with runtime configuration.
......@@ -563,6 +569,7 @@ async def register_vllm_model(
"--frontend-decoding requires MediaDecoder support. "
"Ensure dynamo.llm module includes MediaDecoder and MediaFetcher."
)
assert MediaDecoder is not None and MediaFetcher is not None
media_decoder = MediaDecoder()
media_decoder.enable_image({"limits": {"max_alloc": 128 * 1024 * 1024}})
# media_decoder.enable_video({})
......@@ -590,8 +597,10 @@ async def init_prefill(
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
checkpoint_restore_engine=None,
):
checkpoint_restore_engine: Optional[
tuple[AsyncLLM, VllmConfig, Any, Any, LLMBackendMetrics]
] = None,
) -> None:
"""
Instantiate and serve
"""
......@@ -690,7 +699,7 @@ async def init_prefill(
# (long-term reason): prefill engine should pull from a global queue so there is
# only a few in-flight requests that can be quickly finished
generate_endpoint.serve_endpoint(
handler.generate,
handler.generate, # type: ignore
graceful_shutdown=True,
# In practice config.served_model_name is always set, but mypy needs the "or" here.
metrics_labels=[
......@@ -706,10 +715,16 @@ async def init_prefill(
health_check_payload=health_check_payload,
),
clear_endpoint.serve_endpoint(
handler.clear_kv_blocks,
handler.clear_kv_blocks, # type: ignore
metrics_labels=[
(prometheus_names.labels.MODEL, config.served_model_name),
(prometheus_names.labels.MODEL_NAME, config.served_model_name),
(
prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
],
),
)
......@@ -726,8 +741,10 @@ async def init(
runtime: DistributedRuntime,
config: Config,
shutdown_event: asyncio.Event,
checkpoint_restore_engine=None,
):
checkpoint_restore_engine: Optional[
tuple[AsyncLLM, VllmConfig, Any, Any, LLMBackendMetrics]
] = None,
) -> None:
"""
Instantiate and serve
"""
......@@ -886,7 +903,7 @@ async def init(
# for decode, we want to transfer the in-flight requests to other decode engines,
# because waiting them to finish can take a long time for long OSLs
generate_endpoint.serve_endpoint(
handler.generate,
handler.generate, # type: ignore
graceful_shutdown=True,
metrics_labels=model_metrics_labels,
health_check_payload=health_check_payload,
......@@ -926,7 +943,7 @@ async def init(
handler.cleanup()
def get_engine_cache_info(engine: AsyncLLM):
def get_engine_cache_info(engine: AsyncLLM) -> dict[str, Any]:
"""Retrieve cache configuration information from [`AsyncLLM`] engine."""
try:
......@@ -956,7 +973,7 @@ def get_engine_cache_info(engine: AsyncLLM):
async def init_omni(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
):
) -> None:
"""Initialize Omni worker for multi-stage pipeline generation using vLLM-Omni.
Supports text-to-text, text-to-image, and text-to-video generation
......@@ -1034,7 +1051,7 @@ async def init_omni(
handler.cleanup()
def main():
def main() -> None:
uvloop.run(worker())
......
......@@ -6,7 +6,7 @@ import logging
import os
import time
from dataclasses import dataclass
from typing import AsyncIterator
from typing import Any, AsyncIterator
import torch
from transformers import AutoImageProcessor
......@@ -80,7 +80,7 @@ class EncodeWorkerHandler:
self._connector: connect.Connector | None = None
self._accumulated_time = 0.0
self._processed_requests = 0
self.readables = []
self.readables: list[Any] = []
self.embedding_cache = EmbeddingCache() if ENABLE_ENCODER_CACHE else None
if embedding_transfer_mode == EmbeddingTransferMode.LOCAL:
self.embedding_sender = LocalEmbeddingSender()
......@@ -93,7 +93,7 @@ class EncodeWorkerHandler:
f"Invalid embedding transfer mode: {embedding_transfer_mode}"
)
self.send_complete_queue = asyncio.Queue()
self.send_complete_queue: asyncio.Queue[tuple[Any, Any]] = asyncio.Queue()
self.send_complete_checker_task = asyncio.create_task(
self.check_complete(self.send_complete_queue)
)
......@@ -150,7 +150,9 @@ class EncodeWorkerHandler:
with _nvtx.annotate("mm:enc:cache_check", color="cyan"):
# Before batch process images, check cache first
need_encode_indexes = []
embedding_lists = [None] * len(request.multimodal_inputs)
embedding_lists: list[EmbeddingItem | None] = [None] * len(
request.multimodal_inputs
)
for idx in range(len(request.multimodal_inputs)):
if not request.multimodal_inputs[idx].multimodal_input.image_url:
raise ValueError("image_url is required for the encode worker.")
......@@ -251,16 +253,16 @@ class EncodeWorkerHandler:
for split_idx, (list_idx, key) in enumerate(need_encode_indexes):
embedding_lists[list_idx] = EmbeddingItem(
key,
[image_grid_thw[split_idx]] if image_grid_thw else None,
[image_grid_thw[split_idx]] if image_grid_thw else [],
splitted_embeddings[split_idx].unsqueeze(0),
)
# Cache the computed value for future use
if self.embedding_cache is not None:
self.embedding_cache.set(
embedding_lists[list_idx].key,
embedding_lists[list_idx].key, # type: ignore
(
embedding_lists[list_idx].image_grid_thw,
embedding_lists[list_idx].embeddings,
embedding_lists[list_idx].image_grid_thw, # type: ignore
embedding_lists[list_idx].embeddings, # type: ignore
),
)
......@@ -275,6 +277,7 @@ class EncodeWorkerHandler:
)
)
for embedding_item in embedding_lists
if embedding_item is not None
]
transfer_requests = await asyncio.gather(*send_tasks)
......@@ -282,6 +285,7 @@ class EncodeWorkerHandler:
for idx, item in enumerate(zip(embedding_lists, transfer_requests)):
embedding_item, transfer_request = item
assert embedding_item is not None
logger.debug(
f"{embedding_item.embeddings.shape} prepared for transfer."
)
......
......@@ -371,10 +371,10 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
num_output_tokens_so_far = 0
async for (
decode_response
) in await self.decode_worker_client.round_robin( # type: ignore[union-attr]
) in await self.decode_worker_client.round_robin( # type: ignore
request.model_dump_json()
):
output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore[attr-defined]
output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore
yield self._format_engine_output(output, num_output_tokens_so_far)
if output.outputs:
num_output_tokens_so_far = len(output.outputs[0].token_ids)
......
......@@ -64,8 +64,8 @@ def get_qwen_image_features(
if grid_thw is None:
raise ValueError("grid_thw is not provided")
grid_thw = grid_thw.tolist()
image_embeds = vision_encoder(pixel_values, grid_thw=grid_thw)
return image_embeds
image_features = vision_encoder(pixel_values, grid_thw=grid_thw)
return image_features
pixel_values = image_embeds["pixel_values"].to(vision_encoder.device)
......
......@@ -257,8 +257,10 @@ async def _fetch_embeddings(
)
# ── 3. Update cache (no-op when cache is None) ──────────────
for (idx, _url, key), group in zip(to_fetch, groups, strict=True):
if cache is not None and key is not None:
assert group.loaded_embedding is not None
cache.set(
key,
CachedEmbedding(
......@@ -301,6 +303,7 @@ async def load_multimodal_embeddings(
multi_modal_data: Dict[str, Any] = defaultdict(list)
for group in groups:
assert group.loaded_embedding is not None
_accumulate_embeddings(
multi_modal_data,
model,
......
......@@ -132,7 +132,7 @@ class BaseOmniHandler(BaseWorkerHandler):
request_id = context.id()
logger.debug(f"Omni Request ID: {request_id}")
async for chunk in self._generate_openai_mode(request, context, request_id):
async for chunk in self._generate_openai_mode(request, context, request_id): # type: ignore
yield chunk
async def _generate_openai_mode(
......
......@@ -413,6 +413,8 @@ class OmniHandler(BaseOmniHandler):
output = NvImagesResponse(created=int(time.time()), data=image_data_list)
return output.model_dump()
else:
return None
async def _format_video_chunk(
self,
......
......@@ -46,7 +46,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
raise
# TODO: Remove this and pass as metadata through shared storage
def set_num_gpu_block(self, num_blocks):
def set_num_gpu_block(self, num_blocks: int) -> None:
self.num_gpu_block = num_blocks
def record(
......@@ -54,9 +54,9 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
scheduler_stats: SchedulerStats,
iteration_stats: Optional[IterationStats],
engine_idx: int = 0,
*args,
**kwargs,
):
*args: object,
**kwargs: object,
) -> None:
active_decode_blocks = int(self.num_gpu_block * scheduler_stats.kv_cache_usage)
self.inner.publish(self.dp_rank, active_decode_blocks)
......@@ -71,7 +71,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
dp_rank_str, scheduler_stats.kv_cache_usage
)
def init_publish(self):
def init_publish(self) -> None:
self.inner.publish(self.dp_rank, 0)
dp_rank_str = str(self.dp_rank)
self.component_gauges.set_total_blocks(dp_rank_str, 0)
......@@ -112,10 +112,10 @@ class StatLoggerFactory:
return self.create_stat_logger(dp_rank=dp_rank)
# TODO Remove once we publish metadata to shared storage
def set_num_gpu_blocks_all(self, num_blocks):
def set_num_gpu_blocks_all(self, num_blocks: int) -> None:
if self.created_logger:
self.created_logger.set_num_gpu_block(num_blocks)
def init_publish(self):
def init_publish(self) -> None:
if self.created_logger:
self.created_logger.init_publish()
......@@ -180,7 +180,7 @@ class TestLoadMultimodalData:
mock_client = MagicMock()
handler = _make_handler(encode_worker_client=mock_client)
fake_mm_data = defaultdict(list, {"image": torch.randn(1, 10)})
fake_mm_data = defaultdict(list, {"image": torch.randn(1, 10)}) # type: ignore
with patch.object(
mod,
"load_multimodal_embeddings",
......
......@@ -29,7 +29,7 @@ def mock_handler():
pass
handler = MockHandler()
handler._decode_prompt_embeds = BaseWorkerHandler._decode_prompt_embeds.__get__(
handler._decode_prompt_embeds = BaseWorkerHandler._decode_prompt_embeds.__get__( # type: ignore
handler
)
return handler
......
......@@ -981,7 +981,7 @@ class ModelType:
Audios: ModelType
Videos: ModelType
def __or__(self, other: "ModelType") -> "ModelType":
def __or__(self, other: ModelType) -> ModelType:
...
def supports_chat(self) -> bool:
......@@ -1091,6 +1091,8 @@ async def register_model(
runtime_config: Optional[ModelRuntimeConfig] = None,
user_data: Optional[Dict[str, Any]] = None,
custom_template_path: Optional[str] = None,
media_decoder: Optional[MediaDecoder] = None,
media_fetcher: Optional[MediaFetcher] = None,
lora_name: Optional[str] = None,
base_model_path: Optional[str] = None,
) -> None:
......@@ -1649,6 +1651,8 @@ class PlannerDecision:
-1 in any of those fields mean not set, usually because planner hasn't decided anything yet.
Call VirtualConnectorClient.complete(event) when action is completed.
"""
num_prefill_workers: int
num_decode_workers: int
...
class VirtualConnectorCoordinator:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment