Unverified Commit 869e7336 authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

chore: add mypy typing for components/src/dynamo/common (#6721)

parent 9780bf3a
...@@ -156,7 +156,7 @@ def get_config_dump(config: Any, extra_info: Optional[Dict[str, Any]] = None) -> ...@@ -156,7 +156,7 @@ def get_config_dump(config: Any, extra_info: Optional[Dict[str, Any]] = None) ->
return canonical_json_encoder.encode(error_info) return canonical_json_encoder.encode(error_info)
def add_config_dump_args(parser: argparse.ArgumentParser): def add_config_dump_args(parser: argparse.ArgumentParser) -> None:
""" """
Add arguments to the parser to dump the config to a file. Add arguments to the parser to dump the config to a file.
...@@ -206,7 +206,7 @@ def _preprocess_for_encode(obj: object) -> object: ...@@ -206,7 +206,7 @@ def _preprocess_for_encode(obj: object) -> object:
return str(obj) return str(obj)
def register_encoder(type_class): def register_encoder(type_class: type) -> Any:
""" """
Decorator to register custom encoders for specific types. Decorator to register custom encoders for specific types.
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Base ArgGroup interface.""" """Base ArgGroup interface."""
import argparse
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -13,7 +14,7 @@ class ArgGroup(ABC): ...@@ -13,7 +14,7 @@ class ArgGroup(ABC):
""" """
@abstractmethod @abstractmethod
def add_arguments(self, parser) -> None: def add_arguments(self, parser: argparse.ArgumentParser) -> None:
""" """
Register CLI arguments owned by this group. Register CLI arguments owned by this group.
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import argparse import argparse
from typing import Self
class ConfigBase: class ConfigBase:
"""Base configuration class that allows properties with and without defaults in arbitrary order.""" """Base configuration class that allows properties with and without defaults in arbitrary order."""
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace) -> Self:
obj = cls.__new__(cls) obj = cls.__new__(cls)
# 1) Set everything provided by argparse # 1) Set everything provided by argparse
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
"""Dynamo runtime configuration ArgGroup.""" """Dynamo runtime configuration ArgGroup."""
import argparse
from typing import List, Optional from typing import List, Optional
from dynamo._core import get_reasoning_parser_names, get_tool_parser_names from dynamo._core import get_reasoning_parser_names, get_tool_parser_names
...@@ -63,7 +64,7 @@ class DynamoRuntimeConfig(ConfigBase): ...@@ -63,7 +64,7 @@ class DynamoRuntimeConfig(ConfigBase):
class DynamoRuntimeArgGroup(ArgGroup): class DynamoRuntimeArgGroup(ArgGroup):
"""Dynamo runtime configuration parameters (common to all backends).""" """Dynamo runtime configuration parameters (common to all backends)."""
def add_arguments(self, parser) -> None: def add_arguments(self, parser: argparse.ArgumentParser) -> None:
"""Add Dynamo runtime arguments to parser.""" """Add Dynamo runtime arguments to parser."""
g = parser.add_argument_group("Dynamo Runtime Options") g = parser.add_argument_group("Dynamo Runtime Options")
......
...@@ -55,7 +55,7 @@ def env_or_default( ...@@ -55,7 +55,7 @@ def env_or_default(
def add_argument( def add_argument(
parser, parser: argparse.ArgumentParser | argparse._ArgumentGroup,
*, *,
flag_name: str, flag_name: str,
env_var: str, env_var: str,
...@@ -109,7 +109,7 @@ def add_argument( ...@@ -109,7 +109,7 @@ def add_argument(
def add_negatable_bool_argument( def add_negatable_bool_argument(
parser, parser: Any,
*, *,
flag_name: str, flag_name: str,
env_var: str, env_var: str,
......
...@@ -49,7 +49,7 @@ class LoRAManager: ...@@ -49,7 +49,7 @@ class LoRAManager:
# Extension point: custom sources # Extension point: custom sources
self._custom_sources: Dict[str, LoRASourceProtocol] = {} self._custom_sources: Dict[str, LoRASourceProtocol] = {}
def register_custom_source(self, scheme: str, source: LoRASourceProtocol): def register_custom_source(self, scheme: str, source: LoRASourceProtocol) -> None:
""" """
Register a custom Python source for a URI scheme. Register a custom Python source for a URI scheme.
......
...@@ -11,7 +11,7 @@ import time ...@@ -11,7 +11,7 @@ import time
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from queue import Queue from queue import Queue
from typing import Any, List, Optional from typing import Any, Awaitable, List, Optional
import msgpack import msgpack
import torch import torch
...@@ -79,7 +79,7 @@ class AbstractEmbeddingReceiver(ABC): ...@@ -79,7 +79,7 @@ class AbstractEmbeddingReceiver(ABC):
pass pass
@abstractmethod @abstractmethod
def release_tensor(self, tensor_id: int): def release_tensor(self, tensor_id: int) -> None:
""" """
Abstract method to indicate that the tensor associated with the ID is no longer in use. Abstract method to indicate that the tensor associated with the ID is no longer in use.
Args: Args:
...@@ -96,7 +96,7 @@ class AbstractEmbeddingSender(ABC): ...@@ -96,7 +96,7 @@ class AbstractEmbeddingSender(ABC):
@abstractmethod @abstractmethod
async def send_embeddings( async def send_embeddings(
self, embeddings: torch.Tensor, stage_embeddings: bool = False self, embeddings: torch.Tensor, stage_embeddings: bool = False
) -> tuple[TransferRequest, asyncio.Future]: ) -> tuple[TransferRequest, Awaitable[None]]:
""" """
Abstract method to send precomputed embeddings for a given request ID. Abstract method to send precomputed embeddings for a given request ID.
...@@ -105,7 +105,7 @@ class AbstractEmbeddingSender(ABC): ...@@ -105,7 +105,7 @@ class AbstractEmbeddingSender(ABC):
stage_embeddings: A boolean indicating whether the embeddings should be staged for the transfer, stage_embeddings: A boolean indicating whether the embeddings should be staged for the transfer,
if True, the embeddings may be used as transfer buffer and must not be released until the return future is completed. if True, the embeddings may be used as transfer buffer and must not be released until the return future is completed.
Returns: Returns:
A tuple containing the TransferRequest object and a future that can be awaited to indicate the send is completed. A tuple containing the TransferRequest object and an awaitable that can be awaited to indicate the send is completed.
""" """
pass pass
...@@ -145,7 +145,7 @@ class LocalEmbeddingSender(AbstractEmbeddingSender): ...@@ -145,7 +145,7 @@ class LocalEmbeddingSender(AbstractEmbeddingSender):
@_nvtx.annotate("mm:local:send_embeddings", color="magenta") @_nvtx.annotate("mm:local:send_embeddings", color="magenta")
async def send_embeddings( async def send_embeddings(
self, embeddings: torch.Tensor, stage_embeddings: bool = False self, embeddings: torch.Tensor, stage_embeddings: bool = False
) -> tuple[TransferRequest, asyncio.Future]: ) -> tuple[TransferRequest, Awaitable[None]]:
""" """
Send precomputed embeddings for a given request ID. Send precomputed embeddings for a given request ID.
...@@ -154,7 +154,7 @@ class LocalEmbeddingSender(AbstractEmbeddingSender): ...@@ -154,7 +154,7 @@ class LocalEmbeddingSender(AbstractEmbeddingSender):
stage_embeddings: A boolean indicating whether the embeddings should be staged for the transfer, stage_embeddings: A boolean indicating whether the embeddings should be staged for the transfer,
if True, the embeddings may be used as transfer buffer and must not be released until the return future is completed. if True, the embeddings may be used as transfer buffer and must not be released until the return future is completed.
Returns: Returns:
A tuple containing the TransferRequest object and a future that can be awaited to indicate the send is completed. A tuple containing the TransferRequest object and an awaitable that can be awaited to indicate the send is completed.
""" """
# Implementation to send embeddings to the downstream worker # Implementation to send embeddings to the downstream worker
# This could involve publishing to a message queue or making an API call # This could involve publishing to a message queue or making an API call
...@@ -209,7 +209,7 @@ class LocalEmbeddingReceiver(AbstractEmbeddingReceiver): ...@@ -209,7 +209,7 @@ class LocalEmbeddingReceiver(AbstractEmbeddingReceiver):
self.received_tensors[tensor_id] = tensor_path self.received_tensors[tensor_id] = tensor_path
return tensor_id, embedding_tensor return tensor_id, embedding_tensor
def release_tensor(self, tensor_id: int): def release_tensor(self, tensor_id: int) -> None:
""" """
Indicate that the tensor associated with the ID is no longer in use. Indicate that the tensor associated with the ID is no longer in use.
...@@ -400,7 +400,7 @@ class NixlWriteEmbeddingSender(AbstractEmbeddingSender): ...@@ -400,7 +400,7 @@ class NixlWriteEmbeddingSender(AbstractEmbeddingSender):
# Background transfer task.. # Background transfer task..
# Create a queue hinting whether the sender is expecting future transfer # Create a queue hinting whether the sender is expecting future transfer
self.transfer_queue = asyncio.Queue() self.transfer_queue: asyncio.Queue[str] = asyncio.Queue()
self._state_update_task = asyncio.create_task(self._state_update()) self._state_update_task = asyncio.create_task(self._state_update())
self.transfer_timeout = 60 # seconds, can be tuned based on expected transfer time and network condition self.transfer_timeout = 60 # seconds, can be tuned based on expected transfer time and network condition
...@@ -571,7 +571,7 @@ class NixlWriteEmbeddingSender(AbstractEmbeddingSender): ...@@ -571,7 +571,7 @@ class NixlWriteEmbeddingSender(AbstractEmbeddingSender):
stage_embeddings: A boolean indicating whether the embeddings should be staged for the transfer, stage_embeddings: A boolean indicating whether the embeddings should be staged for the transfer,
if True, the embeddings may be used as transfer buffer and must not be released until the return future is completed. if True, the embeddings may be used as transfer buffer and must not be released until the return future is completed.
Returns: Returns:
A tuple containing the TransferRequest object and a future that can be awaited to indicate the send is completed. A tuple containing the TransferRequest object and an awaitable that can be awaited to indicate the send is completed.
""" """
tensor_id = self.id_counter.get_next_id() tensor_id = self.id_counter.get_next_id()
fut = asyncio.get_event_loop().create_future() fut = asyncio.get_event_loop().create_future()
...@@ -754,7 +754,7 @@ class NixlWriteEmbeddingReceiver(AbstractEmbeddingReceiver): ...@@ -754,7 +754,7 @@ class NixlWriteEmbeddingReceiver(AbstractEmbeddingReceiver):
self.to_buffer_id[tensor_id] = buffer_id self.to_buffer_id[tensor_id] = buffer_id
return tensor_id, embedding_tensor return tensor_id, embedding_tensor
def release_tensor(self, tensor_id: int): def release_tensor(self, tensor_id: int) -> None:
""" """
Indicate that the tensor associated with the ID is no longer in use. Indicate that the tensor associated with the ID is no longer in use.
...@@ -789,7 +789,7 @@ def remote_release_overwrite(self) -> None: ...@@ -789,7 +789,7 @@ def remote_release_overwrite(self) -> None:
pass pass
nixl_connect.Remote._release = remote_release_overwrite nixl_connect.Remote._release = remote_release_overwrite # type: ignore[method-assign]
class NixlReadEmbeddingSender(AbstractEmbeddingSender): class NixlReadEmbeddingSender(AbstractEmbeddingSender):
...@@ -809,7 +809,7 @@ class NixlReadEmbeddingSender(AbstractEmbeddingSender): ...@@ -809,7 +809,7 @@ class NixlReadEmbeddingSender(AbstractEmbeddingSender):
@_nvtx.annotate("mm:nixl:send_embeddings", color="magenta") @_nvtx.annotate("mm:nixl:send_embeddings", color="magenta")
async def send_embeddings( async def send_embeddings(
self, embeddings: torch.Tensor, stage_embeddings: bool = False self, embeddings: torch.Tensor, stage_embeddings: bool = False
) -> tuple[TransferRequest, asyncio.Future]: ) -> tuple[TransferRequest, Awaitable[None]]:
""" """
Send precomputed embeddings. Send precomputed embeddings.
...@@ -819,7 +819,7 @@ class NixlReadEmbeddingSender(AbstractEmbeddingSender): ...@@ -819,7 +819,7 @@ class NixlReadEmbeddingSender(AbstractEmbeddingSender):
if True, the embeddings may be used as transfer buffer and must not be released until the return future is completed. if True, the embeddings may be used as transfer buffer and must not be released until the return future is completed.
if False, the sender will copy the embeddings. if False, the sender will copy the embeddings.
Returns: Returns:
A tuple containing the TransferRequest object and a future that can be awaited to indicate the send is completed. A tuple containing the TransferRequest object and an awaitable that can be awaited to indicate the send is completed.
""" """
if stage_embeddings: if stage_embeddings:
transfer_buf = embeddings transfer_buf = embeddings
...@@ -851,15 +851,18 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver): ...@@ -851,15 +851,18 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver):
""" """
def __init__( def __init__(
self, embedding_hidden_size=8 * 1024, max_item_mm_token=1024, max_items=1024 self,
): embedding_hidden_size: int = 8 * 1024,
max_item_mm_token: int = 1024,
max_items: int = 1024,
) -> None:
super().__init__() super().__init__()
self.connector = PersistentConnector() self.connector = PersistentConnector()
self.tensor_id_counter = 0 self.tensor_id_counter = 0
self.aggregated_op_create_time = 0 self.aggregated_op_create_time = 0
self.aggregated_op_wait_time = 0 self.aggregated_op_wait_time = 0
self.warmedup_descriptors = Queue() self.warmedup_descriptors: Queue[nixl_connect.Descriptor] = Queue()
self.inuse_descriptors = {} self.inuse_descriptors: dict[int, tuple[nixl_connect.Descriptor, bool]] = {}
# Handle both sync and async contexts # Handle both sync and async contexts
try: try:
asyncio.get_running_loop() # Check if we're in async context asyncio.get_running_loop() # Check if we're in async context
...@@ -917,6 +920,7 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver): ...@@ -917,6 +920,7 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver):
original_descriptor_size = descriptor._data_size original_descriptor_size = descriptor._data_size
tensor_size_bytes = embeddings_dtype.itemsize * math.prod(embeddings_shape) tensor_size_bytes = embeddings_dtype.itemsize * math.prod(embeddings_shape)
descriptor._data_size = tensor_size_bytes descriptor._data_size = tensor_size_bytes
assert descriptor._data_ref is not None
encodings_tensor = ( encodings_tensor = (
descriptor._data_ref[:tensor_size_bytes] descriptor._data_ref[:tensor_size_bytes]
.view(dtype=embeddings_dtype) .view(dtype=embeddings_dtype)
...@@ -940,7 +944,7 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver): ...@@ -940,7 +944,7 @@ class NixlReadEmbeddingReceiver(AbstractEmbeddingReceiver):
self.inuse_descriptors[tensor_id] = (descriptor, dynamic_descriptor) self.inuse_descriptors[tensor_id] = (descriptor, dynamic_descriptor)
return tensor_id, encodings_tensor return tensor_id, encodings_tensor
def release_tensor(self, tensor_id: int): def release_tensor(self, tensor_id: int) -> None:
""" """
Indicate that the tensor associated with the ID is no longer in use. Indicate that the tensor associated with the ID is no longer in use.
......
...@@ -36,7 +36,7 @@ def parse_endpoint_types(endpoint_types_str: str) -> ModelType: ...@@ -36,7 +36,7 @@ def parse_endpoint_types(endpoint_types_str: str) -> ModelType:
if not types: if not types:
raise ValueError("No valid endpoint types provided") raise ValueError("No valid endpoint types provided")
result = None result: ModelType | None = None
for t in types: for t in types:
if t == "chat": if t == "chat":
flag = ModelType.Chat flag = ModelType.Chat
...@@ -49,4 +49,6 @@ def parse_endpoint_types(endpoint_types_str: str) -> ModelType: ...@@ -49,4 +49,6 @@ def parse_endpoint_types(endpoint_types_str: str) -> ModelType:
result = flag if result is None else result | flag result = flag if result is None else result | flag
# `types` is validated as non-empty above, so result is guaranteed to be set.
assert result is not None
return result return result
...@@ -5,7 +5,9 @@ import asyncio ...@@ -5,7 +5,9 @@ import asyncio
import logging import logging
import os import os
import signal import signal
from typing import Iterable, Optional from typing import Any, Iterable, Optional
from dynamo._core import DistributedRuntime
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -62,7 +64,7 @@ async def _unregister_endpoints(endpoints: Iterable) -> None: ...@@ -62,7 +64,7 @@ async def _unregister_endpoints(endpoints: Iterable) -> None:
async def graceful_shutdown_with_discovery( async def graceful_shutdown_with_discovery(
runtime, runtime: DistributedRuntime,
endpoints: Iterable, endpoints: Iterable,
shutdown_event: Optional[asyncio.Event] = None, shutdown_event: Optional[asyncio.Event] = None,
grace_period_s: Optional[float] = None, grace_period_s: Optional[float] = None,
...@@ -90,7 +92,7 @@ async def graceful_shutdown_with_discovery( ...@@ -90,7 +92,7 @@ async def graceful_shutdown_with_discovery(
def install_signal_handlers( def install_signal_handlers(
loop: asyncio.AbstractEventLoop, loop: asyncio.AbstractEventLoop,
runtime, runtime: Any,
endpoints: Iterable, endpoints: Iterable,
shutdown_event: Optional[asyncio.Event] = None, shutdown_event: Optional[asyncio.Event] = None,
grace_period_s: Optional[float] = None, grace_period_s: Optional[float] = None,
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional
class InputParamManager: class InputParamManager:
def __init__(self, tokenizer): def __init__(self, tokenizer: Any) -> None:
self.tokenizer = tokenizer self.tokenizer = tokenizer
def get_input_param(self, request: dict, use_tokenizer: bool): def get_input_param(self, request: dict, use_tokenizer: bool) -> Optional[Any]:
""" """
Get the input parameter for the request. Get the input parameter for the request.
""" """
......
...@@ -19,7 +19,7 @@ async def read_decoded_media_via_nixl( ...@@ -19,7 +19,7 @@ async def read_decoded_media_via_nixl(
connector: nixl_connect.Connector, connector: nixl_connect.Connector,
decoded_meta: Dict[str, Any], decoded_meta: Dict[str, Any],
return_metadata: bool = False, return_metadata: bool = False,
) -> np.ndarray | Tuple[np.ndarray, Dict[str, Any]]: ) -> np.ndarray | Tuple[np.ndarray, Dict[str, Any] | None]:
""" """
Read pre-decoded media data via NIXL RDMA transfer, into a CPU numpy array. Read pre-decoded media data via NIXL RDMA transfer, into a CPU numpy array.
......
...@@ -185,7 +185,7 @@ def _compile_include_pattern(metric_prefixes: tuple[str, ...]) -> Pattern: ...@@ -185,7 +185,7 @@ def _compile_include_pattern(metric_prefixes: tuple[str, ...]) -> Pattern:
def get_prometheus_expfmt( def get_prometheus_expfmt(
registry, registry: "CollectorRegistry",
metric_prefix_filters: Optional[list[str]] = None, metric_prefix_filters: Optional[list[str]] = None,
exclude_prefixes: Optional[list[str]] = None, exclude_prefixes: Optional[list[str]] = None,
inject_custom_labels: Optional[dict[str, str]] = None, inject_custom_labels: Optional[dict[str, str]] = None,
...@@ -310,7 +310,12 @@ class LLMBackendMetrics: ...@@ -310,7 +310,12 @@ class LLMBackendMetrics:
metrics.set_model_load_time(5.2) metrics.set_model_load_time(5.2)
""" """
def __init__(self, registry=None, model_name: str = "", component_name: str = ""): def __init__(
self,
registry: Optional["CollectorRegistry"] = None,
model_name: str = "",
component_name: str = "",
) -> None:
"""Create all Dynamo component gauges.""" """Create all Dynamo component gauges."""
from prometheus_client import Gauge from prometheus_client import Gauge
......
...@@ -62,7 +62,7 @@ def compute_num_frames( ...@@ -62,7 +62,7 @@ def compute_num_frames(
return default_num_frames return default_num_frames
def normalize_video_frames(images) -> list: def normalize_video_frames(images: list) -> list:
"""Normalize stage_output.images into a frame list for export_to_video. """Normalize stage_output.images into a frame list for export_to_video.
Args: Args:
...@@ -140,7 +140,7 @@ def encode_to_mp4( ...@@ -140,7 +140,7 @@ def encode_to_mp4(
import imageio.v3 as iio import imageio.v3 as iio
except ImportError: except ImportError:
try: try:
import imageio as iio import imageio as iio # type: ignore[no-redef]
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"imageio is required for video encoding. " "imageio is required for video encoding. "
...@@ -160,7 +160,7 @@ def encode_to_mp4( ...@@ -160,7 +160,7 @@ def encode_to_mp4(
iio.imwrite(output_path, frames, fps=fps, codec="libx264") iio.imwrite(output_path, frames, fps=fps, codec="libx264")
else: else:
# Fall back to v2 API # Fall back to v2 API
writer = iio.get_writer(output_path, fps=fps, codec="libx264") writer = iio.get_writer(output_path, fps=fps, codec="libx264") # type: ignore[attr-defined]
try: try:
for frame in frames: for frame in frames:
writer.append_data(frame) writer.append_data(frame)
...@@ -197,7 +197,7 @@ def encode_to_mp4_bytes( ...@@ -197,7 +197,7 @@ def encode_to_mp4_bytes(
import imageio.v3 as iio import imageio.v3 as iio
except ImportError: except ImportError:
try: try:
import imageio as iio import imageio as iio # type: ignore[no-redef]
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"imageio is required for video encoding. " "imageio is required for video encoding. "
...@@ -216,7 +216,7 @@ def encode_to_mp4_bytes( ...@@ -216,7 +216,7 @@ def encode_to_mp4_bytes(
iio.imwrite(buffer, frames, extension=".mp4", fps=fps, codec="libx264") iio.imwrite(buffer, frames, extension=".mp4", fps=fps, codec="libx264")
else: else:
# v2 API # v2 API
writer = iio.get_writer( writer = iio.get_writer( # type: ignore[attr-defined]
buffer, format="FFMPEG", mode="I", fps=fps, codec="libx264" buffer, format="FFMPEG", mode="I", fps=fps, codec="libx264"
) )
try: try:
......
...@@ -8,6 +8,7 @@ import signal ...@@ -8,6 +8,7 @@ import signal
from collections import defaultdict from collections import defaultdict
from typing import Any, Awaitable, Callable, DefaultDict from typing import Any, Awaitable, Callable, DefaultDict
from dynamo._core import DistributedRuntime
from dynamo.common.utils.graceful_shutdown import graceful_shutdown_with_discovery from dynamo.common.utils.graceful_shutdown import graceful_shutdown_with_discovery
SignalCallback = Callable[..., Any] SignalCallback = Callable[..., Any]
...@@ -15,8 +16,8 @@ SignalCallback = Callable[..., Any] ...@@ -15,8 +16,8 @@ SignalCallback = Callable[..., Any]
def install_graceful_shutdown( def install_graceful_shutdown(
loop: asyncio.AbstractEventLoop, loop: asyncio.AbstractEventLoop,
runtime: Any, runtime: DistributedRuntime,
endpoints: list, endpoints: list[str],
shutdown_event: asyncio.Event, shutdown_event: asyncio.Event,
*, *,
signals: tuple[int, ...] = (signal.SIGTERM, signal.SIGINT), signals: tuple[int, ...] = (signal.SIGTERM, signal.SIGINT),
......
...@@ -24,6 +24,14 @@ def log_message(level: str, message: str, module: str, file: str, line: int) -> ...@@ -24,6 +24,14 @@ def log_message(level: str, message: str, module: str, file: str, line: int) ->
""" """
... ...
def get_tool_parser_names() -> list[str]:
"""Get list of available tool parser names."""
...
def get_reasoning_parser_names() -> list[str]:
"""Get list of available reasoning parser names."""
...
class JsonLike: class JsonLike:
""" """
Any PyObject which can be serialized to JSON Any PyObject which can be serialized to JSON
...@@ -938,6 +946,9 @@ class ModelType: ...@@ -938,6 +946,9 @@ class ModelType:
Audios: ModelType Audios: ModelType
Videos: ModelType Videos: ModelType
def __or__(self, other: "ModelType") -> "ModelType":
...
def supports_chat(self) -> bool: def supports_chat(self) -> bool:
"""Return True if this model type supports chat.""" """Return True if this model type supports chat."""
... ...
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment