mypy.py 4.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Run mypy on changed files.

This script is designed to be used as a pre-commit hook. It runs mypy
on files that have been changed. It groups files into different mypy calls
based on their directory to avoid import following issues.

Usage:
    python tools/pre_commit/mypy.py <ci> <python_version> <changed_files...>

Args:
    ci: "1" if running in CI, "0" otherwise. In CI, follow_imports is set to
        "silent" for the main group of files.
    python_version: Python version to use (e.g., "3.10") or "local" to use
        the local Python version.
    changed_files: List of changed files to check.
"""

import subprocess
import sys

import regex as re

FILES = [
    "vllm/*.py",
    "vllm/assets",
29
    "vllm/distributed",
30
    "vllm/engine",
31
    "vllm/entrypoints",
32
    "vllm/executor",
33
34
35
36
    "vllm/inputs",
    "vllm/logging_utils",
    "vllm/multimodal",
    "vllm/platforms",
37
    "vllm/plugins",
38
    "vllm/tokenizers",
39
40
41
    "vllm/transformers_utils",
    "vllm/triton_utils",
    "vllm/usage",
42
    "vllm/utils",
43
    "vllm/worker",
44
    "vllm/v1/attention",
45
46
    "vllm/v1/core",
    "vllm/v1/engine",
47
    "vllm/v1/executor",
48
    "vllm/v1/metrics",
49
    "vllm/v1/pool",
50
    "vllm/v1/sample",
51
    "vllm/v1/worker",
52
53
54
55
56
57
]

# After fixing errors resulting from changing follow_imports
# from "skip" to "silent", move the following directories to FILES
SEPARATE_GROUPS = [
    "tests",
58
    # v0 related
59
60
61
62
    "vllm/attention",
    "vllm/compilation",
    "vllm/lora",
    "vllm/model_executor",
63
64
65
66
    # v1 related
    "vllm/v1/kv_offload",
    "vllm/v1/spec_decode",
    "vllm/v1/structured_output",
67
68
69
70
]

# TODO(woosuk): Include the code from Megatron and HuggingFace.
EXCLUDE = [
71
    "vllm/engine/arg_utils.py",
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    "vllm/model_executor/parallel_utils",
    "vllm/model_executor/models",
    "vllm/model_executor/layers/fla/ops",
    # Ignore triton kernels in ops.
    "vllm/attention/ops",
]


def group_files(changed_files: list[str]) -> dict[str, list[str]]:
    """
    Group changed files into different mypy calls.

    Args:
        changed_files: List of changed files.

    Returns:
        A dictionary mapping file group names to lists of changed files.
    """
    exclude_pattern = re.compile(f"^{'|'.join(EXCLUDE)}.*")
    files_pattern = re.compile(f"^({'|'.join(FILES)}).*")
    file_groups = {"": []}
    file_groups.update({k: [] for k in SEPARATE_GROUPS})
    for changed_file in changed_files:
        # Skip files which should be ignored completely
        if exclude_pattern.match(changed_file):
            continue
        # Group files by mypy call
        if files_pattern.match(changed_file):
            file_groups[""].append(changed_file)
            continue
        else:
            for directory in SEPARATE_GROUPS:
                if re.match(f"^{directory}.*", changed_file):
                    file_groups[directory].append(changed_file)
                    break
    return file_groups


110
111
def mypy(
    targets: list[str],
112
113
    python_version: str | None,
    follow_imports: str | None,
114
115
    file_group: str,
) -> int:
116
117
    """
    Run mypy on the given targets.
118

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    Args:
        targets: List of files or directories to check.
        python_version: Python version to use (e.g., "3.10") or None to use
            the default mypy version.
        follow_imports: Value for the --follow-imports option or None to use
            the default mypy behavior.
        file_group: The file group name for logging purposes.

    Returns:
        The return code from mypy.
    """
    args = ["mypy"]
    if python_version is not None:
        args += ["--python-version", python_version]
    if follow_imports is not None:
        args += ["--follow-imports", follow_imports]
    print(f"$ {' '.join(args)} {file_group}")
    return subprocess.run(args + targets, check=False).returncode


def main():
    ci = sys.argv[1] == "1"
    python_version = sys.argv[2]
    file_groups = group_files(sys.argv[3:])

    if python_version == "local":
        python_version = f"{sys.version_info.major}.{sys.version_info.minor}"

    returncode = 0
    for file_group, changed_files in file_groups.items():
        follow_imports = None if ci and file_group == "" else "skip"
        if changed_files:
151
152
153
            returncode |= mypy(
                changed_files, python_version, follow_imports, file_group
            )
154
155
156
157
158
    return returncode


if __name__ == "__main__":
    sys.exit(main())