"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "189860102539b54098cfa04b6381ee86c53a16c1"
Unverified Commit 2831bfec authored by jh-nv's avatar jh-nv Committed by GitHub
Browse files

chore: add mypy typing to vllm (#6858)

parent 0f01e724
...@@ -29,7 +29,7 @@ class VllmEngineMonitor: ...@@ -29,7 +29,7 @@ class VllmEngineMonitor:
self, self,
runtime: DistributedRuntime, runtime: DistributedRuntime,
engine_client: AsyncLLM, engine_client: AsyncLLM,
shutdown_event: asyncio.Event = None, shutdown_event: asyncio.Event | None = None,
): ):
if not isinstance(runtime, DistributedRuntime): if not isinstance(runtime, DistributedRuntime):
raise ValueError( raise ValueError(
......
...@@ -261,7 +261,7 @@ def build_sampling_params_openai( ...@@ -261,7 +261,7 @@ def build_sampling_params_openai(
return sampling_params return sampling_params
def get_dp_range_for_worker(vllm_config: VllmConfig) -> range: def get_dp_range_for_worker(vllm_config: VllmConfig) -> tuple[int, int]:
""" """
Get the global DP rank range that this worker is responsible for based on vLLM config. Get the global DP rank range that this worker is responsible for based on vLLM config.
Note that the 'vllm_config' is normalized so the load balancing flags are set properly. Note that the 'vllm_config' is normalized so the load balancing flags are set properly.
...@@ -318,7 +318,7 @@ class BaseWorkerHandler(ABC): ...@@ -318,7 +318,7 @@ class BaseWorkerHandler(ABC):
self.enable_multimodal = enable_multimodal self.enable_multimodal = enable_multimodal
self.enable_frontend_decoding = enable_frontend_decoding self.enable_frontend_decoding = enable_frontend_decoding
# NIXL connector for frontend decoding - lazy initialized # NIXL connector for frontend decoding - lazy initialized
self._nixl_connector = None self._nixl_connector: nixl_connect.Connector | None = None
self._nixl_connector_lock = asyncio.Lock() self._nixl_connector_lock = asyncio.Lock()
# LoRA tracking: name -> LoRAInfo(id, path) # LoRA tracking: name -> LoRAInfo(id, path)
self.loaded_loras: dict[str, LoRAInfo] = {} self.loaded_loras: dict[str, LoRAInfo] = {}
......
...@@ -7,7 +7,7 @@ import logging ...@@ -7,7 +7,7 @@ import logging
import os import os
import tempfile import tempfile
import time import time
from typing import Optional from typing import Any, Optional
import uvloop import uvloop
from prometheus_client import REGISTRY, CollectorRegistry, multiprocess from prometheus_client import REGISTRY, CollectorRegistry, multiprocess
...@@ -37,17 +37,6 @@ from dynamo.llm import ( ...@@ -37,17 +37,6 @@ from dynamo.llm import (
fetch_model, fetch_model,
register_model, register_model,
) )
# Optional imports for frontend decoding support
try:
from dynamo.llm import MediaDecoder, MediaFetcher
MEDIA_DECODER_AVAILABLE = True
except ImportError:
MediaDecoder = None
MediaFetcher = None
MEDIA_DECODER_AVAILABLE = False
from dynamo.runtime import DistributedRuntime, Endpoint from dynamo.runtime import DistributedRuntime, Endpoint
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.vllm.worker_factory import WorkerFactory from dynamo.vllm.worker_factory import WorkerFactory
...@@ -63,6 +52,18 @@ from .health_check import ( ...@@ -63,6 +52,18 @@ from .health_check import (
) )
from .publisher import DYNAMO_COMPONENT_REGISTRY, StatLoggerFactory from .publisher import DYNAMO_COMPONENT_REGISTRY, StatLoggerFactory
# Optional imports for frontend decoding support
MediaDecoder: type | None = None
MediaFetcher: type | None = None
try:
from dynamo.llm import MediaDecoder, MediaFetcher
MEDIA_DECODER_AVAILABLE = True
except ImportError:
MediaDecoder = None
MediaFetcher = None
MEDIA_DECODER_AVAILABLE = False
configure_dynamo_logging() configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
shutdown_endpoints: list = [] shutdown_endpoints: list = []
...@@ -93,7 +94,7 @@ def run_dynamo_headless(config: Config) -> None: ...@@ -93,7 +94,7 @@ def run_dynamo_headless(config: Config) -> None:
run_headless(args) run_headless(args)
async def worker(): async def worker() -> None:
config = parse_args() config = parse_args()
dump_config(config.dump_config_to, config) dump_config(config.dump_config_to, config)
...@@ -198,7 +199,9 @@ async def worker(): ...@@ -198,7 +199,9 @@ async def worker():
logger.debug("Worker function completed, exiting...") logger.debug("Worker function completed, exiting...")
def setup_metrics_collection(config: Config, generate_endpoint, logger): def setup_metrics_collection(
config: Config, generate_endpoint: Endpoint, logger: logging.Logger
) -> None:
"""Set up metrics collection for vLLM and LMCache metrics. """Set up metrics collection for vLLM and LMCache metrics.
In multiprocess mode (PROMETHEUS_MULTIPROC_DIR set), metrics are stored: In multiprocess mode (PROMETHEUS_MULTIPROC_DIR set), metrics are stored:
...@@ -294,8 +297,9 @@ def setup_kv_event_publisher( ...@@ -294,8 +297,9 @@ def setup_kv_event_publisher(
vllm_config: VllmConfig, vllm_config: VllmConfig,
consolidator_enabled: bool = False, consolidator_enabled: bool = False,
consolidator_port: Optional[int] = 5558, consolidator_port: Optional[int] = 5558,
) -> Optional[KvEventPublisher]: ) -> Optional[list[KvEventPublisher]]:
""" """
list[KvEventPublisher] | None
Set up KV event publishers for prefix caching if enabled. Set up KV event publishers for prefix caching if enabled.
Creates one publisher per dp_rank since each dp_rank publishes to a different port. Creates one publisher per dp_rank since each dp_rank publishes to a different port.
Args: Args:
...@@ -365,7 +369,9 @@ def setup_kv_event_publisher( ...@@ -365,7 +369,9 @@ def setup_kv_event_publisher(
return kv_publishers if kv_publishers else None return kv_publishers if kv_publishers else None
def setup_vllm_engine(config, stat_logger=None): def setup_vllm_engine(
config: Config, stat_logger: Optional[StatLoggerFactory] = None
) -> tuple[AsyncLLM, VllmConfig, Any, Any, LLMBackendMetrics]:
# vLLM v0.11.0 bug: vllm/v1.metrics/prometheus.py:79 passes TemporaryDirectory object # vLLM v0.11.0 bug: vllm/v1.metrics/prometheus.py:79 passes TemporaryDirectory object
# instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR # instead of .name string, causing false error on exit. Set PROMETHEUS_MULTIPROC_DIR
# ourselves to avoid this and handle cleanup properly. # ourselves to avoid this and handle cleanup properly.
...@@ -511,11 +517,11 @@ def setup_vllm_engine(config, stat_logger=None): ...@@ -511,11 +517,11 @@ def setup_vllm_engine(config, stat_logger=None):
async def register_vllm_model( async def register_vllm_model(
model_input: ModelInput, model_input: ModelInput,
model_type: ModelType, model_type: ModelType,
generate_endpoint, generate_endpoint: Endpoint,
config: Config, config: Config,
engine_client: AsyncLLM, engine_client: AsyncLLM,
vllm_config: VllmConfig, vllm_config: VllmConfig,
): ) -> None:
""" """
Helper function to register a vLLM model with runtime configuration. Helper function to register a vLLM model with runtime configuration.
...@@ -563,6 +569,7 @@ async def register_vllm_model( ...@@ -563,6 +569,7 @@ async def register_vllm_model(
"--frontend-decoding requires MediaDecoder support. " "--frontend-decoding requires MediaDecoder support. "
"Ensure dynamo.llm module includes MediaDecoder and MediaFetcher." "Ensure dynamo.llm module includes MediaDecoder and MediaFetcher."
) )
assert MediaDecoder is not None and MediaFetcher is not None
media_decoder = MediaDecoder() media_decoder = MediaDecoder()
media_decoder.enable_image({"limits": {"max_alloc": 128 * 1024 * 1024}}) media_decoder.enable_image({"limits": {"max_alloc": 128 * 1024 * 1024}})
# media_decoder.enable_video({}) # media_decoder.enable_video({})
...@@ -590,8 +597,10 @@ async def init_prefill( ...@@ -590,8 +597,10 @@ async def init_prefill(
runtime: DistributedRuntime, runtime: DistributedRuntime,
config: Config, config: Config,
shutdown_event: asyncio.Event, shutdown_event: asyncio.Event,
checkpoint_restore_engine=None, checkpoint_restore_engine: Optional[
): tuple[AsyncLLM, VllmConfig, Any, Any, LLMBackendMetrics]
] = None,
) -> None:
""" """
Instantiate and serve Instantiate and serve
""" """
...@@ -690,7 +699,7 @@ async def init_prefill( ...@@ -690,7 +699,7 @@ async def init_prefill(
# (long-term reason): prefill engine should pull from a global queue so there is # (long-term reason): prefill engine should pull from a global queue so there is
# only a few in-flight requests that can be quickly finished # only a few in-flight requests that can be quickly finished
generate_endpoint.serve_endpoint( generate_endpoint.serve_endpoint(
handler.generate, handler.generate, # type: ignore
graceful_shutdown=True, graceful_shutdown=True,
# In practice config.served_model_name is always set, but mypy needs the "or" here. # In practice config.served_model_name is always set, but mypy needs the "or" here.
metrics_labels=[ metrics_labels=[
...@@ -706,10 +715,16 @@ async def init_prefill( ...@@ -706,10 +715,16 @@ async def init_prefill(
health_check_payload=health_check_payload, health_check_payload=health_check_payload,
), ),
clear_endpoint.serve_endpoint( clear_endpoint.serve_endpoint(
handler.clear_kv_blocks, handler.clear_kv_blocks, # type: ignore
metrics_labels=[ metrics_labels=[
(prometheus_names.labels.MODEL, config.served_model_name), (
(prometheus_names.labels.MODEL_NAME, config.served_model_name), prometheus_names.labels.MODEL,
config.served_model_name or config.model,
),
(
prometheus_names.labels.MODEL_NAME,
config.served_model_name or config.model,
),
], ],
), ),
) )
...@@ -726,8 +741,10 @@ async def init( ...@@ -726,8 +741,10 @@ async def init(
runtime: DistributedRuntime, runtime: DistributedRuntime,
config: Config, config: Config,
shutdown_event: asyncio.Event, shutdown_event: asyncio.Event,
checkpoint_restore_engine=None, checkpoint_restore_engine: Optional[
): tuple[AsyncLLM, VllmConfig, Any, Any, LLMBackendMetrics]
] = None,
) -> None:
""" """
Instantiate and serve Instantiate and serve
""" """
...@@ -886,7 +903,7 @@ async def init( ...@@ -886,7 +903,7 @@ async def init(
# for decode, we want to transfer the in-flight requests to other decode engines, # for decode, we want to transfer the in-flight requests to other decode engines,
# because waiting them to finish can take a long time for long OSLs # because waiting them to finish can take a long time for long OSLs
generate_endpoint.serve_endpoint( generate_endpoint.serve_endpoint(
handler.generate, handler.generate, # type: ignore
graceful_shutdown=True, graceful_shutdown=True,
metrics_labels=model_metrics_labels, metrics_labels=model_metrics_labels,
health_check_payload=health_check_payload, health_check_payload=health_check_payload,
...@@ -926,7 +943,7 @@ async def init( ...@@ -926,7 +943,7 @@ async def init(
handler.cleanup() handler.cleanup()
def get_engine_cache_info(engine: AsyncLLM): def get_engine_cache_info(engine: AsyncLLM) -> dict[str, Any]:
"""Retrieve cache configuration information from [`AsyncLLM`] engine.""" """Retrieve cache configuration information from [`AsyncLLM`] engine."""
try: try:
...@@ -956,7 +973,7 @@ def get_engine_cache_info(engine: AsyncLLM): ...@@ -956,7 +973,7 @@ def get_engine_cache_info(engine: AsyncLLM):
async def init_omni( async def init_omni(
runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event runtime: DistributedRuntime, config: Config, shutdown_event: asyncio.Event
): ) -> None:
"""Initialize Omni worker for multi-stage pipeline generation using vLLM-Omni. """Initialize Omni worker for multi-stage pipeline generation using vLLM-Omni.
Supports text-to-text, text-to-image, and text-to-video generation Supports text-to-text, text-to-image, and text-to-video generation
...@@ -1034,7 +1051,7 @@ async def init_omni( ...@@ -1034,7 +1051,7 @@ async def init_omni(
handler.cleanup() handler.cleanup()
def main(): def main() -> None:
uvloop.run(worker()) uvloop.run(worker())
......
...@@ -6,7 +6,7 @@ import logging ...@@ -6,7 +6,7 @@ import logging
import os import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import AsyncIterator from typing import Any, AsyncIterator
import torch import torch
from transformers import AutoImageProcessor from transformers import AutoImageProcessor
...@@ -80,7 +80,7 @@ class EncodeWorkerHandler: ...@@ -80,7 +80,7 @@ class EncodeWorkerHandler:
self._connector: connect.Connector | None = None self._connector: connect.Connector | None = None
self._accumulated_time = 0.0 self._accumulated_time = 0.0
self._processed_requests = 0 self._processed_requests = 0
self.readables = [] self.readables: list[Any] = []
self.embedding_cache = EmbeddingCache() if ENABLE_ENCODER_CACHE else None self.embedding_cache = EmbeddingCache() if ENABLE_ENCODER_CACHE else None
if embedding_transfer_mode == EmbeddingTransferMode.LOCAL: if embedding_transfer_mode == EmbeddingTransferMode.LOCAL:
self.embedding_sender = LocalEmbeddingSender() self.embedding_sender = LocalEmbeddingSender()
...@@ -93,7 +93,7 @@ class EncodeWorkerHandler: ...@@ -93,7 +93,7 @@ class EncodeWorkerHandler:
f"Invalid embedding transfer mode: {embedding_transfer_mode}" f"Invalid embedding transfer mode: {embedding_transfer_mode}"
) )
self.send_complete_queue = asyncio.Queue() self.send_complete_queue: asyncio.Queue[tuple[Any, Any]] = asyncio.Queue()
self.send_complete_checker_task = asyncio.create_task( self.send_complete_checker_task = asyncio.create_task(
self.check_complete(self.send_complete_queue) self.check_complete(self.send_complete_queue)
) )
...@@ -150,7 +150,9 @@ class EncodeWorkerHandler: ...@@ -150,7 +150,9 @@ class EncodeWorkerHandler:
with _nvtx.annotate("mm:enc:cache_check", color="cyan"): with _nvtx.annotate("mm:enc:cache_check", color="cyan"):
# Before batch process images, check cache first # Before batch process images, check cache first
need_encode_indexes = [] need_encode_indexes = []
embedding_lists = [None] * len(request.multimodal_inputs) embedding_lists: list[EmbeddingItem | None] = [None] * len(
request.multimodal_inputs
)
for idx in range(len(request.multimodal_inputs)): for idx in range(len(request.multimodal_inputs)):
if not request.multimodal_inputs[idx].multimodal_input.image_url: if not request.multimodal_inputs[idx].multimodal_input.image_url:
raise ValueError("image_url is required for the encode worker.") raise ValueError("image_url is required for the encode worker.")
...@@ -251,16 +253,16 @@ class EncodeWorkerHandler: ...@@ -251,16 +253,16 @@ class EncodeWorkerHandler:
for split_idx, (list_idx, key) in enumerate(need_encode_indexes): for split_idx, (list_idx, key) in enumerate(need_encode_indexes):
embedding_lists[list_idx] = EmbeddingItem( embedding_lists[list_idx] = EmbeddingItem(
key, key,
[image_grid_thw[split_idx]] if image_grid_thw else None, [image_grid_thw[split_idx]] if image_grid_thw else [],
splitted_embeddings[split_idx].unsqueeze(0), splitted_embeddings[split_idx].unsqueeze(0),
) )
# Cache the computed value for future use # Cache the computed value for future use
if self.embedding_cache is not None: if self.embedding_cache is not None:
self.embedding_cache.set( self.embedding_cache.set(
embedding_lists[list_idx].key, embedding_lists[list_idx].key, # type: ignore
( (
embedding_lists[list_idx].image_grid_thw, embedding_lists[list_idx].image_grid_thw, # type: ignore
embedding_lists[list_idx].embeddings, embedding_lists[list_idx].embeddings, # type: ignore
), ),
) )
...@@ -275,6 +277,7 @@ class EncodeWorkerHandler: ...@@ -275,6 +277,7 @@ class EncodeWorkerHandler:
) )
) )
for embedding_item in embedding_lists for embedding_item in embedding_lists
if embedding_item is not None
] ]
transfer_requests = await asyncio.gather(*send_tasks) transfer_requests = await asyncio.gather(*send_tasks)
...@@ -282,6 +285,7 @@ class EncodeWorkerHandler: ...@@ -282,6 +285,7 @@ class EncodeWorkerHandler:
for idx, item in enumerate(zip(embedding_lists, transfer_requests)): for idx, item in enumerate(zip(embedding_lists, transfer_requests)):
embedding_item, transfer_request = item embedding_item, transfer_request = item
assert embedding_item is not None
logger.debug( logger.debug(
f"{embedding_item.embeddings.shape} prepared for transfer." f"{embedding_item.embeddings.shape} prepared for transfer."
) )
......
...@@ -371,10 +371,10 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler): ...@@ -371,10 +371,10 @@ class MultimodalPDWorkerHandler(BaseWorkerHandler):
num_output_tokens_so_far = 0 num_output_tokens_so_far = 0
async for ( async for (
decode_response decode_response
) in await self.decode_worker_client.round_robin( # type: ignore[union-attr] ) in await self.decode_worker_client.round_robin( # type: ignore
request.model_dump_json() request.model_dump_json()
): ):
output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore[attr-defined] output = MyRequestOutput.model_validate_json(decode_response.data()) # type: ignore
yield self._format_engine_output(output, num_output_tokens_so_far) yield self._format_engine_output(output, num_output_tokens_so_far)
if output.outputs: if output.outputs:
num_output_tokens_so_far = len(output.outputs[0].token_ids) num_output_tokens_so_far = len(output.outputs[0].token_ids)
......
...@@ -64,8 +64,8 @@ def get_qwen_image_features( ...@@ -64,8 +64,8 @@ def get_qwen_image_features(
if grid_thw is None: if grid_thw is None:
raise ValueError("grid_thw is not provided") raise ValueError("grid_thw is not provided")
grid_thw = grid_thw.tolist() grid_thw = grid_thw.tolist()
image_embeds = vision_encoder(pixel_values, grid_thw=grid_thw) image_features = vision_encoder(pixel_values, grid_thw=grid_thw)
return image_embeds return image_features
pixel_values = image_embeds["pixel_values"].to(vision_encoder.device) pixel_values = image_embeds["pixel_values"].to(vision_encoder.device)
......
...@@ -257,8 +257,10 @@ async def _fetch_embeddings( ...@@ -257,8 +257,10 @@ async def _fetch_embeddings(
) )
# ── 3. Update cache (no-op when cache is None) ────────────── # ── 3. Update cache (no-op when cache is None) ──────────────
for (idx, _url, key), group in zip(to_fetch, groups, strict=True): for (idx, _url, key), group in zip(to_fetch, groups, strict=True):
if cache is not None and key is not None: if cache is not None and key is not None:
assert group.loaded_embedding is not None
cache.set( cache.set(
key, key,
CachedEmbedding( CachedEmbedding(
...@@ -301,6 +303,7 @@ async def load_multimodal_embeddings( ...@@ -301,6 +303,7 @@ async def load_multimodal_embeddings(
multi_modal_data: Dict[str, Any] = defaultdict(list) multi_modal_data: Dict[str, Any] = defaultdict(list)
for group in groups: for group in groups:
assert group.loaded_embedding is not None
_accumulate_embeddings( _accumulate_embeddings(
multi_modal_data, multi_modal_data,
model, model,
......
...@@ -132,7 +132,7 @@ class BaseOmniHandler(BaseWorkerHandler): ...@@ -132,7 +132,7 @@ class BaseOmniHandler(BaseWorkerHandler):
request_id = context.id() request_id = context.id()
logger.debug(f"Omni Request ID: {request_id}") logger.debug(f"Omni Request ID: {request_id}")
async for chunk in self._generate_openai_mode(request, context, request_id): async for chunk in self._generate_openai_mode(request, context, request_id): # type: ignore
yield chunk yield chunk
async def _generate_openai_mode( async def _generate_openai_mode(
......
...@@ -413,6 +413,8 @@ class OmniHandler(BaseOmniHandler): ...@@ -413,6 +413,8 @@ class OmniHandler(BaseOmniHandler):
output = NvImagesResponse(created=int(time.time()), data=image_data_list) output = NvImagesResponse(created=int(time.time()), data=image_data_list)
return output.model_dump() return output.model_dump()
else:
return None
async def _format_video_chunk( async def _format_video_chunk(
self, self,
......
...@@ -46,7 +46,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase): ...@@ -46,7 +46,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
raise raise
# TODO: Remove this and pass as metadata through shared storage # TODO: Remove this and pass as metadata through shared storage
def set_num_gpu_block(self, num_blocks): def set_num_gpu_block(self, num_blocks: int) -> None:
self.num_gpu_block = num_blocks self.num_gpu_block = num_blocks
def record( def record(
...@@ -54,9 +54,9 @@ class DynamoStatLoggerPublisher(StatLoggerBase): ...@@ -54,9 +54,9 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
scheduler_stats: SchedulerStats, scheduler_stats: SchedulerStats,
iteration_stats: Optional[IterationStats], iteration_stats: Optional[IterationStats],
engine_idx: int = 0, engine_idx: int = 0,
*args, *args: object,
**kwargs, **kwargs: object,
): ) -> None:
active_decode_blocks = int(self.num_gpu_block * scheduler_stats.kv_cache_usage) active_decode_blocks = int(self.num_gpu_block * scheduler_stats.kv_cache_usage)
self.inner.publish(self.dp_rank, active_decode_blocks) self.inner.publish(self.dp_rank, active_decode_blocks)
...@@ -71,7 +71,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase): ...@@ -71,7 +71,7 @@ class DynamoStatLoggerPublisher(StatLoggerBase):
dp_rank_str, scheduler_stats.kv_cache_usage dp_rank_str, scheduler_stats.kv_cache_usage
) )
def init_publish(self): def init_publish(self) -> None:
self.inner.publish(self.dp_rank, 0) self.inner.publish(self.dp_rank, 0)
dp_rank_str = str(self.dp_rank) dp_rank_str = str(self.dp_rank)
self.component_gauges.set_total_blocks(dp_rank_str, 0) self.component_gauges.set_total_blocks(dp_rank_str, 0)
...@@ -112,10 +112,10 @@ class StatLoggerFactory: ...@@ -112,10 +112,10 @@ class StatLoggerFactory:
return self.create_stat_logger(dp_rank=dp_rank) return self.create_stat_logger(dp_rank=dp_rank)
# TODO Remove once we publish metadata to shared storage # TODO Remove once we publish metadata to shared storage
def set_num_gpu_blocks_all(self, num_blocks): def set_num_gpu_blocks_all(self, num_blocks: int) -> None:
if self.created_logger: if self.created_logger:
self.created_logger.set_num_gpu_block(num_blocks) self.created_logger.set_num_gpu_block(num_blocks)
def init_publish(self): def init_publish(self) -> None:
if self.created_logger: if self.created_logger:
self.created_logger.init_publish() self.created_logger.init_publish()
...@@ -180,7 +180,7 @@ class TestLoadMultimodalData: ...@@ -180,7 +180,7 @@ class TestLoadMultimodalData:
mock_client = MagicMock() mock_client = MagicMock()
handler = _make_handler(encode_worker_client=mock_client) handler = _make_handler(encode_worker_client=mock_client)
fake_mm_data = defaultdict(list, {"image": torch.randn(1, 10)}) fake_mm_data = defaultdict(list, {"image": torch.randn(1, 10)}) # type: ignore
with patch.object( with patch.object(
mod, mod,
"load_multimodal_embeddings", "load_multimodal_embeddings",
......
...@@ -29,7 +29,7 @@ def mock_handler(): ...@@ -29,7 +29,7 @@ def mock_handler():
pass pass
handler = MockHandler() handler = MockHandler()
handler._decode_prompt_embeds = BaseWorkerHandler._decode_prompt_embeds.__get__( handler._decode_prompt_embeds = BaseWorkerHandler._decode_prompt_embeds.__get__( # type: ignore
handler handler
) )
return handler return handler
......
...@@ -981,7 +981,7 @@ class ModelType: ...@@ -981,7 +981,7 @@ class ModelType:
Audios: ModelType Audios: ModelType
Videos: ModelType Videos: ModelType
def __or__(self, other: "ModelType") -> "ModelType": def __or__(self, other: ModelType) -> ModelType:
... ...
def supports_chat(self) -> bool: def supports_chat(self) -> bool:
...@@ -1091,6 +1091,8 @@ async def register_model( ...@@ -1091,6 +1091,8 @@ async def register_model(
runtime_config: Optional[ModelRuntimeConfig] = None, runtime_config: Optional[ModelRuntimeConfig] = None,
user_data: Optional[Dict[str, Any]] = None, user_data: Optional[Dict[str, Any]] = None,
custom_template_path: Optional[str] = None, custom_template_path: Optional[str] = None,
media_decoder: Optional[MediaDecoder] = None,
media_fetcher: Optional[MediaFetcher] = None,
lora_name: Optional[str] = None, lora_name: Optional[str] = None,
base_model_path: Optional[str] = None, base_model_path: Optional[str] = None,
) -> None: ) -> None:
...@@ -1649,6 +1651,8 @@ class PlannerDecision: ...@@ -1649,6 +1651,8 @@ class PlannerDecision:
-1 in any of those fields mean not set, usually because planner hasn't decided anything yet. -1 in any of those fields mean not set, usually because planner hasn't decided anything yet.
Call VirtualConnectorClient.complete(event) when action is completed. Call VirtualConnectorClient.complete(event) when action is completed.
""" """
num_prefill_workers: int
num_decode_workers: int
... ...
class VirtualConnectorCoordinator: class VirtualConnectorCoordinator:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment