mypy.py 3.75 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Run mypy on changed files.

This script is designed to be used as a pre-commit hook. It runs mypy
on files that have been changed. It groups files into different mypy calls
based on their directory to avoid import following issues.

Usage:
    python tools/pre_commit/mypy.py <ci> <python_version> <changed_files...>

Args:
    ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to
        "silent" for the main group of files.
    python_version: Python version to use (e.g., "3.10") or "local" to use
        the local Python version.
    changed_files: List of changed files to check.
"""

import subprocess
import sys

import regex as re

# After fixing errors resulting from changing follow_imports
27
# from "skip" to "silent", remove its directory from SEPARATE_GROUPS.
28
29
SEPARATE_GROUPS = [
    "tests",
30
    # v0 related
31
32
33
34
35
36
37
38
39
    "vllm/lora",
    "vllm/model_executor",
]

# TODO(woosuk): Include the code from Megatron and HuggingFace.
EXCLUDE = [
    "vllm/model_executor/models",
    "vllm/model_executor/layers/fla/ops",
    # Ignore triton kernels in ops.
40
    "vllm/v1/attention/ops",
41
42
43
44
45
46
    # TODO: Remove these entries after fixing mypy errors.
    "vllm/benchmarks",
    "vllm/config",
    "vllm/device_allocator",
    "vllm/reasoning",
    "vllm/tool_parser",
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
]


def group_files(changed_files: list[str]) -> dict[str, list[str]]:
    """
    Group changed files into different mypy calls.

    Args:
        changed_files: List of changed files.

    Returns:
        A dictionary mapping file group names to lists of changed files.
    """
    exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*")
    file_groups = {"": []}
    file_groups.update({k: [] for k in SEPARATE_GROUPS})
    for changed_file in changed_files:
        # Skip files which should be ignored completely
        if exclude_pattern.match(changed_file):
            continue
        # Group files by mypy call
68
69
70
71
        for directory in SEPARATE_GROUPS:
            if re.match(f"^{directory}.*", changed_file):
                file_groups[directory].append(changed_file)
                break
72
        else:
73
74
            if changed_file.startswith("vllm/"):
                file_groups[""].append(changed_file)
75
76
77
    return file_groups


78
79
def mypy(
    targets: list[str],
80
81
    python_version: str | None,
    follow_imports: str | None,
82
83
    file_group: str,
) -> int:
84
85
    """
    Run mypy on the given targets.
86

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    Args:
        targets: List of files or directories to check.
        python_version: Python version to use (e.g., "3.10") or None to use
            the default mypy version.
        follow_imports: Value for the --follow-imports option or None to use
            the default mypy behavior.
        file_group: The file group name for logging purposes.

    Returns:
        The return code from mypy.
    """
    args = ["mypy"]
    if python_version is not None:
        args += ["--python-version", python_version]
    if follow_imports is not None:
        args += ["--follow-imports", follow_imports]
    print(f"$ {' '.join(args)} {file_group}")
    return subprocess.run(args + targets, check=False).returncode


def main():
    ci = sys.argv[1] == "1"
    python_version = sys.argv[2]
    file_groups = group_files(sys.argv[3:])

    if python_version == "local":
        python_version = f"{sys.version_info.major}.{sys.version_info.minor}"

    returncode = 0
    for file_group, changed_files in file_groups.items():
        follow_imports = None if ci and file_group == "" else "skip"
        if changed_files:
119
120
121
            returncode |= mypy(
                changed_files, python_version, follow_imports, file_group
            )
122
123
124
125
126
    return returncode


if __name__ == "__main__":
    sys.exit(main())