Unverified Commit 2712426f authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

feat: enable mypy in pre-merge (#6732)

parent e5e118a1
...@@ -34,6 +34,7 @@ from dynamo.llm import ( ...@@ -34,6 +34,7 @@ from dynamo.llm import (
) )
from dynamo.runtime import Endpoint from dynamo.runtime import Endpoint
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.omni.args import OmniConfig
from dynamo.vllm.worker_factory import WorkerFactory from dynamo.vllm.worker_factory import WorkerFactory
from . import envs from . import envs
...@@ -168,7 +169,7 @@ async def worker() -> None: ...@@ -168,7 +169,7 @@ async def worker() -> None:
def setup_metrics_collection( def setup_metrics_collection(
config: Config, generate_endpoint: Endpoint, logger: logging.Logger config: Config | OmniConfig, generate_endpoint: Endpoint, logger: logging.Logger
) -> None: ) -> None:
"""Set up metrics collection for vLLM and LMCache metrics. """Set up metrics collection for vLLM and LMCache metrics.
......
...@@ -18,6 +18,7 @@ from dynamo.common.multimodal import ( ...@@ -18,6 +18,7 @@ from dynamo.common.multimodal import (
NixlReadEmbeddingSender, NixlReadEmbeddingSender,
NixlWriteEmbeddingSender, NixlWriteEmbeddingSender,
) )
from dynamo.common.multimodal.embedding_transfer import AbstractEmbeddingSender
from dynamo.common.utils import nvtx_utils as _nvtx from dynamo.common.utils import nvtx_utils as _nvtx
from dynamo.common.utils.time_section import time_and_log_code_section from dynamo.common.utils.time_section import time_and_log_code_section
from dynamo.runtime import DistributedRuntime from dynamo.runtime import DistributedRuntime
...@@ -85,6 +86,7 @@ class EncodeWorkerHandler: ...@@ -85,6 +86,7 @@ class EncodeWorkerHandler:
self._processed_requests = 0 self._processed_requests = 0
self.readables: list[Any] = [] self.readables: list[Any] = []
self.embedding_cache = EmbeddingCache() if ENABLE_ENCODER_CACHE else None self.embedding_cache = EmbeddingCache() if ENABLE_ENCODER_CACHE else None
self.embedding_sender: AbstractEmbeddingSender
if embedding_transfer_mode == EmbeddingTransferMode.LOCAL: if embedding_transfer_mode == EmbeddingTransferMode.LOCAL:
self.embedding_sender = LocalEmbeddingSender() self.embedding_sender = LocalEmbeddingSender()
elif embedding_transfer_mode == EmbeddingTransferMode.NIXL_WRITE: elif embedding_transfer_mode == EmbeddingTransferMode.NIXL_WRITE:
...@@ -136,6 +138,9 @@ class EncodeWorkerHandler: ...@@ -136,6 +138,9 @@ class EncodeWorkerHandler:
logger.debug(f"Received encode request: {{ id: {request.request_id} }}.") logger.debug(f"Received encode request: {{ id: {request.request_id} }}.")
request_id = request.request_id request_id = request.request_id
assert (
request.multimodal_inputs is not None
), "multimodal_inputs must not be None for encode worker"
# The following steps encode the requested image and provided useful embeddings. # The following steps encode the requested image and provided useful embeddings.
# 1. Open the image from the provided URL. # 1. Open the image from the provided URL.
...@@ -157,12 +162,11 @@ class EncodeWorkerHandler: ...@@ -157,12 +162,11 @@ class EncodeWorkerHandler:
request.multimodal_inputs request.multimodal_inputs
) )
for idx in range(len(request.multimodal_inputs)): for idx in range(len(request.multimodal_inputs)):
if not request.multimodal_inputs[idx].multimodal_input.image_url: group_input = request.multimodal_inputs[idx].multimodal_input
if group_input is None or not group_input.image_url:
raise ValueError("image_url is required for the encode worker.") raise ValueError("image_url is required for the encode worker.")
image_url = request.multimodal_inputs[ image_url = group_input.image_url
idx
].multimodal_input.image_url
# see if we have local cache # see if we have local cache
embedding_key = EmbeddingCache.generate_hash_key(image_url) embedding_key = EmbeddingCache.generate_hash_key(image_url)
if ( if (
...@@ -189,7 +193,10 @@ class EncodeWorkerHandler: ...@@ -189,7 +193,10 @@ class EncodeWorkerHandler:
image_tasks = [] image_tasks = []
image_to_load = [] image_to_load = []
for idx, _ in need_encode_indexes: for idx, _ in need_encode_indexes:
url = request.multimodal_inputs[idx].multimodal_input.image_url group_mm_input = request.multimodal_inputs[idx].multimodal_input
assert group_mm_input is not None
assert group_mm_input.image_url is not None
url: str = group_mm_input.image_url
image_tasks.append( image_tasks.append(
asyncio.create_task(self.image_loader.load_image(url)) asyncio.create_task(self.image_loader.load_image(url))
) )
...@@ -305,16 +312,12 @@ class EncodeWorkerHandler: ...@@ -305,16 +312,12 @@ class EncodeWorkerHandler:
f"{embedding_item.embeddings.shape} prepared for transfer." f"{embedding_item.embeddings.shape} prepared for transfer."
) )
# Update request for transfer metadata # Update request for transfer metadata
request.multimodal_inputs[idx].multimodal_input.image_url = None group = request.multimodal_inputs[idx]
request.multimodal_inputs[ assert group.multimodal_input is not None
idx group.multimodal_input.image_url = None
].image_grid_thw = embedding_item.image_grid_thw group.image_grid_thw = embedding_item.image_grid_thw
request.multimodal_inputs[idx].embeddings_shape = tuple( group.embeddings_shape = tuple(embedding_item.embeddings.shape) # type: ignore[assignment]
embedding_item.embeddings.shape group.serialized_request = transfer_request[0]
)
request.multimodal_inputs[
idx
].serialized_request = transfer_request[0]
# Keep a reference of the embedding and only drop reference when the transfer is done # Keep a reference of the embedding and only drop reference when the transfer is done
self.send_complete_queue.put_nowait( self.send_complete_queue.put_nowait(
......
...@@ -14,6 +14,7 @@ from dynamo.common.memory.multimodal_embedding_cache_manager import ( ...@@ -14,6 +14,7 @@ from dynamo.common.memory.multimodal_embedding_cache_manager import (
MultimodalEmbeddingCacheManager, MultimodalEmbeddingCacheManager,
) )
from dynamo.common.multimodal.embedding_transfer import ( from dynamo.common.multimodal.embedding_transfer import (
AbstractEmbeddingReceiver,
LocalEmbeddingReceiver, LocalEmbeddingReceiver,
NixlReadEmbeddingReceiver, NixlReadEmbeddingReceiver,
NixlWriteEmbeddingReceiver, NixlWriteEmbeddingReceiver,
...@@ -39,7 +40,7 @@ logger = logging.getLogger(__name__) ...@@ -39,7 +40,7 @@ logger = logging.getLogger(__name__)
IMAGE_URL_KEY = "image_url" IMAGE_URL_KEY = "image_url"
class MultimodalPDWorkerHandler(BaseWorkerHandler): class MultimodalPDWorkerHandler(BaseWorkerHandler[dict, dict]):
"""Prefill/Decode or Prefill-only worker for multimodal serving""" """Prefill/Decode or Prefill-only worker for multimodal serving"""
def __init__( def __init__(
...@@ -88,7 +89,9 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -88,7 +89,9 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
# and used to determine whether remote encode is necessary for a given mm data. # and used to determine whether remote encode is necessary for a given mm data.
self.encode_worker_client = encode_worker_client self.encode_worker_client = encode_worker_client
if config.embedding_transfer_mode == EmbeddingTransferMode.LOCAL: if config.embedding_transfer_mode == EmbeddingTransferMode.LOCAL:
self.embedding_receiver = LocalEmbeddingReceiver() self.embedding_receiver: AbstractEmbeddingReceiver = (
LocalEmbeddingReceiver()
)
elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_WRITE: elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_WRITE:
self.embedding_receiver = NixlWriteEmbeddingReceiver() self.embedding_receiver = NixlWriteEmbeddingReceiver()
elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_READ: elif config.embedding_transfer_mode == EmbeddingTransferMode.NIXL_READ:
...@@ -381,12 +384,12 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -381,12 +384,12 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
) as decode_timer, ) as decode_timer,
): ):
num_output_tokens_so_far = 0 num_output_tokens_so_far = 0
async for ( if self.decode_worker_client is None:
decode_response raise RuntimeError("Decode worker client is not configured.")
) in await self.decode_worker_client.round_robin( # type: ignore async for (decode_response) in await self.decode_worker_client.round_robin(
request.model_dump_json(), context=context request.model_dump_json(), context=context
): ):
output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore output = MyRequestOutput.model_validate_json(decode_response.data())
yield self._format_engine_output(output, num_output_tokens_so_far) yield self._format_engine_output(output, num_output_tokens_so_far)
if output.outputs: if output.outputs:
if num_output_tokens_so_far == 0: if num_output_tokens_so_far == 0:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import logging import logging
from typing import AsyncIterator
from vllm.inputs.data import TokensPrompt from vllm.inputs.data import TokensPrompt
...@@ -20,7 +21,7 @@ from ..multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_m ...@@ -20,7 +21,7 @@ from ..multimodal_utils.model import construct_qwen_decode_mm_data, is_qwen_vl_m
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class MultimodalDecodeWorkerHandler(BaseWorkerHandler): class MultimodalDecodeWorkerHandler(BaseWorkerHandler[vLLMMultimodalRequest, str]):
"""Decode worker for disaggregated multimodal serving""" """Decode worker for disaggregated multimodal serving"""
def __init__( def __init__(
...@@ -55,7 +56,9 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler): ...@@ -55,7 +56,9 @@ class MultimodalDecodeWorkerHandler(BaseWorkerHandler):
self._connector = connect.Connector() self._connector = connect.Connector()
logger.info("Multimodal Decode Worker async initialization completed.") logger.info("Multimodal Decode Worker async initialization completed.")
async def generate(self, request: vLLMMultimodalRequest, context): async def generate(
self, request: vLLMMultimodalRequest, context
) -> AsyncIterator[str]:
rng_decode = _nvtx.start_range("mm:decode_worker_generate", color="blue") rng_decode = _nvtx.start_range("mm:decode_worker_generate", color="blue")
logger.debug(f"Got raw request: {request}") logger.debug(f"Got raw request: {request}")
if not isinstance(request, vLLMMultimodalRequest): if not isinstance(request, vLLMMultimodalRequest):
......
...@@ -279,14 +279,14 @@ def construct_qwen_decode_mm_data( ...@@ -279,14 +279,14 @@ def construct_qwen_decode_mm_data(
# that happen to have the same dimensions (same image_grid_thw). # that happen to have the same dimensions (same image_grid_thw).
# bit ops to convert request ID to somewhat unique value that fits in the dtype range # bit ops to convert request ID to somewhat unique value that fits in the dtype range
if not hasattr(construct_qwen_decode_mm_data, "_counter"): if not hasattr(construct_qwen_decode_mm_data, "_counter"):
construct_qwen_decode_mm_data._counter = 0 construct_qwen_decode_mm_data._counter = 0 # type: ignore[attr-defined]
fill_value = construct_qwen_decode_mm_data._counter fill_value = construct_qwen_decode_mm_data._counter # type: ignore[attr-defined]
construct_qwen_decode_mm_data._counter += 1 construct_qwen_decode_mm_data._counter += 1 # type: ignore[attr-defined]
max_val = ( max_val = (
torch.finfo(dtype).max if dtype.is_floating_point else torch.iinfo(dtype).max torch.finfo(dtype).max if dtype.is_floating_point else torch.iinfo(dtype).max
) )
if construct_qwen_decode_mm_data._counter > max_val: if construct_qwen_decode_mm_data._counter > max_val: # type: ignore[attr-defined]
construct_qwen_decode_mm_data._counter = 0 construct_qwen_decode_mm_data._counter = 0 # type: ignore[attr-defined]
image_embeds = torch.full( image_embeds = torch.full(
embeddings_shape, fill_value=fill_value, dtype=dtype, device="cpu" embeddings_shape, fill_value=fill_value, dtype=dtype, device="cpu"
) )
......
...@@ -204,6 +204,7 @@ async def _fetch_from_encode_workers( ...@@ -204,6 +204,7 @@ async def _fetch_from_encode_workers(
tasks = [ tasks = [
asyncio.create_task(receiver.receive_embeddings(group.serialized_request)) asyncio.create_task(receiver.receive_embeddings(group.serialized_request))
for group in multimodal_groups for group in multimodal_groups
if group.serialized_request is not None
] ]
loaded = await asyncio.gather(*tasks) loaded = await asyncio.gather(*tasks)
......
...@@ -16,12 +16,13 @@ try: ...@@ -16,12 +16,13 @@ try:
except ImportError: except ImportError:
DiffusionParallelConfig = None # type: ignore[assignment, misc] DiffusionParallelConfig = None # type: ignore[assignment, misc]
from dynamo._core import Context
from dynamo.vllm.handlers import BaseWorkerHandler, build_sampling_params from dynamo.vllm.handlers import BaseWorkerHandler, build_sampling_params
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BaseOmniHandler(BaseWorkerHandler): class BaseOmniHandler(BaseWorkerHandler[Dict[str, Any], Dict[str, Any]]):
"""Base handler for multi-stage pipelines using vLLM-Omni's AsyncOmni orchestrator.""" """Base handler for multi-stage pipelines using vLLM-Omni's AsyncOmni orchestrator."""
def __init__( def __init__(
...@@ -107,8 +108,8 @@ class BaseOmniHandler(BaseWorkerHandler): ...@@ -107,8 +108,8 @@ class BaseOmniHandler(BaseWorkerHandler):
return omni_kwargs return omni_kwargs
async def generate( async def generate(
self, request: Dict[str, Any], context self, request: Dict[str, Any], context: Context
) -> AsyncGenerator[Dict, None]: ) -> AsyncGenerator[Dict[str, Any], None]:
"""Generate outputs using AsyncOmni orchestrator with OpenAI-compatible format. """Generate outputs using AsyncOmni orchestrator with OpenAI-compatible format.
Subclasses should override ``_generate_openai_mode`` for custom output handling. Subclasses should override ``_generate_openai_mode`` for custom output handling.
...@@ -116,7 +117,7 @@ class BaseOmniHandler(BaseWorkerHandler): ...@@ -116,7 +117,7 @@ class BaseOmniHandler(BaseWorkerHandler):
request_id = context.id() request_id = context.id()
logger.debug(f"Omni Request ID: {request_id}") logger.debug(f"Omni Request ID: {request_id}")
async for chunk in self._generate_openai_mode(request, context, request_id): # type: ignore async for chunk in self._generate_openai_mode(request, context, request_id):
yield chunk yield chunk
async def _generate_openai_mode( async def _generate_openai_mode(
...@@ -130,6 +131,8 @@ class BaseOmniHandler(BaseWorkerHandler): ...@@ -130,6 +131,8 @@ class BaseOmniHandler(BaseWorkerHandler):
raise NotImplementedError( raise NotImplementedError(
f"{self.__class__.__name__} must implement _generate_openai_mode" f"{self.__class__.__name__} must implement _generate_openai_mode"
) )
# Make this a proper async generator so the return type is correct.
yield # pragma: no cover
def _extract_text_prompt(self, request: Dict[str, Any]) -> str | None: def _extract_text_prompt(self, request: Dict[str, Any]) -> str | None:
"""Extract text prompt from OpenAI messages format. """Extract text prompt from OpenAI messages format.
......
...@@ -8,13 +8,14 @@ import time ...@@ -8,13 +8,14 @@ import time
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from io import BytesIO
from typing import Any, AsyncGenerator, Dict, Optional, Union from typing import Any, AsyncGenerator, Dict, Optional, Union, cast
import PIL.Image import PIL.Image
from diffusers.utils import export_to_video from diffusers.utils import export_to_video
from fsspec.implementations.dirfs import DirFileSystem from fsspec.implementations.dirfs import DirFileSystem
from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniTextPrompt
from dynamo._core import Context
from dynamo.common.multimodal import ImageLoader from dynamo.common.multimodal import ImageLoader
from dynamo.common.protocols.image_protocol import ( from dynamo.common.protocols.image_protocol import (
ImageData, ImageData,
...@@ -99,8 +100,8 @@ class OmniHandler(BaseOmniHandler): ...@@ -99,8 +100,8 @@ class OmniHandler(BaseOmniHandler):
self._image_loader = ImageLoader() self._image_loader = ImageLoader()
async def generate( async def generate(
self, request: Dict[str, Any], context self, request: Dict[str, Any], context: Context
) -> AsyncGenerator[Dict, None]: ) -> AsyncGenerator[Dict[str, Any], None]:
"""Generate outputs via the unified OpenAI mode. """Generate outputs via the unified OpenAI mode.
Args: Args:
...@@ -111,19 +112,24 @@ class OmniHandler(BaseOmniHandler): ...@@ -111,19 +112,24 @@ class OmniHandler(BaseOmniHandler):
Response dictionaries. Response dictionaries.
""" """
request_id = context.id() request_id = context.id()
assert request_id is not None, "Request ID is required"
logger.debug(f"Omni Request ID: {request_id}") logger.debug(f"Omni Request ID: {request_id}")
async for chunk in self._generate_openai_mode(request, context, request_id): async for chunk in self._generate_openai_mode(request, context, request_id):
yield chunk yield chunk
async def _generate_openai_mode( async def _generate_openai_mode(
self, request: Dict[str, Any], context, request_id: str self, request: Dict[str, Any], context: Context, request_id: str
) -> AsyncGenerator[Dict[str, Any], None]: ) -> AsyncGenerator[Dict[str, Any], None]:
"""Single generation path for all request protocols and output modalities.""" """Single generation path for all request protocols and output modalities."""
parsed_request, request_type = parse_request_type( parsed_request_raw, request_type = parse_request_type(
request, self.config.output_modalities request, self.config.output_modalities
) )
parsed_request = cast(
Union[NvCreateImageRequest, NvCreateVideoRequest, Dict[str, Any]],
parsed_request_raw,
)
# Pre-load input image for I2V requests (async I/O before sync build) # Pre-load input image for I2V requests (async I/O before sync build)
image = None image = None
...@@ -227,10 +233,13 @@ class OmniHandler(BaseOmniHandler): ...@@ -227,10 +233,13 @@ class OmniHandler(BaseOmniHandler):
EngineInputs ready for engine_client.generate(). EngineInputs ready for engine_client.generate().
""" """
if request_type == RequestType.CHAT_COMPLETION: if request_type == RequestType.CHAT_COMPLETION:
assert isinstance(parsed_request, dict)
return self._engine_inputs_from_chat(parsed_request) return self._engine_inputs_from_chat(parsed_request)
elif request_type == RequestType.IMAGE_GENERATION: elif request_type == RequestType.IMAGE_GENERATION:
assert isinstance(parsed_request, NvCreateImageRequest)
return self._engine_inputs_from_image(parsed_request) return self._engine_inputs_from_image(parsed_request)
elif request_type == RequestType.VIDEO_GENERATION: elif request_type == RequestType.VIDEO_GENERATION:
assert isinstance(parsed_request, NvCreateVideoRequest)
return self._engine_inputs_from_video(parsed_request, image=image) return self._engine_inputs_from_video(parsed_request, image=image)
elif request_type == RequestType.AUDIO_GENERATION: elif request_type == RequestType.AUDIO_GENERATION:
......
...@@ -262,6 +262,7 @@ class WorkerFactory: ...@@ -262,6 +262,7 @@ class WorkerFactory:
logger.info("Connected to decode worker for disaggregated mode") logger.info("Connected to decode worker for disaggregated mode")
# Choose handler based on worker type # Choose handler based on worker type
handler: MultimodalDecodeWorkerHandler | MultimodalPDWorkerHandler
if config.multimodal_decode_worker: if config.multimodal_decode_worker:
handler = MultimodalDecodeWorkerHandler( handler = MultimodalDecodeWorkerHandler(
runtime, runtime,
...@@ -289,7 +290,7 @@ class WorkerFactory: ...@@ -289,7 +290,7 @@ class WorkerFactory:
config, generate_endpoint, vllm_config config, generate_endpoint, vllm_config
) )
if kv_publisher: if kv_publisher:
handler.kv_publisher = kv_publisher handler.kv_publisher = kv_publisher # type: ignore[attr-defined, union-attr]
if not config.multimodal_decode_worker: if not config.multimodal_decode_worker:
model_type = parse_endpoint_types(config.endpoint_types) model_type = parse_endpoint_types(config.endpoint_types)
...@@ -357,7 +358,7 @@ class WorkerFactory: ...@@ -357,7 +358,7 @@ class WorkerFactory:
shutdown_endpoints[:] = [generate_endpoint] shutdown_endpoints[:] = [generate_endpoint]
handler = EncodeWorkerHandler( handler = EncodeWorkerHandler(
config.engine_args, config.embedding_transfer_mode config.engine_args, config.embedding_transfer_mode # type: ignore[arg-type]
) )
await handler.async_init(runtime) await handler.async_init(runtime)
logger.info("Starting to serve the encode worker endpoint...") logger.info("Starting to serve the encode worker endpoint...")
......
...@@ -13,7 +13,9 @@ filelock==3.25.1 ...@@ -13,7 +13,9 @@ filelock==3.25.1
kr8s==0.20.13 kr8s==0.20.13
kubernetes_asyncio==32.0.0 kubernetes_asyncio==32.0.0
matplotlib==3.10.7 matplotlib==3.10.7
matplotlib-stubs
mistral-common==1.9.1 mistral-common==1.9.1
mypy==1.18.2
# For NATS object store verification in router tests # For NATS object store verification in router tests
nats-py==2.12.0 nats-py==2.12.0
psutil<=7.0.0 # System package, may vary by platform (was >=5.0.0) psutil<=7.0.0 # System package, may vary by platform (was >=5.0.0)
...@@ -27,7 +29,6 @@ pytest-cov==7.0.0 ...@@ -27,7 +29,6 @@ pytest-cov==7.0.0
pytest-forked==1.6.0 pytest-forked==1.6.0
pytest-httpserver==1.1.3 pytest-httpserver==1.1.3
pytest-md-report==0.7.0 pytest-md-report==0.7.0
pytest-mypy==1.0.1
pytest-order==1.3.0 pytest-order==1.3.0
pytest-timeout==2.4.0 pytest-timeout==2.4.0
pytest-xdist==3.8.0 pytest-xdist==3.8.0
......
...@@ -2,17 +2,7 @@ ...@@ -2,17 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import asyncio import asyncio
from typing import ( from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Tuple
Any,
AsyncGenerator,
AsyncIterator,
Awaitable,
Callable,
Dict,
List,
Optional,
Tuple,
)
# Import from specialized modules # Import from specialized modules
from .prometheus_metrics import RuntimeMetrics as PyRuntimeMetrics from .prometheus_metrics import RuntimeMetrics as PyRuntimeMetrics
...@@ -31,14 +21,14 @@ def get_reasoning_parser_names() -> list[str]: ...@@ -31,14 +21,14 @@ def get_reasoning_parser_names() -> list[str]:
"""Get list of available reasoning parser names.""" """Get list of available reasoning parser names."""
... ...
class JsonLike: def run_kv_indexer(args: List[str]) -> None:
""" """Run the KV indexer with the given arguments."""
Any PyObject which can be serialized to JSON
"""
... ...
RequestHandler = Callable[[JsonLike], AsyncGenerator[JsonLike, None]] # Any Python object that can be serialized to JSON (dict, list, str, int, etc.)
JsonLike = Any
RequestHandler = Callable[..., AsyncIterator[JsonLike]]
class DistributedRuntime: class DistributedRuntime:
""" """
...@@ -473,8 +463,6 @@ class ModelRuntimeConfig: ...@@ -473,8 +463,6 @@ class ModelRuntimeConfig:
enable_local_indexer: bool enable_local_indexer: bool
runtime_data: dict[str, Any] runtime_data: dict[str, Any]
tensor_model_config: Any | None tensor_model_config: Any | None
data_parallel_size: int
data_parallel_start_rank: int
bootstrap_host: str | None bootstrap_host: str | None
bootstrap_port: int | None bootstrap_port: int | None
......
...@@ -149,8 +149,8 @@ minversion = "8.0" ...@@ -149,8 +149,8 @@ minversion = "8.0"
tmp_path_retention_policy = "failed" tmp_path_retention_policy = "failed"
# NOTE # NOTE
# We ignore model.py explicitly here to avoid mypy errors with duplicate modules # Keep these ignores in pytest collection to avoid duplicate-module collection
# pytest overrides the default mypy exclude configuration and so we exclude here as well # errors (for example, backend trees that include multiple model.py files).
addopts = [ addopts = [
"-ra", "-ra",
"--showlocals", "--showlocals",
...@@ -292,17 +292,33 @@ venv = ".venv" ...@@ -292,17 +292,33 @@ venv = ".venv"
# tensorrt_llm and vllm are both named common. # tensorrt_llm and vllm are both named common.
explicit_package_bases = true explicit_package_bases = true
# --ignore-missing-imports: WAR too many errors when developing outside check_untyped_defs = true
# of container environment with PYTHONPATH set and packages installed.
# NOTE: Can possibly move mypy from pre-commit to a github action run only in [[tool.mypy.overrides]]
# a container with the expected environment and PYTHONPATH setup. # _version.py is generated at build time and does not exist in the source tree.
module = ["dynamo.*._version"]
ignore_missing_imports = true ignore_missing_imports = true
check_untyped_defs = true [[tool.mypy.overrides]]
# Skip type checking for test files.
module = ["dynamo.*.tests.*", "dynamo.*.tests"]
ignore_errors = true
[[tool.mypy.overrides]]
# Skip mypy analysis on backend framework internals.
# ignore_missing_imports silences import-not-found only when the backend
# is not installed (e.g. sglang/trtllm missing in the vllm container).
module = ["vllm", "vllm.*"]
follow_imports = "skip"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = ["sglang", "sglang.*"]
follow_imports = "skip"
ignore_missing_imports = true
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
# Skip mypy analysis on internal dependencies of vllm module = ["tensorrt_llm", "tensorrt_llm.*"]
module = ["vllm.*"]
follow_imports = "skip" follow_imports = "skip"
ignore_missing_imports = true ignore_missing_imports = true
...@@ -312,6 +328,95 @@ ignore_missing_imports = true ...@@ -312,6 +328,95 @@ ignore_missing_imports = true
module = ["numpy", "numpy.*"] module = ["numpy", "numpy.*"]
follow_imports = "skip" follow_imports = "skip"
[[tool.mypy.overrides]]
# Third-party libs without type stubs or optional internal deps
# TODO: fix the ones that do have stub package
module = [
"nvtx",
"fsspec",
"fsspec.*",
"kubernetes",
"kubernetes.*",
"scipy",
"scipy.*",
"sklearn",
"sklearn.*",
"pandas",
"pandas.*",
"pmdarima",
"pmdarima.*",
"filterpy",
"filterpy.*",
"prophet",
"prophet.*",
"msgpack",
"nixl",
"nixl.*",
"imageio",
"imageio.*",
"yaml",
"prometheus_api_client",
"prometheus_api_client.*",
"aiohttp",
"aiohttp.*",
"vllm_omni",
"vllm_omni.*",
"modelexpress",
"modelexpress.*",
"kvbm",
"kvbm.*",
"diffusers",
"diffusers.*",
"PIL",
"PIL.*",
"torch",
"torch.*",
"transformers",
"transformers.*",
"cupy",
"cupy.*",
"gpu_memory_service",
"gpu_memory_service.*",
"pydantic",
"pydantic.*",
"uvloop",
"prometheus_client",
"prometheus_client.*",
"pybase64",
"blake3",
"cupy_backends",
"cupy_backends.*",
"huggingface_hub",
"huggingface_hub.*",
"httpx",
"httpx.*",
"zmq",
"zmq.*",
"safetensors",
"safetensors.*",
"gradio",
"gradio.*",
"kubernetes_asyncio",
"kubernetes_asyncio.*",
"pydantic_core",
"aiconfigurator",
"aiconfigurator.*",
]
ignore_missing_imports = true
[[tool.mypy.overrides]]
# msgspec.Struct uses custom __init_subclass__ kwargs (frozen, gc) that mypy
# cannot resolve without the msgspec package installed.
module = ["msgspec", "msgspec.*"]
follow_imports = "skip"
ignore_missing_imports = true
[[tool.mypy.overrides]]
# Profiler module was never previously type-checked and has many
# union-attr / attr-defined issues. Skip errors for now.
module = ["dynamo.profiler.*"]
ignore_errors = true
[tool.sphinx] [tool.sphinx]
# extra-content-head # extra-content-head
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment