Unverified Commit 2712426f authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: enable mypy in pre-merge (#6732)

parent e5e118a1
......@@ -34,6 +34,7 @@ from dynamo.llm import (
)
from dynamo.runtime import Endpoint
from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.omni.args import OmniConfig
from dynamo.vllm.worker_factory import WorkerFactory
from . import envs
......@@ -168,7 +169,7 @@ async def worker() -> None:
def setup_metrics_collection(
config: Config, generate_endpoint: Endpoint, logger: logging.Logger
config: Config | OmniConfig, generate_endpoint: Endpoint, logger: logging.Logger
) -> None:
"""Set up metrics collection for vLLM and LMCache metrics.
......
......@@ -18,6 +18,7 @@ from dynamo.common.multimodal import (
NixlReadEmbeddingSender,
NixlWriteEmbeddingSender,
)
from dynamo.common.multimodal.embedding_transfer import AbstractEmbeddingSender
from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.runtime import DistributedRuntime
......@@ -85,6 +86,7 @@ class EncodeWorkerHandler:
self._processed_requests = 0
self.readables: list[Any] = []
self.embedding_cache = EmbeddingCache() if ENABLE_ENCODER_CACHE else None
self.embedding_sender: AbstractEmbeddingSender
if embedding_transfer_mode == EmbeddingTransferMode.LOCAL:
self.embedding_sender = LocalEmbeddingSender()
elif embedding_transfer_mode == EmbeddingTransferMode.NIXL_WRITE:
......@@ -136,6 +138,9 @@ class EncodeWorkerHandler:
logger.debug(f"Received encode request: {{ id: {request.request_id} }}.")
request_id = request.request_id
assert (
request.multimodal_inputs is not None
), "multimodal_inputs must not be None for encode worker"
# The following steps encode the requested image and provided useful embeddings.
# 1. Open the image from the provided URL.
......@@ -157,12 +162,11 @@ class EncodeWorkerHandler:
request.multimodal_inputs
)
for idx in range(len(request.multimodal_inputs)):
if not request.multimodal_inputs[idx].multimodal_input.image_url:
group_input = request.multimodal_inputs[idx].multimodal_input
if group_input is None or not group_input.image_url:
raise ValueError("image_url is required for the encode worker.")
image_url = request.multimodal_inputs[
idx
].multimodal_input.image_url
image_url = group_input.image_url
# see if we have local cache
embedding_key = EmbeddingCache.generate_hash_key(image_url)
if (
......@@ -189,7 +193,10 @@ class EncodeWorkerHandler:
image_tasks = []
image_to_load = []
for idx, _ in need_encode_indexes:
url = request.multimodal_inputs[idx].multimodal_input.image_url
group_mm_input = request.multimodal_inputs[idx].multimodal_input
assert group_mm_input is not None
assert group_mm_input.image_url is not None
url: str = group_mm_input.image_url
image_tasks.append(
asyncio.create_task(self.image_loader.load_image(url))
)
......@@ -305,16 +312,12 @@ class EncodeWorkerHandler:
f"{embedding_item.embeddings.shape} prepared for transfer."
)
# Update request for transfer metadata
request.multimodal_inputs[idx].multimodal_input.image_url = None
request.multimodal_inputs[
idx
].image_grid_thw = embedding_item.image_grid_thw
request.multimodal_inputs[idx].embeddings_shape = tuple(
embedding_item.embeddings.shape
)
request.multimodal_inputs[
idx
].serialized_request = transfer_request[0]
group = request.multimodal_inputs[idx]
assert group.multimodal_input is not None
group.multimodal_input.image_url = None
group.image_grid_thw = embedding_item.image_grid_thw
group.embeddings_shape = tuple(embedding_item.embeddings.shape) # type: ignore[assignment]
group.serialized_request = transfer_request[0]
# Keep a reference of the embedding and only drop reference when the transfer is done
self.send_complete_queue.put_nowait(
......
......@@ -14,6 +14,7 @@ from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager,
)
from dynamo.common.multimodal.embedding_transfer import (
AbstractEmbeddingReceiver,
LocalEmbeddingReceiver,
NixlReadEmbeddingReceiver,
NixlWriteEmbeddingReceiver,
......@@ -39,7 +40,7 @@ logger = logging.getLogger(__name__)
IMAGE_URL_KEY = "image_url"
class MultimodalPDWorkerHandler(BaseWorkerHandler):
class MultimodalPDWorkerHandler(BaseWorkerHandler[dict, dict]):
"""Prefill/Decode or Prefill-only worker for multimodal serving"""
def __init__(
......@@ -88,7 +89,9 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
# and used to determine whether remote encode is necessary for a given mm data.
self.encode_worker_client = encode_worker_client
if config.embedding_transfer_mode == EmbeddingTransferMode.LOCAL:
self.embedding_receiver = LocalEmbeddingReceiver()
self.embedding_receiver: AbstractEmbeddingReceiver = (
LocalEmbeddingReceiver()
)
elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_WRITE:
self.embedding_receiver = NixlWriteEmbeddingReceiver()
elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_READ:
......@@ -381,12 +384,12 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
) as decode_timer,
):
num_output_tokens_so_far = 0
async for (
decode_response
) in await self.decode_worker_client.round_robin( # type: ignore
if self.decode_worker_client is None:
raise RuntimeError("Decode worker client is not configured.")
async for (decode_response) in await self.decode_worker_client.round_robin(
request.model_dump_json(), context=context
):
output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore
output = MyRequestOutput.model_validate_json(decode_response.data())
yield self._format_engine_output(output, num_output_tokens_so_far)
if output.outputs:
if num_output_tokens_so_far == 0:
......
......@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import AsyncIterator
from vllm.inputs.data import TokensPrompt
......@@ -20,7 +21,7 @@ from ..multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_m
logger = logging.getLogger(__name__)
class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
class MultimodalDecodeWorkerHandler(BaseWorkerHandler[vLLMMultimodalRequest, str]):
"""Decode worker for disaggregated multimodal serving"""
def __init__(
......@@ -55,7 +56,9 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
self._connector = connect.Connector()
logger.info("Multimodal Decode Worker async initialization completed.")
async def generate(self, request: vLLMMultimodalRequest, context):
async def generate(
self, request: vLLMMultimodalRequest, context
) -> AsyncIterator[str]:
rng_decode = _nvtx.start_range("mm:decode_worker_generate", color="blue")
logger.debug(f"Got raw request: {request}")
if not isinstance(request, vLLMMultimodalRequest):
......
......@@ -279,14 +279,14 @@ def construct_qwen_decode_mm_data(
# that happen to have the same dimensions (same image_grid_thw).
# bit ops to convert request ID to somewhat unique value that fits in the dtype range
if not hasattr(construct_qwen_decode_mm_data, "_counter"):
construct_qwen_decode_mm_data._counter = 0
fill_value = construct_qwen_decode_mm_data._counter
construct_qwen_decode_mm_data._counter += 1
construct_qwen_decode_mm_data._counter = 0 # type: ignore[attr-defined]
fill_value = construct_qwen_decode_mm_data._counter # type: ignore[attr-defined]
construct_qwen_decode_mm_data._counter += 1 # type: ignore[attr-defined]
max_val = (
torch.finfo(dtype).max if dtype.is_floating_point else torch.iinfo(dtype).max
)
if construct_qwen_decode_mm_data._counter > max_val:
construct_qwen_decode_mm_data._counter = 0
if construct_qwen_decode_mm_data._counter > max_val: # type: ignore[attr-defined]
construct_qwen_decode_mm_data._counter = 0 # type: ignore[attr-defined]
image_embeds = torch.full(
embeddings_shape, fill_value=fill_value, dtype=dtype, device="cpu"
)
......
......@@ -204,6 +204,7 @@ async def _fetch_from_encode_workers(
tasks = [
asyncio.create_task(receiver.receive_embeddings(group.serialized_request))
for group in multimodal_groups
if group.serialized_request is not None
]
loaded = await asyncio.gather(*tasks)
......
......@@ -16,12 +16,13 @@ try:
except ImportError:
DiffusionParallelConfig = None # type: ignore[assignment, misc]
from dynamo._core import Context
from dynamo.vllm.handlers import BaseWorkerHandler, build_sampling_params
logger = logging.getLogger(__name__)
class BaseOmniHandler(BaseWorkerHandler):
class BaseOmniHandler(BaseWorkerHandler[Dict[str, Any], Dict[str, Any]]):
"""Base handler for multi-stage pipelines using vLLM-Omni's AsyncOmni orchestrator."""
def __init__(
......@@ -107,8 +108,8 @@ class BaseOmniHandler(BaseWorkerHandler):
return omni_kwargs
async def generate(
self, request: Dict[str, Any], context
) -> AsyncGenerator[Dict, None]:
self, request: Dict[str, Any], context: Context
) -> AsyncGenerator[Dict[str, Any], None]:
"""Generate outputs using AsyncOmni orchestrator with OpenAI-compatible format.
Subclasses should override ``_generate_openai_mode`` for custom output handling.
......@@ -116,7 +117,7 @@ class BaseOmniHandler(BaseWorkerHandler):
request_id = context.id()
logger.debug(f"Omni Request ID: {request_id}")
async for chunk in self._generate_openai_mode(request, context, request_id): # type: ignore
async for chunk in self._generate_openai_mode(request, context, request_id):
yield chunk
async def _generate_openai_mode(
......@@ -130,6 +131,8 @@ class BaseOmniHandler(BaseWorkerHandler):
raise NotImplementedError(
f"{self.__class__.__name__} must implement _generate_openai_mode"
)
# Make this a proper async generator so the return type is correct.
yield # pragma: no cover
def _extract_text_prompt(self, request: Dict[str, Any]) -> str | None:
"""Extract text prompt from OpenAI messages format.
......
......@@ -8,13 +8,14 @@ import time
import uuid
from dataclasses import dataclass
from io import BytesIO
from typing import Any, AsyncGenerator, Dict, Optional, Union
from typing import Any, AsyncGenerator, Dict, Optional, Union, cast
import PIL.Image
from diffusers.utils import export_to_video
from fsspec.implementations.dirfs import DirFileSystem
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt
from dynamo._core import Context
from dynamo.common.multimodal import ImageLoader
from dynamo.common.protocols.image_protocol import (
ImageData,
......@@ -99,8 +100,8 @@ class OmniHandler(BaseOmniHandler):
self._image_loader = ImageLoader()
async def generate(
self, request: Dict[str, Any], context
) -> AsyncGenerator[Dict, None]:
self, request: Dict[str, Any], context: Context
) -> AsyncGenerator[Dict[str, Any], None]:
"""Generate outputs via the unified OpenAI mode.
Args:
......@@ -111,19 +112,24 @@ class OmniHandler(BaseOmniHandler):
Response dictionaries.
"""
request_id = context.id()
assert request_id is not None, "Request ID is required"
logger.debug(f"Omni Request ID: {request_id}")
async for chunk in self._generate_openai_mode(request, context, request_id):
yield chunk
async def _generate_openai_mode(
self, request: Dict[str, Any], context, request_id: str
self, request: Dict[str, Any], context: Context, request_id: str
) -> AsyncGenerator[Dict[str, Any], None]:
"""Single generation path for all request protocols and output modalities."""
parsed_request, request_type = parse_request_type(
parsed_request_raw, request_type = parse_request_type(
request, self.config.output_modalities
)
parsed_request = cast(
Union[NvCreateImageRequest, NvCreateVideoRequest, Dict[str, Any]],
parsed_request_raw,
)
# Pre-load input image for I2V requests (async I/O before sync build)
image = None
......@@ -227,10 +233,13 @@ class OmniHandler(BaseOmniHandler):
EngineInputs ready for engine_client.generate().
"""
if request_type == RequestType.CHAT_COMPLETION:
assert isinstance(parsed_request, dict)
return self._engine_inputs_from_chat(parsed_request)
elif request_type == RequestType.IMAGE_GENERATION:
assert isinstance(parsed_request, NvCreateImageRequest)
return self._engine_inputs_from_image(parsed_request)
elif request_type == RequestType.VIDEO_GENERATION:
assert isinstance(parsed_request, NvCreateVideoRequest)
return self._engine_inputs_from_video(parsed_request, image=image)
elif request_type == RequestType.AUDIO_GENERATION:
......
......@@ -262,6 +262,7 @@ class WorkerFactory:
logger.info("Connected to decode worker for disaggregated mode")
# Choose handler based on worker type
handler: MultimodalDecodeWorkerHandler | MultimodalPDWorkerHandler
if config.multimodal_decode_worker:
handler = MultimodalDecodeWorkerHandler(
runtime,
......@@ -289,7 +290,7 @@ class WorkerFactory:
config, generate_endpoint, vllm_config
)
if kv_publisher:
handler.kv_publisher = kv_publisher
handler.kv_publisher = kv_publisher # type: ignore[attr-defined, union-attr]
if not config.multimodal_decode_worker:
model_type = parse_endpoint_types(config.endpoint_types)
......@@ -357,7 +358,7 @@ class WorkerFactory:
shutdown_endpoints[:] = [generate_endpoint]
handler = EncodeWorkerHandler(
config.engine_args, config.embedding_transfer_mode
config.engine_args, config.embedding_transfer_mode # type: ignore[arg-type]
)
await handler.async_init(runtime)
logger.info("Starting to serve the encode worker endpoint...")
......
......@@ -13,7 +13,9 @@ filelock==3.25.1
kr8s==0.20.13
kubernetes_asyncio==32.0.0
matplotlib==3.10.7
matplotlib-stubs
mistral-common==1.9.1
mypy==1.18.2
# For NATS object store verification in router tests
nats-py==2.12.0
psutil<=7.0.0 # System package, may vary by platform (was >=5.0.0)
......@@ -27,7 +29,6 @@ pytest-cov==7.0.0
pytest-forked==1.6.0
pytest-httpserver==1.1.3
pytest-md-report==0.7.0
pytest-mypy==1.0.1
pytest-order==1.3.0
pytest-timeout==2.4.0
pytest-xdist==3.8.0
......
......@@ -2,17 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Dict,
List,
Optional,
Tuple,
)
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Tuple
# Import from specialized modules
from .prometheus_metrics import RuntimeMetrics as PyRuntimeMetrics
......@@ -31,14 +21,14 @@ def get_reasoning_parser_names() -> list[str]:
"""Get list of available reasoning parser names."""
...
class JsonLike:
"""
Any PyObject which can be serialized to JSON
"""
def run_kv_indexer(args: List[str]) -> None:
"""Run the KV indexer with the given arguments."""
...
RequestHandler = Callable[[JsonLike], AsyncGenerator[JsonLike, None]]
# Any Python object that can be serialized to JSON (dict, list, str, int, etc.)
JsonLike = Any
RequestHandler = Callable[..., AsyncIterator[JsonLike]]
class DistributedRuntime:
"""
......@@ -473,8 +463,6 @@ class ModelRuntimeConfig:
enable_local_indexer: bool
runtime_data: dict[str, Any]
tensor_model_config: Any | None
data_parallel_size: int
data_parallel_start_rank: int
bootstrap_host: str | None
bootstrap_port: int | None
......
......@@ -149,8 +149,8 @@ minversion = "8.0"
tmp_path_retention_policy = "failed"
# NOTE
# We ignore model.py explicitly here to avoid mypy errors with duplicate modules
# pytest overrides the default mypy exclude configuration and so we exclude here as well
# Keep these ignores in pytest collection to avoid duplicate-module collection
# errors (for example, backend trees that include multiple model.py files).
addopts = [
"-ra",
"--showlocals",
......@@ -292,17 +292,33 @@ venv = ".venv"
# tensorrt_llm and vllm are both named common.
explicit_package_bases = true
# --ignore-missing-imports: WAR too many errors when developing outside
# of container environment with PYTHONPATH set and packages installed.
# NOTE: Can possibly move mypy from pre-commit to a github action run only in
# a container with the expected environment and PYTHONPATH setup.
check_untyped_defs = true
[[tool.mypy.overrides]]
# _version.py is generated at build time and does not exist in the source tree.
module = ["dynamo.*._version"]
ignore_missing_imports = true
check_untyped_defs = true
[[tool.mypy.overrides]]
# Skip type checking for test files.
module = ["dynamo.*.tests.*", "dynamo.*.tests"]
ignore_errors = true
[[tool.mypy.overrides]]
# Skip mypy analysis on backend framework internals.
# ignore_missing_imports silences import-not-found only when the backend
# is not installed (e.g. sglang/trtllm missing in the vllm container).
module = ["vllm", "vllm.*"]
follow_imports = "skip"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = ["sglang", "sglang.*"]
follow_imports = "skip"
ignore_missing_imports = true
[[tool.mypy.overrides]]
# Skip mypy analysis on internal dependencies of vllm
module = ["vllm.*"]
module = ["tensorrt_llm", "tensorrt_llm.*"]
follow_imports = "skip"
ignore_missing_imports = true
......@@ -312,6 +328,95 @@ ignore_missing_imports = true
module = ["numpy", "numpy.*"]
follow_imports = "skip"
[[tool.mypy.overrides]]
# Third-party libs without type stubs or optional internal deps
# TODO: fix the ones that do have stub package
module = [
"nvtx",
"fsspec",
"fsspec.*",
"kubernetes",
"kubernetes.*",
"scipy",
"scipy.*",
"sklearn",
"sklearn.*",
"pandas",
"pandas.*",
"pmdarima",
"pmdarima.*",
"filterpy",
"filterpy.*",
"prophet",
"prophet.*",
"msgpack",
"nixl",
"nixl.*",
"imageio",
"imageio.*",
"yaml",
"prometheus_api_client",
"prometheus_api_client.*",
"aiohttp",
"aiohttp.*",
"vllm_omni",
"vllm_omni.*",
"modelexpress",
"modelexpress.*",
"kvbm",
"kvbm.*",
"diffusers",
"diffusers.*",
"PIL",
"PIL.*",
"torch",
"torch.*",
"transformers",
"transformers.*",
"cupy",
"cupy.*",
"gpu_memory_service",
"gpu_memory_service.*",
"pydantic",
"pydantic.*",
"uvloop",
"prometheus_client",
"prometheus_client.*",
"pybase64",
"blake3",
"cupy_backends",
"cupy_backends.*",
"huggingface_hub",
"huggingface_hub.*",
"httpx",
"httpx.*",
"zmq",
"zmq.*",
"safetensors",
"safetensors.*",
"gradio",
"gradio.*",
"kubernetes_asyncio",
"kubernetes_asyncio.*",
"pydantic_core",
"aiconfigurator",
"aiconfigurator.*",
]
ignore_missing_imports = true
[[tool.mypy.overrides]]
# msgspec.Struct uses custom __init_subclass__ kwargs (frozen, gc) that mypy
# cannot resolve without the msgspec package installed.
module = ["msgspec", "msgspec.*"]
follow_imports = "skip"
ignore_missing_imports = true
[[tool.mypy.overrides]]
# Profiler module was never previously type-checked and has many
# union-attr / attr-defined issues. Skip errors for now.
module = ["dynamo.profiler.*"]
ignore_errors = true
[tool.sphinx]
# extra-content-head
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment