"lib/bindings/vscode:/vscode.git/clone" did not exist on "c22280ccf40e9143ba85eacebd168a1d205eab06"
Unverified Commit 869e7336 authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

chore: add mypy typing for components/src/dynamo/common (#6721)

parent 9780bf3a
......@@ -156,7 +156,7 @@ def get_config_dump(config: Any, extra_info: Optional[Dict[str, Any]] = None) ->
return canonical_json_encoder.encode(error_info)
def add_config_dump_args(parser: argparse.ArgumentParser):
def add_config_dump_args(parser: argparse.ArgumentParser) -> None:
"""
Add arguments to the parser to dump the config to a file.
......@@ -206,7 +206,7 @@ def _preprocess_for_encode(obj: object) -> object:
return str(obj)
def register_encoder(type_class):
def register_encoder(type_class: type) -> Any:
"""
Decorator to register custom encoders for specific types.
......
......@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
"""Base ArgGroup interface."""
import argparse
from abc import ABC, abstractmethod
......@@ -13,7 +14,7 @@ class ArgGroup(ABC):
"""
@abstractmethod
def add_arguments(self, parser) -> None:
def add_arguments(self, parser: argparse.ArgumentParser) -> None:
"""
Register CLI arguments owned by this group.
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
from typing import Self
class ConfigBase:
"""Base configuration class that allows properties with and without defaults in arbitrary order."""
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
def from_cli_args(cls, args: argparse.Namespace) -> Self:
obj = cls.__new__(cls)
# 1) Set everything provided by argparse
......
......@@ -3,6 +3,7 @@
"""Dynamo runtime configuration ArgGroup."""
import argparse
from typing import List, Optional
from dynamo._core import get_reasoning_parser_names, get_tool_parser_names
......@@ -63,7 +64,7 @@ class DynamoRuntimeConfig(ConfigBase):
class DynamoRuntimeArgGroup(ArgGroup):
"""Dynamo runtime configuration parameters (common to all backends)."""
def add_arguments(self, parser) -> None:
def add_arguments(self, parser: argparse.ArgumentParser) -> None:
"""Add Dynamo runtime arguments to parser."""
g = parser.add_argument_group("Dynamo Runtime Options")
......
......@@ -55,7 +55,7 @@ def env_or_default(
def add_argument(
parser,
parser: argparse.ArgumentParser | argparse._ArgumentGroup,
*,
flag_name: str,
env_var: str,
......@@ -109,7 +109,7 @@ def add_argument(
def add_negatable_bool_argument(
parser,
parser: Any,
*,
flag_name: str,
env_var: str,
......
......@@ -49,7 +49,7 @@ class LoRAManager:
# Extension point: custom sources
self._custom_sources: Dict[str, LoRASourceProtocol] = {}
def register_custom_source(self, scheme: str, source: LoRASourceProtocol):
def register_custom_source(self, scheme: str, source: LoRASourceProtocol) -> None:
"""
Register a custom Python source for a URI scheme.
......
......@@ -11,7 +11,7 @@ import time
import uuid
from abc import ABC, abstractmethod
from queue import Queue
from typing import Any, List, Optional
from typing import Any, Awaitable, List, Optional
import msgpack
import torch
......@@ -79,7 +79,7 @@ class AbstractEmbeddingReceiver(ABC):
pass
@abstractmethod
def release_tensor(self, tensor_id: int):
def release_tensor(self, tensor_id: int) -> None:
"""
Abstract method to indicate that the tensor associated with the ID is no longer in use.
Args:
......@@ -96,7 +96,7 @@ class AbstractEmbeddingSender(ABC):
@abstractmethod
async def send_embeddings(
self, embeddings: torch.Tensor, stage_embeddings: bool = False
) -> tuple[TransferRequest, asyncio.Future]:
) -> tuple[TransferRequest, Awaitable[None]]:
"""
Abstract method to send precomputed embeddings for a given request ID.
......@@ -105,7 +105,7 @@ class AbstractEmbeddingSender(ABC):
stage_embeddings: A boolean indicating whether the embeddings should be staged for the transfer,
if True, the embeddings may be used as transfer buffer and must not be released until the return future is completed.
Returns:
A tuple containing the TransferRequest object and a future that can be awaited to indicate the send is completed.
A tuple containing the TransferRequest object and an awaitable that can be awaited to indicate the send is completed.
"""
pass
......@@ -145,7 +145,7 @@ class LocalEmbeddingSender(AbstractEmbeddingSender):
@_nvtx.annotate("mm:local:send_embeddings", color="magenta")
async def send_embeddings(
self, embeddings: torch.Tensor, stage_embeddings: bool = False
) -> tuple[TransferRequest, asyncio.Future]:
) -> tuple[TransferRequest, Awaitable[None]]:
"""
Send precomputed embeddings for a given request ID.
......@@ -154,7 +154,7 @@ class LocalEmbeddingSender(AbstractEmbeddingSender):
stage_embeddings: A boolean indicating whether the embeddings should be staged for the transfer,
if True, the embeddings may be used as transfer buffer and must not be released until the return future is completed.
Returns:
A tuple containing the TransferRequest object and a future that can be awaited to indicate the send is completed.
A tuple containing the TransferRequest object and an awaitable that can be awaited to indicate the send is completed.
"""
# Implementation to send embeddings to the downstream worker
# This could involve publishing to a message queue or making an API call
......@@ -209,7 +209,7 @@ class LocalEmbeddingReceiver(AbstractEmbeddingReceiver):
self.received_tensors[tensor_id] = tensor_path
return tensor_id, embedding_tensor
def release_tensor(self, tensor_id: int):
def release_tensor(self, tensor_id: int) -> None:
"""
Indicate that the tensor associated with the ID is no longer in use.
......@@ -400,7 +400,7 @@ class NixlWriteEmbeddingSender(AbstractEmbeddingSender):
# Background transfer task..
# Create a queue hinting whether the sender is expecting future transfer
self.transfer_queue = asyncio.Queue()
self.transfer_queue: asyncio.Queue[str] = asyncio.Queue()
self._state_update_task = asyncio.create_task(self._state_update())
self.transfer_timeout = 60 # seconds, can be tuned based on expected transfer time and network condition
......@@ -571,7 +571,7 @@ class NixlWriteEmbeddingSender(AbstractEmbeddingSender):
stage_embeddings: A boolean indicating whether the embeddings should be staged for the transfer,
if True, the embeddings may be used as transfer buffer and must not be released until the return future is completed.
Returns:
A tuple containing the TransferRequest object and a future that can be awaited to indicate the send is completed.
A tuple containing the TransferRequest object and an awaitable that can be awaited to indicate the send is completed.
"""
tensor_id = self.id_counter.get_next_id()
fut = asyncio.get_event_loop().create_future()
......@@ -754,7 +754,7 @@ class NixlWriteEmbeddingReceiver(AbstractEmbeddingReceiver):
self.to_buffer_id[tensor_id] = buffer_id
return tensor_id, embedding_tensor
def release_tensor(self, tensor_id: int):
def release_tensor(self, tensor_id: int) -> None:
"""
Indicate that the tensor associated with the ID is no longer in use.
......@@ -789,7 +789,7 @@ def remote_release_overwrite(self) -> None:
pass
nixl_connect.Remote._release = remote_release_overwrite
nixl_connect.Remote._release = remote_release_overwrite # type: ignore[method-assign]
class NixlReadEmbeddingSender(AbstractEmbeddingSender):
......@@ -809,7 +809,7 @@ class NixlReadEmbeddingSender(AbstractEmbeddingSender):
@_nvtx.annotate("mm:nixl:send_embeddings", color="magenta")
async def send_embeddings(
self, embeddings: torch.Tensor, stage_embeddings: bool = False
) -> tuple[TransferRequest, asyncio.Future]:
) -> tuple[TransferRequest, Awaitable[None]]:
"""
Send precomputed embeddings.
......@@ -819,7 +819,7 @@ class NixlReadEmbeddingSender(AbstractEmbeddingSender):
if True, the embeddings may be used as transfer buffer and must not be released until the return future is completed.
if False, the sender will copy the embeddings.
Returns:
A tuple containing the TransferRequest object and a future that can be awaited to indicate the send is completed.
A tuple containing the TransferRequest object and an awaitable that can be awaited to indicate the send is completed.
"""
if stage_embeddings:
transfer_buf = embeddings
......@@ -851,15 +851,18 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver):
"""
def __init__(
self, embedding_hidden_size=8 * 1024, max_item_mm_token=1024, max_items=1024
):
self,
embedding_hidden_size: int = 8 * 1024,
max_item_mm_token: int = 1024,
max_items: int = 1024,
) -> None:
super().__init__()
self.connector = PersistentConnector()
self.tensor_id_counter = 0
self.aggregated_op_create_time = 0
self.aggregated_op_wait_time = 0
self.warmedup_descriptors = Queue()
self.inuse_descriptors = {}
self.warmedup_descriptors: Queue[nixl_connect.Descriptor] = Queue()
self.inuse_descriptors: dict[int, tuple[nixl_connect.Descriptor, bool]] = {}
# Handle both sync and async contexts
try:
asyncio.get_running_loop() # Check if we're in async context
......@@ -917,6 +920,7 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver):
original_descriptor_size = descriptor._data_size
tensor_size_bytes = embeddings_dtype.itemsize * math.prod(embeddings_shape)
descriptor._data_size = tensor_size_bytes
assert descriptor._data_ref is not None
encodings_tensor = (
descriptor._data_ref[:tensor_size_bytes]
.view(dtype=embeddings_dtype)
......@@ -940,7 +944,7 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver):
self.inuse_descriptors[tensor_id] = (descriptor, dynamic_descriptor)
return tensor_id, encodings_tensor
def release_tensor(self, tensor_id: int):
def release_tensor(self, tensor_id: int) -> None:
"""
Indicate that the tensor associated with the ID is no longer in use.
......
......@@ -36,7 +36,7 @@ def parse_endpoint_types(endpoint_types_str: str) -> ModelType:
if not types:
raise ValueError("No valid endpoint types provided")
result = None
result: ModelType | None = None
for t in types:
if t == "chat":
flag = ModelType.Chat
......@@ -49,4 +49,6 @@ def parse_endpoint_types(endpoint_types_str: str) -> ModelType:
result = flag if result is None else result | flag
# `types` is validated as non-empty above, so result is guaranteed to be set.
assert result is not None
return result
......@@ -5,7 +5,9 @@ import asyncio
import logging
import os
import signal
from typing import Iterable, Optional
from typing import Any, Iterable, Optional
from dynamo._core import DistributedRuntime
logger = logging.getLogger(__name__)
......@@ -62,7 +64,7 @@ async def _unregister_endpoints(endpoints: Iterable) -> None:
async def graceful_shutdown_with_discovery(
runtime,
runtime: DistributedRuntime,
endpoints: Iterable,
shutdown_event: Optional[asyncio.Event] = None,
grace_period_s: Optional[float] = None,
......@@ -90,7 +92,7 @@ async def graceful_shutdown_with_discovery(
def install_signal_handlers(
loop: asyncio.AbstractEventLoop,
runtime,
runtime: Any,
endpoints: Iterable,
shutdown_event: Optional[asyncio.Event] = None,
grace_period_s: Optional[float] = None,
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional
class InputParamManager:
def __init__(self, tokenizer):
def __init__(self, tokenizer: Any) -> None:
self.tokenizer = tokenizer
def get_input_param(self, request: dict, use_tokenizer: bool):
def get_input_param(self, request: dict, use_tokenizer: bool) -> Optional[Any]:
"""
Get the input parameter for the request.
"""
......
......@@ -19,7 +19,7 @@ async def read_decoded_media_via_nixl(
connector: nixl_connect.Connector,
decoded_meta: Dict[str, Any],
return_metadata: bool = False,
) -> np.ndarray | Tuple[np.ndarray, Dict[str, Any]]:
) -> np.ndarray | Tuple[np.ndarray, Dict[str, Any] | None]:
"""
Read pre-decoded media data via NIXL RDMA transfer, into a CPU numpy array.
......
......@@ -185,7 +185,7 @@ def _compile_include_pattern(metric_prefixes: tuple[str, ...]) -> Pattern:
def get_prometheus_expfmt(
registry,
registry: "CollectorRegistry",
metric_prefix_filters: Optional[list[str]] = None,
exclude_prefixes: Optional[list[str]] = None,
inject_custom_labels: Optional[dict[str, str]] = None,
......@@ -310,7 +310,12 @@ class LLMBackendMetrics:
metrics.set_model_load_time(5.2)
"""
def __init__(self, registry=None, model_name: str = "", component_name: str = ""):
def __init__(
self,
registry: Optional["CollectorRegistry"] = None,
model_name: str = "",
component_name: str = "",
) -> None:
"""Create all Dynamo component gauges."""
from prometheus_client import Gauge
......
......@@ -62,7 +62,7 @@ def compute_num_frames(
return default_num_frames
def normalize_video_frames(images) -> list:
def normalize_video_frames(images: list) -> list:
"""Normalize stage_output.images into a frame list for export_to_video.
Args:
......@@ -140,7 +140,7 @@ def encode_to_mp4(
import imageio.v3 as iio
except ImportError:
try:
import imageio as iio
import imageio as iio # type: ignore[no-redef]
except ImportError:
raise ImportError(
"imageio is required for video encoding. "
......@@ -160,7 +160,7 @@ def encode_to_mp4(
iio.imwrite(output_path, frames, fps=fps, codec="libx264")
else:
# Fall back to v2 API
writer = iio.get_writer(output_path, fps=fps, codec="libx264")
writer = iio.get_writer(output_path, fps=fps, codec="libx264") # type: ignore[attr-defined]
try:
for frame in frames:
writer.append_data(frame)
......@@ -197,7 +197,7 @@ def encode_to_mp4_bytes(
import imageio.v3 as iio
except ImportError:
try:
import imageio as iio
import imageio as iio # type: ignore[no-redef]
except ImportError:
raise ImportError(
"imageio is required for video encoding. "
......@@ -216,7 +216,7 @@ def encode_to_mp4_bytes(
iio.imwrite(buffer, frames, extension=".mp4", fps=fps, codec="libx264")
else:
# v2 API
writer = iio.get_writer(
writer = iio.get_writer( # type: ignore[attr-defined]
buffer, format="FFMPEG", mode="I", fps=fps, codec="libx264"
)
try:
......
......@@ -8,6 +8,7 @@ import signal
from collections import defaultdict
from typing import Any, Awaitable, Callable, DefaultDict
from dynamo._core import DistributedRuntime
from dynamo.common.utils.graceful_shutdown import graceful_shutdown_with_discovery
SignalCallback = Callable[..., Any]
......@@ -15,8 +16,8 @@ SignalCallback = Callable[..., Any]
def install_graceful_shutdown(
loop: asyncio.AbstractEventLoop,
runtime: Any,
endpoints: list,
runtime: DistributedRuntime,
endpoints: list[str],
shutdown_event: asyncio.Event,
*,
signals: tuple[int, ...] = (signal.SIGTERM, signal.SIGINT),
......
......@@ -24,6 +24,14 @@ def log_message(level: str, message: str, module: str, file: str, line: int) ->
"""
...
def get_tool_parser_names() -> list[str]:
"""Get list of available tool parser names."""
...
def get_reasoning_parser_names() -> list[str]:
"""Get list of available reasoning parser names."""
...
class JsonLike:
"""
Any PyObject which can be serialized to JSON
......@@ -938,6 +946,9 @@ class ModelType:
Audios: ModelType
Videos: ModelType
def __or__(self, other: "ModelType") -> "ModelType":
...
def supports_chat(self) -> bool:
"""Return True if this model type supports chat."""
...
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment