Unverified Commit fb946a7f authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Make `mypy` opt-out instead of opt-in (#33205)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent a650ad15
...@@ -23,39 +23,8 @@ import sys ...@@ -23,39 +23,8 @@ import sys
import regex as re import regex as re
FILES = [
"vllm/*.py",
"vllm/assets",
"vllm/compilation",
"vllm/distributed",
"vllm/engine",
"vllm/entrypoints",
"vllm/executor",
"vllm/inputs",
"vllm/logging_utils",
"vllm/multimodal",
"vllm/platforms",
"vllm/plugins",
"vllm/renderers",
"vllm/tokenizers",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
"vllm/utils",
"vllm/worker",
"vllm/v1/attention",
"vllm/v1/core",
"vllm/v1/engine",
"vllm/v1/executor",
"vllm/v1/metrics",
"vllm/v1/pool",
"vllm/v1/sample",
"vllm/v1/structured_output",
"vllm/v1/worker",
]
# After fixing errors resulting from changing follow_imports # After fixing errors resulting from changing follow_imports
# from "skip" to "silent", move the following directories to FILES # from "skip" to "silent", remove its directory from SEPARATE_GROUPS.
SEPARATE_GROUPS = [ SEPARATE_GROUPS = [
"tests", "tests",
# v0 related # v0 related
...@@ -74,6 +43,16 @@ EXCLUDE = [ ...@@ -74,6 +43,16 @@ EXCLUDE = [
"vllm/model_executor/layers/fla/ops", "vllm/model_executor/layers/fla/ops",
# Ignore triton kernels in ops. # Ignore triton kernels in ops.
"vllm/v1/attention/ops", "vllm/v1/attention/ops",
# TODO: Remove these entries after fixing mypy errors.
"vllm/benchmarks",
"vllm/config",
"vllm/device_allocator",
"vllm/profiler",
"vllm/reasoning",
"vllm/tool_parser",
"vllm/v1/cudagraph_dispatcher.py",
"vllm/outputs.py",
"vllm/logger.py",
] ]
...@@ -88,7 +67,6 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]: ...@@ -88,7 +67,6 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]:
A dictionary mapping file group names to lists of changed files. A dictionary mapping file group names to lists of changed files.
""" """
exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*") exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*")
files_pattern = re.compile(f"^({'|'.join(FILES)}).*")
file_groups = {"": []} file_groups = {"": []}
file_groups.update({k: [] for k in SEPARATE_GROUPS}) file_groups.update({k: [] for k in SEPARATE_GROUPS})
for changed_file in changed_files: for changed_file in changed_files:
...@@ -96,14 +74,13 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]: ...@@ -96,14 +74,13 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]:
if exclude_pattern.match(changed_file): if exclude_pattern.match(changed_file):
continue continue
# Group files by mypy call # Group files by mypy call
if files_pattern.match(changed_file):
file_groups[""].append(changed_file)
continue
else:
for directory in SEPARATE_GROUPS: for directory in SEPARATE_GROUPS:
if re.match(f"^{directory}.*", changed_file): if re.match(f"^{directory}.*", changed_file):
file_groups[directory].append(changed_file) file_groups[directory].append(changed_file)
break break
else:
if changed_file.startswith("vllm/"):
file_groups[""].append(changed_file)
return file_groups return file_groups
......
...@@ -349,7 +349,7 @@ def _rocm_aiter_mla_decode_fwd_impl( ...@@ -349,7 +349,7 @@ def _rocm_aiter_mla_decode_fwd_impl(
) -> None: ) -> None:
from aiter.mla import mla_decode_fwd from aiter.mla import mla_decode_fwd
kwargs = { kwargs: dict[str, float | torch.Tensor | None] = {
"sm_scale": sm_scale, "sm_scale": sm_scale,
"logit_cap": logit_cap, "logit_cap": logit_cap,
} }
......
...@@ -570,7 +570,7 @@ class CompilationConfig: ...@@ -570,7 +570,7 @@ class CompilationConfig:
pass_config: PassConfig = field(default_factory=PassConfig) pass_config: PassConfig = field(default_factory=PassConfig)
"""Custom inductor passes, see PassConfig for more details""" """Custom inductor passes, see PassConfig for more details"""
max_cudagraph_capture_size: int | None = field(default=None) max_cudagraph_capture_size: int = field(default=None)
"""The maximum cudagraph capture size. """The maximum cudagraph capture size.
If cudagraph_capture_sizes is specified, this will be set to the largest If cudagraph_capture_sizes is specified, this will be set to the largest
...@@ -743,6 +743,7 @@ class CompilationConfig: ...@@ -743,6 +743,7 @@ class CompilationConfig:
"level", "level",
"mode", "mode",
"cudagraph_mode", "cudagraph_mode",
"max_cudagraph_capture_size",
"use_inductor_graph_partition", "use_inductor_graph_partition",
mode="wrap", mode="wrap",
) )
......
...@@ -9,7 +9,7 @@ import inspect ...@@ -9,7 +9,7 @@ import inspect
import json import json
import pathlib import pathlib
import textwrap import textwrap
from collections.abc import Callable, Iterable, Mapping, Sequence, Set from collections.abc import Callable, Mapping, Sequence, Set
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
from itertools import pairwise from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar from typing import TYPE_CHECKING, Any, Protocol, TypeVar
...@@ -75,7 +75,7 @@ def get_field(cls: ConfigType, name: str) -> Field: ...@@ -75,7 +75,7 @@ def get_field(cls: ConfigType, name: str) -> Field:
def getattr_iter( def getattr_iter(
object: object, object: object,
names: Iterable[str], names: Sequence[str],
default: Any | None = None, default: Any | None = None,
default_factory: Callable[[], Any] | None = None, default_factory: Callable[[], Any] | None = None,
warn: bool = False, warn: bool = False,
......
...@@ -382,7 +382,7 @@ def _patch_get_raw_stream_if_needed(): ...@@ -382,7 +382,7 @@ def _patch_get_raw_stream_if_needed():
if hasattr(torch._C, "_cuda_getCurrentRawStream"): if hasattr(torch._C, "_cuda_getCurrentRawStream"):
from torch._C import _cuda_getCurrentRawStream as _get_raw_stream from torch._C import _cuda_getCurrentRawStream as _get_raw_stream
builtins.get_raw_stream = _get_raw_stream builtins.get_raw_stream = _get_raw_stream # type: ignore[attr-defined]
_patch_get_raw_stream_if_needed() _patch_get_raw_stream_if_needed()
......
...@@ -1680,6 +1680,7 @@ def disable_envs_cache() -> None: ...@@ -1680,6 +1680,7 @@ def disable_envs_cache() -> None:
global __getattr__ global __getattr__
# If __getattr__ is wrapped by functions.cache, unwrap the caching layer. # If __getattr__ is wrapped by functions.cache, unwrap the caching layer.
if _is_envs_cache_enabled(): if _is_envs_cache_enabled():
assert hasattr(__getattr__, "__wrapped__")
__getattr__ = __getattr__.__wrapped__ __getattr__ = __getattr__.__wrapped__
......
...@@ -270,7 +270,7 @@ def create_forward_context( ...@@ -270,7 +270,7 @@ def create_forward_context(
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None, batch_descriptor: BatchDescriptor | None = None,
ubatch_slices: UBatchSlices | None = None, ubatch_slices: UBatchSlices | None = None,
slot_mapping: dict[str, torch.Tensor] | None = None, slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
additional_kwargs: dict[str, Any] | None = None, additional_kwargs: dict[str, Any] | None = None,
skip_compiled: bool = False, skip_compiled: bool = False,
): ):
......
...@@ -157,7 +157,7 @@ _METHODS_TO_PATCH = { ...@@ -157,7 +157,7 @@ _METHODS_TO_PATCH = {
def _configure_vllm_root_logger() -> None: def _configure_vllm_root_logger() -> None:
logging_config = dict[str, Any]() logging_config = dict[str, dict[str, Any] | Any]()
if not envs.VLLM_CONFIGURE_LOGGING and envs.VLLM_LOGGING_CONFIG_PATH: if not envs.VLLM_CONFIGURE_LOGGING and envs.VLLM_LOGGING_CONFIG_PATH:
raise RuntimeError( raise RuntimeError(
......
...@@ -28,7 +28,7 @@ LogprobsOnePosition = dict[int, Logprob] ...@@ -28,7 +28,7 @@ LogprobsOnePosition = dict[int, Logprob]
@dataclass @dataclass
class FlatLogprobs(MutableSequence[LogprobsOnePosition]): class FlatLogprobs(MutableSequence[LogprobsOnePosition | None]):
""" """
Flat logprobs of a request into multiple primitive type lists. Flat logprobs of a request into multiple primitive type lists.
...@@ -140,7 +140,7 @@ class FlatLogprobs(MutableSequence[LogprobsOnePosition]): ...@@ -140,7 +140,7 @@ class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
def __delitem__(self, item) -> None: def __delitem__(self, item) -> None:
raise TypeError("Cannot delete logprobs from FlatLogprobs") raise TypeError("Cannot delete logprobs from FlatLogprobs")
def insert(self, item) -> None: def insert(self, index: int, value: dict[int, Logprob] | None) -> None:
raise TypeError("Cannot insert logprobs to FlatLogprobs") raise TypeError("Cannot insert logprobs to FlatLogprobs")
def __iter__(self) -> Iterator[LogprobsOnePosition]: def __iter__(self) -> Iterator[LogprobsOnePosition]:
...@@ -161,7 +161,7 @@ SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition] ...@@ -161,7 +161,7 @@ SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
def create_prompt_logprobs(flat_logprobs: bool) -> PromptLogprobs: def create_prompt_logprobs(flat_logprobs: bool) -> PromptLogprobs:
"""Creates a container to store prompt logprobs for a request""" """Creates a container to store prompt logprobs for a request"""
logprobs = FlatLogprobs() if flat_logprobs else [] logprobs: PromptLogprobs = FlatLogprobs() if flat_logprobs else []
# NOTE: logprob of first prompt token is None. # NOTE: logprob of first prompt token is None.
logprobs.append(None) logprobs.append(None)
return logprobs return logprobs
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import dataclasses import dataclasses
from collections.abc import Callable from collections.abc import Callable
from _typeshed import DataclassInstance
from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata
# #
...@@ -11,7 +12,7 @@ from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata ...@@ -11,7 +12,7 @@ from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata
# #
def trim_string_front(string, width): def trim_string_front(string: str, width: int) -> str:
if len(string) > width: if len(string) > width:
offset = len(string) - width + 3 offset = len(string) - width + 3
string = string[offset:] string = string[offset:]
...@@ -20,7 +21,7 @@ def trim_string_front(string, width): ...@@ -20,7 +21,7 @@ def trim_string_front(string, width):
return string return string
def trim_string_back(string, width): def trim_string_back(string: str, width: int) -> str:
if len(string) > width: if len(string) > width:
offset = len(string) - width + 3 offset = len(string) - width + 3
string = string[:-offset] string = string[:-offset]
...@@ -30,15 +31,13 @@ def trim_string_back(string, width): ...@@ -30,15 +31,13 @@ def trim_string_back(string, width):
class TablePrinter: class TablePrinter:
def __init__( def __init__(self, row_cls: type[DataclassInstance], column_widths: dict[str, int]):
self, row_cls: type[dataclasses.dataclass], column_widths: dict[str, int]
):
self.row_cls = row_cls self.row_cls = row_cls
self.fieldnames = [x.name for x in dataclasses.fields(row_cls)] self.fieldnames = [x.name for x in dataclasses.fields(row_cls)]
self.column_widths = column_widths self.column_widths = column_widths
assert set(self.column_widths.keys()) == set(self.fieldnames) assert set(self.column_widths.keys()) == set(self.fieldnames)
def print_table(self, rows: list[dataclasses.dataclass]): def print_table(self, rows: list[DataclassInstance]):
self._print_header() self._print_header()
self._print_line() self._print_line()
for row in rows: for row in rows:
......
...@@ -98,7 +98,7 @@ class FullAttentionSpec(AttentionSpec): ...@@ -98,7 +98,7 @@ class FullAttentionSpec(AttentionSpec):
In this case, we use FullAttentionSpec and record the sliding window size. In this case, we use FullAttentionSpec and record the sliding window size.
""" """
head_size_v: int | None = None head_size_v: int = None # type: ignore[assignment]
sliding_window: int | None = None sliding_window: int | None = None
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment