Unverified Commit c01f6e52 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[CI] Fix mypy for `vllm/v1/core` and `vllm/v1/engine` (#27108)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent c7d2a554
...@@ -36,12 +36,15 @@ FILES = [ ...@@ -36,12 +36,15 @@ FILES = [
"vllm/transformers_utils", "vllm/transformers_utils",
"vllm/triton_utils", "vllm/triton_utils",
"vllm/usage", "vllm/usage",
"vllm/v1/core",
"vllm/v1/engine",
] ]
# After fixing errors resulting from changing follow_imports # After fixing errors resulting from changing follow_imports
# from "skip" to "silent", move the following directories to FILES # from "skip" to "silent", move the following directories to FILES
SEPARATE_GROUPS = [ SEPARATE_GROUPS = [
"tests", "tests",
# v0 related
"vllm/attention", "vllm/attention",
"vllm/compilation", "vllm/compilation",
"vllm/engine", "vllm/engine",
...@@ -50,7 +53,16 @@ SEPARATE_GROUPS = [ ...@@ -50,7 +53,16 @@ SEPARATE_GROUPS = [
"vllm/model_executor", "vllm/model_executor",
"vllm/plugins", "vllm/plugins",
"vllm/worker", "vllm/worker",
"vllm/v1", # v1 related
"vllm/v1/attention",
"vllm/v1/executor",
"vllm/v1/kv_offload",
"vllm/v1/metrics",
"vllm/v1/pool",
"vllm/v1/sample",
"vllm/v1/spec_decode",
"vllm/v1/structured_output",
"vllm/v1/worker",
] ]
# TODO(woosuk): Include the code from Megatron and HuggingFace. # TODO(woosuk): Include the code from Megatron and HuggingFace.
......
...@@ -84,7 +84,9 @@ class VllmConfig: ...@@ -84,7 +84,9 @@ class VllmConfig:
default_factory=StructuredOutputsConfig default_factory=StructuredOutputsConfig
) )
"""Structured outputs configuration.""" """Structured outputs configuration."""
observability_config: ObservabilityConfig | None = None observability_config: ObservabilityConfig = Field(
default_factory=ObservabilityConfig
)
"""Observability configuration.""" """Observability configuration."""
quant_config: QuantizationConfig | None = None quant_config: QuantizationConfig | None = None
"""Quantization configuration.""" """Quantization configuration."""
...@@ -170,10 +172,7 @@ class VllmConfig: ...@@ -170,10 +172,7 @@ class VllmConfig:
vllm_factors.append(self.structured_outputs_config.compute_hash()) vllm_factors.append(self.structured_outputs_config.compute_hash())
else: else:
vllm_factors.append("None") vllm_factors.append("None")
if self.observability_config:
vllm_factors.append(self.observability_config.compute_hash()) vllm_factors.append(self.observability_config.compute_hash())
else:
vllm_factors.append("None")
if self.quant_config: if self.quant_config:
pass # should be captured by model_config.quantization pass # should be captured by model_config.quantization
if self.compilation_config: if self.compilation_config:
......
...@@ -77,6 +77,7 @@ class EngineClient(ABC): ...@@ -77,6 +77,7 @@ class EngineClient(ABC):
lora_request: LoRARequest | None = None, lora_request: LoRARequest | None = None,
trace_headers: Mapping[str, str] | None = None, trace_headers: Mapping[str, str] | None = None,
priority: int = 0, priority: int = 0,
truncate_prompt_tokens: int | None = None,
tokenization_kwargs: dict[str, Any] | None = None, tokenization_kwargs: dict[str, Any] | None = None,
) -> AsyncGenerator[PoolingRequestOutput, None]: ) -> AsyncGenerator[PoolingRequestOutput, None]:
"""Generate outputs for a request from a pooling model.""" """Generate outputs for a request from a pooling model."""
......
...@@ -167,7 +167,7 @@ class Scheduler(SchedulerInterface): ...@@ -167,7 +167,7 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager = KVCacheManager( self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
enable_caching=self.cache_config.enable_prefix_caching, enable_caching=bool(self.cache_config.enable_prefix_caching),
use_eagle=self.use_eagle, use_eagle=self.use_eagle,
log_stats=self.log_stats, log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events, enable_kv_cache_events=self.enable_kv_cache_events,
...@@ -407,13 +407,13 @@ class Scheduler(SchedulerInterface): ...@@ -407,13 +407,13 @@ class Scheduler(SchedulerInterface):
# Get externally-cached tokens if using a KVConnector. # Get externally-cached tokens if using a KVConnector.
if self.connector is not None: if self.connector is not None:
num_external_computed_tokens, load_kv_async = ( ext_tokens, load_kv_async = (
self.connector.get_num_new_matched_tokens( self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens request, num_new_local_computed_tokens
) )
) )
if num_external_computed_tokens is None: if ext_tokens is None:
# The request cannot be scheduled because # The request cannot be scheduled because
# the KVConnector couldn't determine # the KVConnector couldn't determine
# the number of matched tokens. # the number of matched tokens.
...@@ -421,6 +421,8 @@ class Scheduler(SchedulerInterface): ...@@ -421,6 +421,8 @@ class Scheduler(SchedulerInterface):
skipped_waiting_requests.prepend_request(request) skipped_waiting_requests.prepend_request(request)
continue continue
num_external_computed_tokens = ext_tokens
# Total computed tokens (local + external). # Total computed tokens (local + external).
num_computed_tokens = ( num_computed_tokens = (
num_new_local_computed_tokens + num_external_computed_tokens num_new_local_computed_tokens + num_external_computed_tokens
...@@ -905,13 +907,13 @@ class Scheduler(SchedulerInterface): ...@@ -905,13 +907,13 @@ class Scheduler(SchedulerInterface):
outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list) outputs: dict[int, list[EngineCoreOutput]] = defaultdict(list)
spec_decoding_stats: SpecDecodingStats | None = None spec_decoding_stats: SpecDecodingStats | None = None
kv_connector_stats = ( kv_connector_stats: KVConnectorStats | None = (
kv_connector_output.kv_connector_stats if kv_connector_output else None kv_connector_output.kv_connector_stats if kv_connector_output else None
) )
if kv_connector_stats and self.connector: if kv_connector_stats and self.connector:
stats = self.connector.get_kv_connector_stats() kv_stats = self.connector.get_kv_connector_stats()
if stats: if kv_stats:
kv_connector_stats = kv_connector_stats.aggregate(stats) kv_connector_stats = kv_connector_stats.aggregate(kv_stats)
failed_kv_load_req_ids = None failed_kv_load_req_ids = None
if kv_connector_output and kv_connector_output.invalid_block_ids: if kv_connector_output and kv_connector_output.invalid_block_ids:
......
...@@ -6,7 +6,7 @@ import socket ...@@ -6,7 +6,7 @@ import socket
import time import time
from collections.abc import AsyncGenerator, Iterable, Mapping from collections.abc import AsyncGenerator, Iterable, Mapping
from copy import copy from copy import copy
from typing import Any from typing import Any, cast
import numpy as np import numpy as np
import torch import torch
...@@ -131,10 +131,9 @@ class AsyncLLM(EngineClient): ...@@ -131,10 +131,9 @@ class AsyncLLM(EngineClient):
self.output_processor = OutputProcessor( self.output_processor = OutputProcessor(
self.tokenizer, log_stats=self.log_stats self.tokenizer, log_stats=self.log_stats
) )
if self.observability_config.otlp_traces_endpoint is not None: endpoint = self.observability_config.otlp_traces_endpoint
tracer = init_tracer( if endpoint is not None:
"vllm.llm_engine", self.observability_config.otlp_traces_endpoint tracer = init_tracer("vllm.llm_engine", endpoint)
)
self.output_processor.tracer = tracer self.output_processor.tracer = tracer
# EngineCore (starts the engine in background process). # EngineCore (starts the engine in background process).
...@@ -266,7 +265,9 @@ class AsyncLLM(EngineClient): ...@@ -266,7 +265,9 @@ class AsyncLLM(EngineClient):
if engine_core := getattr(self, "engine_core", None): if engine_core := getattr(self, "engine_core", None):
engine_core.shutdown() engine_core.shutdown()
cancel_task_threadsafe(getattr(self, "output_handler", None)) handler = getattr(self, "output_handler", None)
if handler is not None:
cancel_task_threadsafe(handler)
async def get_supported_tasks(self) -> tuple[SupportedTask, ...]: async def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return await self.engine_core.get_supported_tasks_async() return await self.engine_core.get_supported_tasks_async()
...@@ -314,7 +315,10 @@ class AsyncLLM(EngineClient): ...@@ -314,7 +315,10 @@ class AsyncLLM(EngineClient):
priority, priority,
data_parallel_rank, data_parallel_rank,
) )
prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt") if isinstance(prompt, str):
prompt_text = prompt
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
if is_pooling or params.n == 1: if is_pooling or params.n == 1:
await self._add_request(request, prompt_text, None, 0, queue) await self._add_request(request, prompt_text, None, 0, queue)
...@@ -436,6 +440,7 @@ class AsyncLLM(EngineClient): ...@@ -436,6 +440,7 @@ class AsyncLLM(EngineClient):
# Note: both OutputProcessor and EngineCore handle their # Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished. # own request cleanup based on finished.
finished = out.finished finished = out.finished
assert isinstance(out, RequestOutput)
yield out yield out
# If the request is disconnected by the client, generate() # If the request is disconnected by the client, generate()
...@@ -653,7 +658,7 @@ class AsyncLLM(EngineClient): ...@@ -653,7 +658,7 @@ class AsyncLLM(EngineClient):
return self.tokenizer return self.tokenizer
async def is_tracing_enabled(self) -> bool: async def is_tracing_enabled(self) -> bool:
return self.observability_config.otlp_traces_endpoint is not None return self.observability_config.otlp_traces_endpoint is not None # type: ignore
async def do_log_stats(self) -> None: async def do_log_stats(self) -> None:
if self.logger_manager: if self.logger_manager:
......
...@@ -1075,6 +1075,7 @@ class DPEngineCoreProc(EngineCoreProc): ...@@ -1075,6 +1075,7 @@ class DPEngineCoreProc(EngineCoreProc):
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
assert dp_size > 1 assert dp_size > 1
assert local_dp_rank is not None
assert 0 <= local_dp_rank <= dp_rank < dp_size assert 0 <= local_dp_rank <= dp_rank < dp_size
if vllm_config.kv_transfer_config is not None: if vllm_config.kv_transfer_config is not None:
......
...@@ -385,9 +385,10 @@ class BackgroundResources: ...@@ -385,9 +385,10 @@ class BackgroundResources:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
task.cancel() task.cancel()
if loop is not None:
if in_loop(loop): if in_loop(loop):
close_sockets_and_tasks() close_sockets_and_tasks()
elif loop and not loop.is_closed(): elif not loop.is_closed():
loop.call_soon_threadsafe(close_sockets_and_tasks) loop.call_soon_threadsafe(close_sockets_and_tasks)
else: else:
# Loop has been closed, try to clean up directly. # Loop has been closed, try to clean up directly.
...@@ -1044,6 +1045,7 @@ class DPAsyncMPClient(AsyncMPClient): ...@@ -1044,6 +1045,7 @@ class DPAsyncMPClient(AsyncMPClient):
return return
assert self.stats_update_address is not None assert self.stats_update_address is not None
stats_addr: str = self.stats_update_address
assert len(self.engine_ranks_managed) > 0 assert len(self.engine_ranks_managed) > 0
# NOTE: running and waiting counts are all global from # NOTE: running and waiting counts are all global from
# the Coordinator include all global EngineCores. This # the Coordinator include all global EngineCores. This
...@@ -1054,9 +1056,7 @@ class DPAsyncMPClient(AsyncMPClient): ...@@ -1054,9 +1056,7 @@ class DPAsyncMPClient(AsyncMPClient):
async def run_engine_stats_update_task(): async def run_engine_stats_update_task():
with ( with (
make_zmq_socket( make_zmq_socket(self.ctx, stats_addr, zmq.XSUB, linger=0) as socket,
self.ctx, self.stats_update_address, zmq.XSUB, linger=0
) as socket,
make_zmq_socket( make_zmq_socket(
self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=False, linger=0 self.ctx, self.first_req_sock_addr, zmq.PAIR, bind=False, linger=0
) as first_req_rcv_socket, ) as first_req_rcv_socket,
......
...@@ -69,14 +69,21 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC): ...@@ -69,14 +69,21 @@ class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
# Stop strings # Stop strings
params = request.sampling_params params = request.sampling_params
assert params is not None assert params is not None
self.stop = stop = params.stop stop_list: list[str]
if params.stop is None:
stop_list = []
elif isinstance(params.stop, str):
stop_list = [params.stop]
else:
stop_list = params.stop
self.stop = stop_list
self.min_tokens = params.min_tokens self.min_tokens = params.min_tokens
self.include_stop_str_in_output = params.include_stop_str_in_output self.include_stop_str_in_output = params.include_stop_str_in_output
# Number of chars to hold back when stop strings are to be excluded # Number of chars to hold back when stop strings are to be excluded
# from streamed output. # from streamed output.
if stop and not self.include_stop_str_in_output: if self.stop and not self.include_stop_str_in_output:
self.stop_buffer_length = max(len(s) for s in stop) - 1 self.stop_buffer_length = max(len(s) for s in self.stop) - 1
else: else:
self.stop_buffer_length = 0 self.stop_buffer_length = 0
self._last_output_text_offset: int = 0 self._last_output_text_offset: int = 0
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import time import time
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from copy import copy from copy import copy
from typing import Any from typing import Any, cast
import torch.nn as nn import torch.nn as nn
from typing_extensions import TypeVar from typing_extensions import TypeVar
...@@ -112,10 +112,9 @@ class LLMEngine: ...@@ -112,10 +112,9 @@ class LLMEngine:
self.output_processor = OutputProcessor( self.output_processor = OutputProcessor(
self.tokenizer, log_stats=self.log_stats self.tokenizer, log_stats=self.log_stats
) )
if self.observability_config.otlp_traces_endpoint is not None: endpoint = self.observability_config.otlp_traces_endpoint
tracer = init_tracer( if endpoint is not None:
"vllm.llm_engine", self.observability_config.otlp_traces_endpoint tracer = init_tracer("vllm.llm_engine", endpoint)
)
self.output_processor.tracer = tracer self.output_processor.tracer = tracer
# EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs) # EngineCore (gets EngineCoreRequests and gives EngineCoreOutputs)
...@@ -259,7 +258,10 @@ class LLMEngine: ...@@ -259,7 +258,10 @@ class LLMEngine:
trace_headers, trace_headers,
priority, priority,
) )
prompt_text = prompt if isinstance(prompt, str) else prompt.get("prompt") if isinstance(prompt, str):
prompt_text = prompt
elif isinstance(prompt, Mapping):
prompt_text = cast(str | None, prompt.get("prompt"))
n = params.n if isinstance(params, SamplingParams) else 1 n = params.n if isinstance(params, SamplingParams) else 1
...@@ -285,7 +287,7 @@ class LLMEngine: ...@@ -285,7 +287,7 @@ class LLMEngine:
# Add the request to EngineCore. # Add the request to EngineCore.
self.engine_core.add_request(child_request) self.engine_core.add_request(child_request)
def step(self) -> list[RequestOutput] | list[PoolingRequestOutput]: def step(self) -> list[RequestOutput | PoolingRequestOutput]:
if self.should_execute_dummy_batch: if self.should_execute_dummy_batch:
self.should_execute_dummy_batch = False self.should_execute_dummy_batch = False
self.engine_core.execute_dummy_batch() self.engine_core.execute_dummy_batch()
......
...@@ -44,10 +44,16 @@ class RequestOutputCollector: ...@@ -44,10 +44,16 @@ class RequestOutputCollector:
if self.output is None or isinstance(output, Exception): if self.output is None or isinstance(output, Exception):
self.output = output self.output = output
self.ready.set() self.ready.set()
elif isinstance(self.output, (RequestOutput, PoolingRequestOutput)): elif isinstance(self.output, RequestOutput) and isinstance(
output, RequestOutput
):
# This ensures that request outputs with different request indexes # This ensures that request outputs with different request indexes
# (if n > 1) do not override each other. # (if n > 1) do not override each other.
self.output.add(output, aggregate=self.aggregate) self.output.add(output, aggregate=self.aggregate)
elif isinstance(self.output, PoolingRequestOutput) and isinstance(
output, PoolingRequestOutput
):
self.output = output
async def get(self) -> RequestOutput | PoolingRequestOutput: async def get(self) -> RequestOutput | PoolingRequestOutput:
"""Get operation blocks on put event.""" """Get operation blocks on put event."""
...@@ -408,7 +414,7 @@ class OutputProcessor: ...@@ -408,7 +414,7 @@ class OutputProcessor:
within the loop below. within the loop below.
""" """
request_outputs: list[RequestOutput] | list[PoolingRequestOutput] = [] request_outputs: list[RequestOutput | PoolingRequestOutput] = []
reqs_to_abort: list[str] = [] reqs_to_abort: list[str] = []
for engine_core_output in engine_core_outputs: for engine_core_output in engine_core_outputs:
req_id = engine_core_output.request_id req_id = engine_core_output.request_id
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from copy import copy from copy import copy
from typing import Optional from typing import Optional, cast
from vllm.outputs import CompletionOutput from vllm.outputs import CompletionOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams from vllm.sampling_params import RequestOutputKind, SamplingParams
...@@ -37,7 +37,7 @@ class ParentRequest: ...@@ -37,7 +37,7 @@ class ParentRequest:
self.child_requests = set() self.child_requests = set()
self.output_aggregator = ( self.output_aggregator = (
[None] * sampling_params.n [cast(CompletionOutput, None)] * sampling_params.n
if (sampling_params.output_kind == RequestOutputKind.FINAL_ONLY) if (sampling_params.output_kind == RequestOutputKind.FINAL_ONLY)
else [] else []
) )
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import time import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Literal from typing import Any, Literal, cast
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
...@@ -208,9 +208,9 @@ class Processor: ...@@ -208,9 +208,9 @@ class Processor:
enc = prompt.get("encoder_prompt") enc = prompt.get("encoder_prompt")
dec = prompt.get("decoder_prompt") dec = prompt.get("decoder_prompt")
if enc is not None: if enc is not None:
_validate_single_prompt(enc) _validate_single_prompt(cast(dict | str, enc))
if dec is not None: if dec is not None:
_validate_single_prompt(dec) _validate_single_prompt(cast(dict | str, dec))
else: else:
_validate_single_prompt(prompt) # type: ignore[arg-type] _validate_single_prompt(prompt) # type: ignore[arg-type]
...@@ -332,7 +332,7 @@ class Processor: ...@@ -332,7 +332,7 @@ class Processor:
if not mm_data: if not mm_data:
return None return None
mm_uuids: MultiModalUUIDDict = {} mm_uuids: dict[str, list[str | None] | str] = {}
for modality, data in mm_data.items(): for modality, data in mm_data.items():
n = len(data) if isinstance(data, list) else 1 n = len(data) if isinstance(data, list) else 1
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)] mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
...@@ -384,7 +384,9 @@ class Processor: ...@@ -384,7 +384,9 @@ class Processor:
# if provided. # if provided.
self._validate_multi_modal_uuids(prompt) self._validate_multi_modal_uuids(prompt)
if isinstance(prompt, dict): if isinstance(prompt, dict):
mm_uuids = prompt.get("multi_modal_uuids") mm_uuids = cast(
MultiModalUUIDDict | None, prompt.get("multi_modal_uuids")
)
else: else:
mm_uuids = None mm_uuids = None
...@@ -410,20 +412,13 @@ class Processor: ...@@ -410,20 +412,13 @@ class Processor:
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs) encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
self._validate_model_inputs(encoder_inputs, decoder_inputs) self._validate_model_inputs(encoder_inputs, decoder_inputs)
# Mypy does not always properly infer the types of some elements of # Mypy can be conservative for TypedDict unions; normalize access.
# discriminated unions of TypedDicts, because of how it handles if decoder_inputs["type"] == "embeds":
# inheritance of TypedDict. If we explicitly extract the items we want prompt_token_ids = None
# we can avoid type errors from using `dict.get` later in the method. prompt_embeds = decoder_inputs["prompt_embeds"]
prompt_token_ids = ( else:
decoder_inputs["prompt_token_ids"] prompt_token_ids = decoder_inputs["prompt_token_ids"]
if decoder_inputs["type"] != "embeds" prompt_embeds = None
else None
)
prompt_embeds = (
decoder_inputs["prompt_embeds"]
if decoder_inputs["type"] == "embeds"
else None
)
sampling_params = None sampling_params = None
pooling_params = None pooling_params = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment