Unverified Commit f1acbd68 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

[CI] Enable mypy import following for `vllm/compilation` (#33199)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 9581185d
......@@ -103,13 +103,6 @@ ignore_missing_imports = true
check_untyped_defs = true
follow_imports = "silent"
[[tool.mypy.overrides]]
module = "vllm.compilation.*"
disallow_untyped_defs = true
disallow_incomplete_defs = true
warn_return_any = true
follow_imports = "silent"
[tool.pytest.ini_options]
markers = [
"slow_test",
......
......@@ -26,6 +26,7 @@ import regex as re
FILES = [
"vllm/*.py",
"vllm/assets",
"vllm/compilation",
"vllm/distributed",
"vllm/engine",
"vllm/entrypoints",
......@@ -58,7 +59,6 @@ FILES = [
SEPARATE_GROUPS = [
"tests",
# v0 related
"vllm/compilation",
"vllm/lora",
"vllm/model_executor",
# v1 related
......@@ -76,11 +76,6 @@ EXCLUDE = [
"vllm/v1/attention/ops",
]
# Directories that should be checked with --strict
STRICT_DIRS = [
"vllm/compilation",
]
def group_files(changed_files: list[str]) -> dict[str, list[str]]:
"""
......@@ -112,17 +107,11 @@ def group_files(changed_files: list[str]) -> dict[str, list[str]]:
return file_groups
def is_strict_file(filepath: str) -> bool:
"""Check if a file should be checked with strict mode."""
return any(filepath.startswith(strict_dir) for strict_dir in STRICT_DIRS)
def mypy(
targets: list[str],
python_version: str | None,
follow_imports: str | None,
file_group: str,
strict: bool = False,
) -> int:
"""
Run mypy on the given targets.
......@@ -134,7 +123,6 @@ def mypy(
follow_imports: Value for the --follow-imports option or None to use
the default mypy behavior.
file_group: The file group name for logging purposes.
strict: If True, run mypy with --strict flag.
Returns:
The return code from mypy.
......@@ -144,8 +132,6 @@ def mypy(
args += ["--python-version", python_version]
if follow_imports is not None:
args += ["--follow-imports", follow_imports]
if strict:
args += ["--strict"]
print(f"$ {' '.join(args)} {file_group}")
return subprocess.run(args + targets, check=False).returncode
......@@ -162,28 +148,8 @@ def main():
for file_group, changed_files in file_groups.items():
follow_imports = None if ci and file_group == "" else "skip"
if changed_files:
# Separate files into strict and non-strict groups
strict_files = [f for f in changed_files if is_strict_file(f)]
non_strict_files = [f for f in changed_files if not is_strict_file(f)]
# Run mypy on non-strict files
if non_strict_files:
returncode |= mypy(
non_strict_files,
python_version,
follow_imports,
file_group,
strict=False,
)
# Run mypy on strict files with --strict flag
if strict_files:
returncode |= mypy(
strict_files,
python_version,
follow_imports,
f"{file_group} (strict)",
strict=True,
changed_files, python_version, follow_imports, file_group
)
return returncode
......
......@@ -29,9 +29,6 @@ else:
Torch25CustomGraphPass as CustomGraphPass,
)
# Re-export CustomGraphPass for external usage
__all__ = ["CustomGraphPass"]
_pass_context = None
P = ParamSpec("P")
R = TypeVar("R")
......
......@@ -32,9 +32,6 @@ else:
logger = init_logger(__name__)
# Explicitly exports Range
__all__ = ["Range"]
class CompilationMode(enum.IntEnum):
"""The compilation approach used for torch.compile-based compilation of the
......
......@@ -22,8 +22,6 @@ from .yarn_scaling_rope import YaRNScalingRotaryEmbedding
_ROPE_DICT: dict[tuple[Any, ...], RotaryEmbedding] = {}
__all__ = ["RotaryEmbedding"]
def get_rope(
head_size: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment