Unverified Commit 1c0c6820 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix per file ruff ignores related to typing (#26254)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 5f317530
......@@ -115,6 +115,7 @@ include = ["vllm*"]
"vllm/distributed/parallel_state.py" = ["SIM108"]
"vllm/entrypoints/chat_utils.py" = ["SIM108"]
"vllm/entrypoints/llm.py" = ["SIM108"]
"vllm/executor/ray_distributed_executor.py" = ["SIM108", "SIM112"]
"vllm/model_executor/layers/batch_invariant.py" = ["SIM108"]
"vllm/model_executor/layers/fla/ops/chunk_o.py" = ["SIM108"]
"vllm/model_executor/layers/fused_moe/fused_moe.py" = ["SIM108"]
......@@ -134,23 +135,6 @@ include = ["vllm*"]
"tools/profiler/print_layerwise_table.py" = ["SIM118"]
## Loop variable binding issues
"tests/kernels/mamba/test_mamba_ssm_ssd.py" = ["B023"]
## Type annotation modernization and other rules
"vllm/attention/backends/abstract.py" = ["UP035", "UP006"]
"vllm/attention/layer.py" = ["UP035", "UP006"]
"vllm/attention/layers/chunked_local_attention.py" = ["UP035", "UP006"]
"vllm/attention/ops/flashmla.py" = ["UP035", "UP006"]
"vllm/attention/ops/paged_attn.py" = ["UP035", "UP006"]
"vllm/engine/arg_utils.py" = ["UP035", "UP006"]
"vllm/engine/metrics.py" = ["UP035", "UP006"]
"vllm/engine/metrics_types.py" = ["UP035", "UP006"]
"vllm/executor/executor_base.py" = ["UP035", "UP006"]
"vllm/executor/msgspec_utils.py" = ["UP035", "UP006"]
"vllm/executor/ray_distributed_executor.py" = ["UP035", "UP006", "SIM108", "SIM112"]
"vllm/executor/ray_utils.py" = ["UP035", "UP006"]
"vllm/executor/uniproc_executor.py" = ["UP035", "UP006"]
"vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py" = ["UP035"]
## Type comparison issues
"vllm/multimodal/inputs.py" = ["E721"]
# End of temporary ignores
[tool.ruff.lint]
......
......@@ -5,7 +5,7 @@ from __future__ import annotations
import logging
import tempfile
from typing import Any, Optional, Union
from typing import Any, Union
import pytest
import torch
......@@ -21,7 +21,7 @@ from vllm.utils import is_torch_equal_or_newer
from ..utils import create_new_process_for_each_test
def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
def models_list(*, all: bool = True, keywords: list[str] | None = None):
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
("facebook/opt-125m", {}),
(
......
......@@ -6,7 +6,7 @@ from __future__ import annotations
import asyncio
from contextlib import suppress
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any
from unittest.mock import AsyncMock, MagicMock
import pytest
......@@ -233,9 +233,9 @@ class MockModelConfig:
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
logits_processor_pattern = None
diff_sampling_param: Optional[dict] = None
diff_sampling_param: dict | None = None
allowed_local_media_path: str = ""
allowed_media_domains: Optional[list[str]] = None
allowed_media_domains: list[str] | None = None
encoder_config = None
generation_config: str = "auto"
media_io_kwargs: dict[str, dict[str, Any]] = field(default_factory=dict)
......
......@@ -9,7 +9,7 @@ import os
import tempfile
import urllib.request
from collections.abc import Sequence
from typing import Any, Optional, Union
from typing import Any, Union
import albumentations
import numpy as np
......@@ -98,9 +98,9 @@ def _convert_np_uint8(float_image: torch.Tensor):
def read_geotiff(
file_path: Optional[str] = None,
path_type: Optional[str] = None,
file_data: Optional[bytes] = None,
file_path: str | None = None,
path_type: str | None = None,
file_data: bytes | None = None,
) -> tuple[torch.Tensor, dict, tuple[float, float] | None]:
"""Read all bands from *file_path* and return image + meta info.
......@@ -114,8 +114,8 @@ def read_geotiff(
if all([x is None for x in [file_path, path_type, file_data]]):
raise Exception("All input fields to read_geotiff are None")
write_to_file: Optional[bytes] = None
path: Optional[str] = None
write_to_file: bytes | None = None
path: str | None = None
if file_data is not None:
# with tempfile.NamedTemporaryFile() as tmpfile:
# tmpfile.write(file_data)
......@@ -162,9 +162,9 @@ def read_geotiff(
def load_image(
data: Union[list[str]],
path_type: str,
mean: Optional[list[float]] = None,
std: Optional[list[float]] = None,
indices: Optional[Union[list[int], None]] = None,
mean: list[float] | None = None,
std: list[float] | None = None,
indices: Union[list[int], None] | None = None,
):
"""Build an input example by loading images in *file_paths*.
......@@ -278,7 +278,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
def pre_process(
self,
prompt: IOProcessorInput,
request_id: Optional[str] = None,
request_id: str | None = None,
**kwargs,
) -> Union[PromptType, Sequence[PromptType]]:
image_data = dict(prompt)
......@@ -359,7 +359,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
def post_process(
self,
model_output: Sequence[PoolingRequestOutput],
request_id: Optional[str] = None,
request_id: str | None = None,
**kwargs,
) -> IOProcessorOutput:
pred_imgs_list = []
......
......@@ -3,7 +3,7 @@
from __future__ import annotations
import random
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
import pytest
......@@ -78,7 +78,7 @@ def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch):
def _get_test_sampling_params(
prompt_list: list[str],
seed: Optional[int] = 42,
seed: int | None = 42,
structured_outputs: bool = False,
) -> tuple[list[SamplingParams], list[int]]:
"""Generate random sampling params for a batch."""
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from typing import Generic, List, Optional, Protocol, Tuple, Type, TypeVar
from typing import Generic, Optional, Protocol, TypeVar
import torch
......@@ -48,12 +48,12 @@ class AttentionBackend(ABC):
@staticmethod
@abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]:
def get_impl_cls() -> type["AttentionImpl"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
def get_metadata_cls() -> type["AttentionMetadata"]:
raise NotImplementedError
@classmethod
......@@ -73,11 +73,11 @@ class AttentionBackend(ABC):
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> Tuple[int, ...]:
) -> tuple[int, ...]:
raise NotImplementedError
@staticmethod
def get_kv_cache_stride_order() -> Tuple[int, ...]:
def get_kv_cache_stride_order() -> tuple[int, ...]:
raise NotImplementedError
@classmethod
......@@ -147,7 +147,7 @@ class AttentionImpl(ABC, Generic[T]):
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
alibi_slopes: Optional[list[float]] = None,
sliding_window: Optional[int] = None,
kv_cache_dtype: str = "auto",
logits_soft_cap: Optional[float] = None,
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer."""
from typing import Callable, List, Optional
from typing import Callable, Optional
import torch
import torch.nn as nn
......@@ -126,7 +126,7 @@ class Attention(nn.Module, AttentionLayerBase):
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
alibi_slopes: Optional[list[float]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
logits_soft_cap: Optional[float] = None,
......@@ -586,7 +586,7 @@ def wait_for_kv_layer_from_connector(layer_name: str):
def maybe_save_kv_layer_to_connector(
layer_name: str,
kv_cache_layer: List[torch.Tensor],
kv_cache_layer: list[torch.Tensor],
):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from typing import ClassVar, List, Optional
from typing import ClassVar, Optional
import torch
......@@ -61,7 +61,7 @@ class ChunkedLocalAttention(Attention):
scale: float,
attention_chunk_size: int,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
alibi_slopes: Optional[list[float]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
kv_sharing_target_layer_name: Optional[str] = None,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# adapted from: https://github.com/deepseek-ai/FlashMLA/blob/main/flash_mla/flash_mla_interface.py
from typing import Optional, Tuple
from typing import Optional
import torch
......@@ -31,7 +31,7 @@ else:
_flashmla_extension_C_AVAILABLE = False
def is_flashmla_supported() -> Tuple[bool, Optional[str]]:
def is_flashmla_supported() -> tuple[bool, Optional[str]]:
"""
Return: is_supported_flag, unsupported_reason (optional).
"""
......@@ -57,7 +57,7 @@ def get_mla_metadata(
num_heads_q: Optional[int] = None,
is_fp8_kvcache: bool = False,
topk: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
- cache_seqlens: (batch_size), dtype torch.int32.
......@@ -101,7 +101,7 @@ def flash_mla_with_kvcache(
descale_k: Optional[torch.Tensor] = None,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Arguments:
- q: (batch_size, seq_len_q, num_heads_q, head_dim).
......@@ -183,7 +183,7 @@ def flash_mla_sparse_prefill(
indices: torch.Tensor,
sm_scale: float,
d_v: int = 512,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sparse attention prefill kernel
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import Optional
import torch
......@@ -41,7 +41,7 @@ class PagedAttentionMetadata:
class PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
def get_supported_head_sizes() -> list[int]:
return [32, 64, 80, 96, 112, 120, 128, 192, 256]
@staticmethod
......@@ -51,7 +51,7 @@ class PagedAttention:
num_kv_heads: int,
head_size: int,
cache_dtype_str: str = "auto",
) -> Tuple[int, ...]:
) -> tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod
......@@ -59,7 +59,7 @@ class PagedAttention:
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
......@@ -255,7 +255,7 @@ class PagedAttention:
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
kv_caches: list[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
......
......@@ -14,11 +14,8 @@ from typing import (
Annotated,
Any,
Callable,
Dict,
List,
Literal,
Optional,
Type,
TypeVar,
Union,
cast,
......@@ -325,7 +322,7 @@ class EngineArgs:
"""Arguments for vLLM engine."""
model: str = ModelConfig.model
served_model_name: Optional[Union[str, List[str]]] = ModelConfig.served_model_name
served_model_name: Optional[Union[str, list[str]]] = ModelConfig.served_model_name
tokenizer: Optional[str] = ModelConfig.tokenizer
hf_config_path: Optional[str] = ModelConfig.hf_config_path
runner: RunnerOption = ModelConfig.runner
......@@ -350,7 +347,7 @@ class EngineArgs:
# is intended for expert use only. The API may change without
# notice.
distributed_executor_backend: Optional[
Union[str, DistributedExecutorBackend, Type[ExecutorBase]]
Union[str, DistributedExecutorBackend, type[ExecutorBase]]
] = ParallelConfig.distributed_executor_backend
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
......@@ -418,7 +415,7 @@ class EngineArgs:
media_io_kwargs: dict[str, dict[str, Any]] = get_field(
MultiModalConfig, "media_io_kwargs"
)
mm_processor_kwargs: Optional[Dict[str, Any]] = MultiModalConfig.mm_processor_kwargs
mm_processor_kwargs: Optional[dict[str, Any]] = MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = False # DEPRECATED
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
mm_processor_cache_type: Optional[MMCacheType] = (
......@@ -436,7 +433,7 @@ class EngineArgs:
enable_lora_bias: bool = LoRAConfig.bias_enabled
max_loras: int = LoRAConfig.max_loras
max_lora_rank: int = LoRAConfig.max_lora_rank
default_mm_loras: Optional[Dict[str, str]] = LoRAConfig.default_mm_loras
default_mm_loras: Optional[dict[str, str]] = LoRAConfig.default_mm_loras
fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras
max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras
lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype
......@@ -446,7 +443,7 @@ class EngineArgs:
num_gpu_blocks_override: Optional[int] = CacheConfig.num_gpu_blocks_override
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
ignore_patterns: Optional[Union[str, List[str]]] = LoadConfig.ignore_patterns
ignore_patterns: Optional[Union[str, list[str]]] = LoadConfig.ignore_patterns
enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
......@@ -467,7 +464,7 @@ class EngineArgs:
logits_processor_pattern: Optional[str] = ModelConfig.logits_processor_pattern
speculative_config: Optional[Dict[str, Any]] = None
speculative_config: Optional[dict[str, Any]] = None
show_hidden_metrics_for_version: Optional[str] = (
ObservabilityConfig.show_hidden_metrics_for_version
......@@ -477,7 +474,7 @@ class EngineArgs:
ObservabilityConfig.collect_detailed_traces
)
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
scheduler_cls: Union[str, type[object]] = SchedulerConfig.scheduler_cls
pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config
override_pooler_config: Optional[Union[dict, PoolerConfig]] = (
......
......@@ -3,7 +3,7 @@
import time
from collections import Counter as CollectionsCounter
from typing import Dict, List, Optional, Type, Union, cast
from typing import Optional, Union, cast
import numpy as np
import prometheus_client
......@@ -43,7 +43,7 @@ class Metrics:
_counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram
def __init__(self, labelnames: List[str], vllm_config: VllmConfig):
def __init__(self, labelnames: list[str], vllm_config: VllmConfig):
# Unregister any existing vLLM collectors (for CI/CD)
self._unregister_vllm_metrics()
......@@ -304,7 +304,7 @@ class _RayGaugeWrapper:
self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None,
labelnames: Optional[list[str]] = None,
multiprocess_mode: str = "",
):
del multiprocess_mode
......@@ -330,7 +330,7 @@ class _RayCounterWrapper:
prometheus_client.Counter"""
def __init__(
self, name: str, documentation: str = "", labelnames: Optional[List[str]] = None
self, name: str, documentation: str = "", labelnames: Optional[list[str]] = None
):
labelnames_tuple = tuple(labelnames) if labelnames else None
self._counter = ray_metrics.Counter(
......@@ -355,8 +355,8 @@ class _RayHistogramWrapper:
self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None,
buckets: Optional[List[float]] = None,
labelnames: Optional[list[str]] = None,
buckets: Optional[list[float]] = None,
):
labelnames_tuple = tuple(labelnames) if labelnames else None
boundaries = buckets if buckets else []
......@@ -381,17 +381,17 @@ class RayMetrics(Metrics):
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
_gauge_cls: Type[prometheus_client.Gauge] = cast(
Type[prometheus_client.Gauge], _RayGaugeWrapper
_gauge_cls: type[prometheus_client.Gauge] = cast(
type[prometheus_client.Gauge], _RayGaugeWrapper
)
_counter_cls: Type[prometheus_client.Counter] = cast(
Type[prometheus_client.Counter], _RayCounterWrapper
_counter_cls: type[prometheus_client.Counter] = cast(
type[prometheus_client.Counter], _RayCounterWrapper
)
_histogram_cls: Type[prometheus_client.Histogram] = cast(
Type[prometheus_client.Histogram], _RayHistogramWrapper
_histogram_cls: type[prometheus_client.Histogram] = cast(
type[prometheus_client.Histogram], _RayHistogramWrapper
)
def __init__(self, labelnames: List[str], vllm_config: VllmConfig):
def __init__(self, labelnames: list[str], vllm_config: VllmConfig):
if ray_metrics is None:
raise ImportError("RayMetrics requires Ray to be installed.")
super().__init__(labelnames, vllm_config)
......@@ -401,14 +401,14 @@ class RayMetrics(Metrics):
pass
def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]:
def build_buckets(mantissa_lst: list[int], max_value: int) -> list[int]:
"""
Builds a list of buckets with increasing powers of 10 multiplied by
mantissa values until the value exceeds the specified maximum.
"""
exponent = 0
buckets: List[int] = []
buckets: list[int] = []
while True:
for m in mantissa_lst:
value = m * 10**exponent
......@@ -419,7 +419,7 @@ def build_buckets(mantissa_lst: List[int], max_value: int) -> List[int]:
exponent += 1
def build_1_2_5_buckets(max_value: int) -> List[int]:
def build_1_2_5_buckets(max_value: int) -> list[int]:
"""
Example:
>>> build_1_2_5_buckets(100)
......@@ -428,7 +428,7 @@ def build_1_2_5_buckets(max_value: int) -> List[int]:
return build_buckets([1, 2, 5], max_value)
def build_1_2_3_5_8_buckets(max_value: int) -> List[int]:
def build_1_2_3_5_8_buckets(max_value: int) -> list[int]:
"""
Example:
>>> build_1_2_3_5_8_buckets(100)
......@@ -442,7 +442,7 @@ def local_interval_elapsed(now: float, last_log: float, local_interval: float) -
return elapsed_time > local_interval
def get_throughput(tracked_stats: List[int], now: float, last_log: float) -> float:
def get_throughput(tracked_stats: list[int], now: float, last_log: float) -> float:
return float(np.sum(tracked_stats) / (now - last_log))
......@@ -530,7 +530,7 @@ class PrometheusStatLogger(StatLoggerBase):
_gauge_cls = prometheus_client.Gauge
def __init__(
self, local_interval: float, labels: Dict[str, str], vllm_config: VllmConfig
self, local_interval: float, labels: dict[str, str], vllm_config: VllmConfig
) -> None:
super().__init__(local_interval, vllm_config)
# Prometheus metrics
......@@ -558,12 +558,12 @@ class PrometheusStatLogger(StatLoggerBase):
for label, count in data.items():
counter.labels(**{**self.labels, label_key: label}).inc(count)
def _log_histogram(self, histogram, data: Union[List[int], List[float]]) -> None:
def _log_histogram(self, histogram, data: Union[list[int], list[float]]) -> None:
# Convenience function for logging list to histogram.
for datum in data:
histogram.labels(**self.labels).observe(datum)
def _log_gauge_string(self, gauge, data: Dict[str, str]) -> None:
def _log_gauge_string(self, gauge, data: dict[str, str]) -> None:
gauge.labels(**data).set_to_current_time()
def _log_prometheus(self, stats: Stats) -> None:
......
......@@ -16,7 +16,6 @@ do this in Python code and lazily import prometheus_client.
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List
from vllm.config import SupportsMetricsInfo, VllmConfig
......@@ -43,26 +42,26 @@ class Stats:
num_prompt_tokens_iter: int
num_generation_tokens_iter: int
num_tokens_iter: int
time_to_first_tokens_iter: List[float]
inter_token_latencies_iter: List[float]
time_to_first_tokens_iter: list[float]
inter_token_latencies_iter: list[float]
num_preemption_iter: int
# Request stats (should have _requests suffix)
# Latency
time_e2e_requests: List[float]
time_queue_requests: List[float]
time_inference_requests: List[float]
time_prefill_requests: List[float]
time_decode_requests: List[float]
time_e2e_requests: list[float]
time_queue_requests: list[float]
time_inference_requests: list[float]
time_prefill_requests: list[float]
time_decode_requests: list[float]
# Metadata
num_prompt_tokens_requests: List[int]
num_generation_tokens_requests: List[int]
n_requests: List[int]
max_num_generation_tokens_requests: List[int]
max_tokens_requests: List[int]
finished_reason_requests: List[str]
waiting_lora_adapters: List[str]
running_lora_adapters: List[str]
num_prompt_tokens_requests: list[int]
num_generation_tokens_requests: list[int]
n_requests: list[int]
max_num_generation_tokens_requests: list[int]
max_tokens_requests: list[int]
finished_reason_requests: list[str]
waiting_lora_adapters: list[str]
running_lora_adapters: list[str]
max_lora: str
......@@ -71,8 +70,8 @@ class StatLoggerBase(ABC):
def __init__(self, local_interval: float, vllm_config: VllmConfig) -> None:
# Tracked stats over current local logging interval.
self.num_prompt_tokens: List[int] = []
self.num_generation_tokens: List[int] = []
self.num_prompt_tokens: list[int] = []
self.num_generation_tokens: list[int] = []
self.last_local_log = time.time()
self.local_interval = local_interval
......
......@@ -6,7 +6,7 @@ from __future__ import annotations
import datetime
import json
from collections.abc import Iterable, Sequence
from typing import Literal, Optional, Union
from typing import Literal, Union
from openai.types.responses import (
ResponseFunctionToolCall,
......@@ -79,13 +79,13 @@ def get_encoding():
def get_system_message(
model_identity: Optional[str] = None,
reasoning_effort: Optional[Literal["high", "medium", "low"]] = None,
start_date: Optional[str] = None,
browser_description: Optional[str] = None,
python_description: Optional[str] = None,
container_description: Optional[str] = None,
instructions: Optional[str] = None,
model_identity: str | None = None,
reasoning_effort: Literal["high", "medium", "low"] | None = None,
start_date: str | None = None,
browser_description: str | None = None,
python_description: str | None = None,
container_description: str | None = None,
instructions: str | None = None,
with_custom_tools: bool = False,
) -> Message:
sys_msg_content = SystemContent.new()
......@@ -137,8 +137,8 @@ def create_tool_definition(tool: Union[ChatCompletionToolsParam, Tool]):
def get_developer_message(
instructions: Optional[str] = None,
tools: Optional[list[Union[Tool, ChatCompletionToolsParam]]] = None,
instructions: str | None = None,
tools: list[Union[Tool, ChatCompletionToolsParam]] | None = None,
) -> Message:
dev_msg_content = DeveloperContent.new()
if instructions is not None and not envs.VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS:
......@@ -202,7 +202,7 @@ def parse_response_input(
msg = msg.with_channel("final")
elif response_msg["type"] == "function_call_output":
call_id = response_msg["call_id"]
call_response: Optional[ResponseFunctionToolCall] = None
call_response: ResponseFunctionToolCall | None = None
for prev_response in reversed(prev_responses):
if (
isinstance(prev_response, ResponseFunctionToolCall)
......@@ -450,7 +450,7 @@ def parse_output_into_messages(token_ids: Iterable[int]) -> StreamableParser:
def parse_chat_output(
token_ids: Sequence[int],
) -> tuple[Optional[str], Optional[str], bool]:
) -> tuple[str | None, str | None, bool]:
parser = parse_output_into_messages(token_ids)
output_msgs = parser.messages
is_tool_call = False # TODO: update this when tool call is supported
......
......@@ -6,7 +6,7 @@ import time
from abc import ABC, abstractmethod
from collections.abc import Awaitable
from functools import cached_property
from typing import Any, Callable, List, Optional, Set, Union
from typing import Any, Callable, Optional, Union
from typing_extensions import TypeVar
......@@ -143,7 +143,7 @@ class ExecutorBase(ABC):
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
) -> Optional[list[Union[SamplerOutput, PoolerOutput]]]:
output = self.collective_rpc("execute_model", args=(execute_model_req,))
return output[0]
......@@ -163,7 +163,7 @@ class ExecutorBase(ABC):
assert lora_id > 0, "lora_id must be greater than 0."
return all(self.collective_rpc("pin_lora", args=(lora_id,)))
def list_loras(self) -> Set[int]:
def list_loras(self) -> set[int]:
sets = self.collective_rpc("list_loras")
for s in sets:
assert s == sets[0], "All workers should have the same LORAs."
......@@ -238,7 +238,7 @@ class ExecutorBase(ABC):
async def execute_model_async(
self, execute_model_req: ExecuteModelRequest
) -> List[SamplerOutput]:
) -> list[SamplerOutput]:
"""Executes one model step on the given sequences."""
output = await make_async(self.execute_model)(execute_model_req)
return output
......@@ -272,7 +272,7 @@ class DistributedExecutorBase(ExecutorBase):
def execute_model(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
) -> list[SamplerOutput]:
# TODO: unify into collective_rpc
if self.parallel_worker_tasks is None:
self.parallel_worker_tasks = self._run_workers(
......@@ -299,7 +299,7 @@ class DistributedExecutorBase(ExecutorBase):
@abstractmethod
def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
) -> Optional[list[SamplerOutput]]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution loop
......@@ -346,7 +346,7 @@ class DistributedExecutorBase(ExecutorBase):
async def execute_model_async(
self, execute_model_req: ExecuteModelRequest
) -> List[SamplerOutput]:
) -> list[SamplerOutput]:
if self.parallel_worker_tasks is None:
# Start model execution loop running in the parallel workers
self.parallel_worker_tasks = asyncio.create_task(
......@@ -371,7 +371,7 @@ class DistributedExecutorBase(ExecutorBase):
async def _driver_execute_model_async(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,
) -> List[SamplerOutput]:
) -> list[SamplerOutput]:
"""Execute the model asynchronously in the driver worker.
Passing None will cause the driver to stop the model execution
......
......@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from array import array
from typing import Any, Type
from typing import Any
from vllm.multimodal.inputs import MultiModalKwargs
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE
......@@ -23,7 +23,7 @@ def encode_hook(obj: Any) -> Any:
return dict(obj)
def decode_hook(type: Type, obj: Any) -> Any:
def decode_hook(type: type, obj: Any) -> Any:
"""Custom msgspec dec hook that supports array types and MultiModalKwargs.
See https://jcristharif.com/msgspec/api.html#msgspec.msgpack.Encoder
......
......@@ -5,7 +5,7 @@ import asyncio
import os
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import cloudpickle
import msgspec
......@@ -114,10 +114,10 @@ class RayDistributedExecutor(DistributedExecutorBase):
self._init_workers_ray(placement_group)
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
self.output_decoder = msgspec.msgpack.Decoder(Optional[List[SamplerOutput]])
self.output_decoder = msgspec.msgpack.Decoder(Optional[list[SamplerOutput]])
self.use_v1 = envs.VLLM_USE_V1
self.pp_locks: Optional[List[asyncio.Lock]] = None
self.pp_locks: Optional[list[asyncio.Lock]] = None
if not self.use_ray_compiled_dag:
self.driver_exec_method = make_async(self.driver_worker.execute_method)
......@@ -137,7 +137,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
ray.kill(worker)
self.forward_dag = None
def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> Dict[str, Any]:
def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> dict[str, Any]:
# If nsight profiling is enabled, we need to set the profiling
# configuration for the ray workers as runtime env.
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
......@@ -164,12 +164,12 @@ class RayDistributedExecutor(DistributedExecutorBase):
# It holds the resource for the driver worker.
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
# The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = []
self.workers: list[RayWorkerWrapper] = []
# Used in ray compiled DAG: indexed first by PP rank,
# and then TP rank. In other words, the inner list is
# the TP group of workers for a PP rank.
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
self.pp_tp_workers: list[list[RayWorkerWrapper]] = []
if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
......@@ -179,7 +179,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers.
bundle_indices: List[int]
bundle_indices: list[int]
if envs.VLLM_RAY_BUNDLE_INDICES:
# Use the bundle indices specified by the user.
bundle_indices = list(map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
......@@ -200,7 +200,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
bundle_indices.append(bundle_id)
bundle_indices = bundle_indices[: self.parallel_config.world_size]
worker_metadata: List[RayWorkerMetaData] = []
worker_metadata: list[RayWorkerMetaData] = []
driver_ip = get_ip()
for rank, bundle_id in enumerate(bundle_indices):
scheduling_strategy = PlacementGroupSchedulingStrategy(
......@@ -262,7 +262,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
"the driver on a GPU node."
)
ip_counts: Dict[str, int] = {}
ip_counts: dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
......@@ -416,11 +416,11 @@ class RayDistributedExecutor(DistributedExecutorBase):
# This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the
# rest of the workers.
self.tp_driver_workers: List[RayWorkerWrapper] = []
self.tp_driver_workers: list[RayWorkerWrapper] = []
# This is the list of workers that are not drivers and not the first
# worker in a TP group. These are the workers that will be
# broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = []
self.non_driver_workers: list[RayWorkerWrapper] = []
# Enforce rank order for correct rank to return final output.
for index, worker in enumerate(self.workers):
......@@ -433,7 +433,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest]
) -> Optional[List[SamplerOutput]]:
) -> Optional[list[SamplerOutput]]:
"""Run execute_model in the driver worker.
Passing None will cause the driver to stop the model execution
......@@ -446,7 +446,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> List[SamplerOutput]:
) -> list[SamplerOutput]:
if not self.use_ray_spmd_worker:
return super().execute_model(execute_model_req)
......@@ -675,7 +675,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
async def execute_model_async(
self, execute_model_req: ExecuteModelRequest
) -> List[SamplerOutput]:
) -> list[SamplerOutput]:
if not self.use_ray_spmd_worker:
return await super().execute_model_async(execute_model_req)
......@@ -689,7 +689,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
async def _driver_execute_model_async(
self, execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]:
) -> list[SamplerOutput]:
assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1"
)
......
......@@ -4,7 +4,7 @@
import os
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Optional, Union
import msgspec
......@@ -59,7 +59,7 @@ try:
def get_node_ip(self) -> str:
return get_ip()
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
def get_node_and_gpu_ids(self) -> tuple[str, list[int]]:
node_id = ray.get_runtime_context().get_node_id()
device_key = vllm.platforms.current_platform.ray_device_key
if not device_key:
......@@ -72,7 +72,7 @@ try:
def execute_model_spmd(
self,
req_or_tuple: Union[bytes, Tuple[bytes, Optional[IntermediateTensors]]],
req_or_tuple: Union[bytes, tuple[bytes, Optional[IntermediateTensors]]],
) -> bytes:
"""Execute model in SPMD fashion: used only when SPMD worker and
compiled DAG are both enabled.
......@@ -126,10 +126,10 @@ try:
def execute_model_ray(
self,
scheduler_output: Union[
"SchedulerOutput", Tuple["SchedulerOutput", "IntermediateTensors"]
"SchedulerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
],
) -> Union[
"ModelRunnerOutput", Tuple["SchedulerOutput", "IntermediateTensors"]
"ModelRunnerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
]:
# This method is used by Ray Compiled Graph to execute the model,
# and it needs a special logic of self.setup_device_if_necessary()
......@@ -156,7 +156,7 @@ try:
output = output.get_output()
return output
def override_env_vars(self, vars: Dict[str, str]):
def override_env_vars(self, vars: dict[str, str]):
os.environ.update(vars)
ray_import_err = None
......@@ -201,7 +201,7 @@ def _verify_bundles(
# bundle_idx -> bundle (e.g., {"GPU": 1})
bundles = pg_data["bundles"]
# node_id -> List of bundle (e.g., {"GPU": 1})
node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list)
node_id_to_bundle: dict[str, list[dict[str, float]]] = defaultdict(list)
for bundle_idx, node_id in bundle_to_node_ids.items():
node_id_to_bundle[node_id].append(bundles[bundle_idx])
......@@ -383,7 +383,7 @@ def initialize_ray_cluster(
device_str,
)
# Create a new placement group
placement_group_specs: List[Dict[str, float]] = [
placement_group_specs: list[dict[str, float]] = [
{device_str: 1.0} for _ in range(parallel_config.world_size)
]
......
......@@ -4,7 +4,7 @@ import os
from concurrent.futures import Future, ThreadPoolExecutor
from functools import cached_property
from multiprocessing import Lock
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Union
import torch
import torch.distributed as dist
......@@ -68,10 +68,10 @@ class UniProcExecutor(ExecutorBase):
self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None,
args: tuple = (),
kwargs: Optional[dict] = None,
non_block: bool = False,
) -> List[Any]:
) -> list[Any]:
if kwargs is None:
kwargs = {}
if self.mm_receiver_cache is not None and method == "execute_model":
......@@ -158,7 +158,7 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
local_rank = int(os.environ["LOCAL_RANK"])
return distributed_init_method, rank, local_rank
def determine_num_available_blocks(self) -> Tuple[int, int]:
def determine_num_available_blocks(self) -> tuple[int, int]:
"""
Determine the number of available KV blocks.
Add an additional all_reduce to get the min across all ranks.
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import (
List, # noqa: UP035
Optional,
)
from typing import Optional
import torch
......@@ -32,7 +29,7 @@ def flashinfer_fused_moe_blockscale_fp8(
intermediate_size: int,
expert_offset: int,
local_num_experts: int,
block_shape: List[int], # noqa: UP006
block_shape: list[int],
routed_scaling: float = 1.0,
) -> torch.Tensor:
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment