Unverified Commit fb946a7f authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Make `mypy` opt-out instead of opt-in (#33205)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent a650ad15
......@@ -23,39 +23,8 @@ import sys
import regex as re
FILES = [
"vllm/*.py",
"vllm/assets",
"vllm/compilation",
"vllm/distributed",
"vllm/engine",
"vllm/entrypoints",
"vllm/executor",
"vllm/inputs",
"vllm/logging_utils",
"vllm/multimodal",
"vllm/platforms",
"vllm/plugins",
"vllm/renderers",
"vllm/tokenizers",
"vllm/transformers_utils",
"vllm/triton_utils",
"vllm/usage",
"vllm/utils",
"vllm/worker",
"vllm/v1/attention",
"vllm/v1/core",
"vllm/v1/engine",
"vllm/v1/executor",
"vllm/v1/metrics",
"vllm/v1/pool",
"vllm/v1/sample",
"vllm/v1/structured_output",
"vllm/v1/worker",
]
# After fixing errors resulting from changing follow_imports
# from "skip" to "silent", move the following directories to FILES
# from "skip" to "silent", remove its directory from SEPARATE_GROUPS.
SEPARATE_GROUPS = [
"tests",
# v0 related
......@@ -74,6 +43,16 @@ EXCLUDE = [
"vllm/model_executor/layers/fla/ops",
# Ignore triton kernels in ops.
"vllm/v1/attention/ops",
# TODO: Remove these entries after fixing mypy errors.
"vllm/benchmarks",
"vllm/config",
"vllm/device_allocator",
"vllm/profiler",
"vllm/reasoning",
"vllm/tool_parser",
"vllm/v1/cudagraph_dispatcher.py",
"vllm/outputs.py",
"vllm/logger.py",
]
......@@ -88,7 +67,6 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]:
A dictionary mapping file group names to lists of changed files.
"""
exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*")
files_pattern = re.compile(f"^({'|'.join(FILES)}).*")
file_groups = {"": []}
file_groups.update({k: [] for k in SEPARATE_GROUPS})
for changed_file in changed_files:
......@@ -96,14 +74,13 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]:
if exclude_pattern.match(changed_file):
continue
# Group files by mypy call
if files_pattern.match(changed_file):
file_groups[""].append(changed_file)
continue
for directory in SEPARATE_GROUPS:
if re.match(f"^{directory}.*", changed_file):
file_groups[directory].append(changed_file)
break
else:
for directory in SEPARATE_GROUPS:
if re.match(f"^{directory}.*", changed_file):
file_groups[directory].append(changed_file)
break
if changed_file.startswith("vllm/"):
file_groups[""].append(changed_file)
return file_groups
......
......@@ -349,7 +349,7 @@ def _rocm_aiter_mla_decode_fwd_impl(
) -> None:
from aiter.mla import mla_decode_fwd
kwargs = {
kwargs: dict[str, float | torch.Tensor | None] = {
"sm_scale": sm_scale,
"logit_cap": logit_cap,
}
......
......@@ -570,7 +570,7 @@ class CompilationConfig:
pass_config: PassConfig = field(default_factory=PassConfig)
"""Custom inductor passes, see PassConfig for more details"""
max_cudagraph_capture_size: int | None = field(default=None)
max_cudagraph_capture_size: int = field(default=None)
"""The maximum cudagraph capture size.
If cudagraph_capture_sizes is specified, this will be set to the largest
......@@ -743,6 +743,7 @@ class CompilationConfig:
"level",
"mode",
"cudagraph_mode",
"max_cudagraph_capture_size",
"use_inductor_graph_partition",
mode="wrap",
)
......
......@@ -9,7 +9,7 @@ import inspect
import json
import pathlib
import textwrap
from collections.abc import Callable, Iterable, Mapping, Sequence, Set
from collections.abc import Callable, Mapping, Sequence, Set
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
......@@ -75,7 +75,7 @@ def get_field(cls: ConfigType, name: str) -> Field:
def getattr_iter(
object: object,
names: Iterable[str],
names: Sequence[str],
default: Any | None = None,
default_factory: Callable[[], Any] | None = None,
warn: bool = False,
......
......@@ -382,7 +382,7 @@ def _patch_get_raw_stream_if_needed():
if hasattr(torch._C, "_cuda_getCurrentRawStream"):
from torch._C import _cuda_getCurrentRawStream as _get_raw_stream
builtins.get_raw_stream = _get_raw_stream
builtins.get_raw_stream = _get_raw_stream # type: ignore[attr-defined]
_patch_get_raw_stream_if_needed()
......
......@@ -1680,6 +1680,7 @@ def disable_envs_cache() -> None:
global __getattr__
# If __getattr__ is wrapped by functions.cache, unwrap the caching layer.
if _is_envs_cache_enabled():
assert hasattr(__getattr__, "__wrapped__")
__getattr__ = __getattr__.__wrapped__
......
......@@ -270,7 +270,7 @@ def create_forward_context(
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: BatchDescriptor | None = None,
ubatch_slices: UBatchSlices | None = None,
slot_mapping: dict[str, torch.Tensor] | None = None,
slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
additional_kwargs: dict[str, Any] | None = None,
skip_compiled: bool = False,
):
......
......@@ -157,7 +157,7 @@ _METHODS_TO_PATCH = {
def _configure_vllm_root_logger() -> None:
logging_config = dict[str, Any]()
logging_config = dict[str, dict[str, Any] | Any]()
if not envs.VLLM_CONFIGURE_LOGGING and envs.VLLM_LOGGING_CONFIG_PATH:
raise RuntimeError(
......
......@@ -28,7 +28,7 @@ LogprobsOnePosition = dict[int, Logprob]
@dataclass
class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
class FlatLogprobs(MutableSequence[LogprobsOnePosition | None]):
"""
Flat logprobs of a request into multiple primitive type lists.
......@@ -140,7 +140,7 @@ class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
def __delitem__(self, item) -> None:
raise TypeError("Cannot delete logprobs from FlatLogprobs")
def insert(self, item) -> None:
def insert(self, index: int, value: dict[int, Logprob] | None) -> None:
raise TypeError("Cannot insert logprobs to FlatLogprobs")
def __iter__(self) -> Iterator[LogprobsOnePosition]:
......@@ -161,7 +161,7 @@ SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
def create_prompt_logprobs(flat_logprobs: bool) -> PromptLogprobs:
"""Creates a container to store prompt logprobs for a request"""
logprobs = FlatLogprobs() if flat_logprobs else []
logprobs: PromptLogprobs = FlatLogprobs() if flat_logprobs else []
# NOTE: logprob of first prompt token is None.
logprobs.append(None)
return logprobs
......
......@@ -4,6 +4,7 @@
import dataclasses
from collections.abc import Callable
from _typeshed import DataclassInstance
from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata
#
......@@ -11,7 +12,7 @@ from torch._C._profiler import _EventType, _ProfilerEvent, _TensorMetadata
#
def trim_string_front(string, width):
def trim_string_front(string: str, width: int) -> str:
if len(string) > width:
offset = len(string) - width + 3
string = string[offset:]
......@@ -20,7 +21,7 @@ def trim_string_front(string, width):
return string
def trim_string_back(string, width):
def trim_string_back(string: str, width: int) -> str:
if len(string) > width:
offset = len(string) - width + 3
string = string[:-offset]
......@@ -30,15 +31,13 @@ def trim_string_back(string, width):
class TablePrinter:
def __init__(
self, row_cls: type[dataclasses.dataclass], column_widths: dict[str, int]
):
def __init__(self, row_cls: type[DataclassInstance], column_widths: dict[str, int]):
self.row_cls = row_cls
self.fieldnames = [x.name for x in dataclasses.fields(row_cls)]
self.column_widths = column_widths
assert set(self.column_widths.keys()) == set(self.fieldnames)
def print_table(self, rows: list[dataclasses.dataclass]):
def print_table(self, rows: list[DataclassInstance]):
self._print_header()
self._print_line()
for row in rows:
......
......@@ -98,7 +98,7 @@ class FullAttentionSpec(AttentionSpec):
In this case, we use FullAttentionSpec and record the sliding window size.
"""
head_size_v: int | None = None
head_size_v: int = None # type: ignore[assignment]
sliding_window: int | None = None
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment