mypy.py 5.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Run mypy on changed files.

This script is designed to be used as a pre-commit hook. It runs mypy
on files that have been changed. It groups files into different mypy calls
based on their directory to avoid import following issues.

Usage:
    python tools/pre_commit/mypy.py <ci> <python_version> <changed_files...>

Args:
    ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to
        "silent" for the main group of files.
    python_version: Python version to use (e.g., "3.10") or "local" to use
        the local Python version.
    changed_files: List of changed files to check.
"""

import subprocess
import sys

import regex as re

FILES = [
    "vllm/*.py",
    "vllm/assets",
29
    "vllm/distributed",
30
    "vllm/engine",
31
    "vllm/entrypoints",
32
    "vllm/executor",
33
34
35
36
    "vllm/inputs",
    "vllm/logging_utils",
    "vllm/multimodal",
    "vllm/platforms",
37
    "vllm/plugins",
38
    "vllm/renderers",
39
    "vllm/tokenizers",
40
41
42
    "vllm/transformers_utils",
    "vllm/triton_utils",
    "vllm/usage",
43
    "vllm/utils",
44
    "vllm/worker",
45
    "vllm/v1/attention",
46
47
    "vllm/v1/core",
    "vllm/v1/engine",
48
    "vllm/v1/executor",
49
    "vllm/v1/metrics",
50
    "vllm/v1/pool",
51
    "vllm/v1/sample",
52
    "vllm/v1/structured_output",
53
    "vllm/v1/worker",
54
55
56
57
58
59
]

# After fixing errors resulting from changing follow_imports
# from "skip" to "silent", move the following directories to FILES
SEPARATE_GROUPS = [
    "tests",
60
    # v0 related
61
62
63
64
    "vllm/attention",
    "vllm/compilation",
    "vllm/lora",
    "vllm/model_executor",
65
66
67
    # v1 related
    "vllm/v1/kv_offload",
    "vllm/v1/spec_decode",
68
69
70
71
]

# TODO(woosuk): Include the code from Megatron and HuggingFace.
EXCLUDE = [
72
    "vllm/engine/arg_utils.py",
73
74
75
76
    "vllm/model_executor/parallel_utils",
    "vllm/model_executor/models",
    "vllm/model_executor/layers/fla/ops",
    # Ignore triton kernels in ops.
77
    "vllm/v1/attention/ops",
78
79
]

80
81
82
83
84
# Directories that should be checked with --strict
STRICT_DIRS = [
    "vllm/compilation",
]

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

def group_files(changed_files: list[str]) -> dict[str, list[str]]:
    """
    Group changed files into different mypy calls.

    Args:
        changed_files: List of changed files.

    Returns:
        A dictionary mapping file group names to lists of changed files.
    """
    exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*")
    files_pattern = re.compile(f"^({'|'.join(FILES)}).*")
    file_groups = {"": []}
    file_groups.update({k: [] for k in SEPARATE_GROUPS})
    for changed_file in changed_files:
        # Skip files which should be ignored completely
        if exclude_pattern.match(changed_file):
            continue
        # Group files by mypy call
        if files_pattern.match(changed_file):
            file_groups[""].append(changed_file)
            continue
        else:
            for directory in SEPARATE_GROUPS:
                if re.match(f"^{directory}.*", changed_file):
                    file_groups[directory].append(changed_file)
                    break
    return file_groups


116
117
118
119
120
def is_strict_file(filepath: str) -> bool:
    """Check if a file should be checked with strict mode."""
    return any(filepath.startswith(strict_dir) for strict_dir in STRICT_DIRS)


121
122
def mypy(
    targets: list[str],
123
124
    python_version: str | None,
    follow_imports: str | None,
125
    file_group: str,
126
    strict: bool = False,
127
) -> int:
128
129
    """
    Run mypy on the given targets.
130

131
132
133
134
135
136
137
    Args:
        targets: List of files or directories to check.
        python_version: Python version to use (e.g., "3.10") or None to use
            the default mypy version.
        follow_imports: Value for the --follow-imports option or None to use
            the default mypy behavior.
        file_group: The file group name for logging purposes.
138
        strict: If True, run mypy with --strict flag.
139
140
141
142
143
144
145
146
147

    Returns:
        The return code from mypy.
    """
    args = ["mypy"]
    if python_version is not None:
        args += ["--python-version", python_version]
    if follow_imports is not None:
        args += ["--follow-imports", follow_imports]
148
149
    if strict:
        args += ["--strict"]
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    print(f"$ {' '.join(args)} {file_group}")
    return subprocess.run(args + targets, check=False).returncode


def main():
    ci = sys.argv[1] == "1"
    python_version = sys.argv[2]
    file_groups = group_files(sys.argv[3:])

    if python_version == "local":
        python_version = f"{sys.version_info.major}.{sys.version_info.minor}"

    returncode = 0
    for file_group, changed_files in file_groups.items():
        follow_imports = None if ci and file_group == "" else "skip"
        if changed_files:
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
            # Separate files into strict and non-strict groups
            strict_files = [f for f in changed_files if is_strict_file(f)]
            non_strict_files = [f for f in changed_files if not is_strict_file(f)]

            # Run mypy on non-strict files
            if non_strict_files:
                returncode |= mypy(
                    non_strict_files,
                    python_version,
                    follow_imports,
                    file_group,
                    strict=False,
                )

            # Run mypy on strict files with --strict flag
            if strict_files:
                returncode |= mypy(
                    strict_files,
                    python_version,
                    follow_imports,
                    f"{file_group} (strict)",
                    strict=True,
                )
189
190
191
192
193
    return returncode


if __name__ == "__main__":
    sys.exit(main())