check_triton_import.py 2.62 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
import subprocess
import sys

6
7
import regex as re

8
9
10
11
12
13
14
15
16
FORBIDDEN_IMPORT_RE = re.compile(r"^(from|import)\s+triton(\s|\.|$)")

# the way allowed to import triton
ALLOWED_LINES = {
    "from vllm.triton_utils import triton",
    "from vllm.triton_utils import tl",
    "from vllm.triton_utils import tl, triton",
}

17
18
19
20
21
22
ALLOWED_FILES = {"vllm/triton_utils/importing.py"}


def is_allowed_file(current_file: str) -> bool:
    return current_file in ALLOWED_FILES

23
24
25
26
27
28
29
30
31
32
33

def is_forbidden_import(line: str) -> bool:
    stripped = line.strip()
    return bool(
        FORBIDDEN_IMPORT_RE.match(stripped)) and stripped not in ALLOWED_LINES


def parse_diff(diff: str) -> list[str]:
    violations = []
    current_file = None
    current_lineno = None
34
    skip_allowed_file = False
35
36
37
38

    for line in diff.splitlines():
        if line.startswith("+++ b/"):
            current_file = line[6:]
39
40
41
            skip_allowed_file = is_allowed_file(current_file)
        elif skip_allowed_file:
            continue
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
        elif line.startswith("@@"):
            match = re.search(r"\+(\d+)", line)
            if match:
                current_lineno = int(
                    match.group(1)) - 1  # next "+ line" is here
        elif line.startswith("+") and not line.startswith("++"):
            current_lineno += 1
            code_line = line[1:]
            if is_forbidden_import(code_line):
                violations.append(
                    f"{current_file}:{current_lineno}: {code_line.strip()}")
    return violations


def get_diff(diff_type: str) -> str:
    if diff_type == "staged":
        return subprocess.check_output(
            ["git", "diff", "--cached", "--unified=0"], text=True)
    elif diff_type == "unstaged":
        return subprocess.check_output(["git", "diff", "--unified=0"],
                                       text=True)
    else:
        raise ValueError(f"Unknown diff_type: {diff_type}")


def main():
    all_violations = []
    for diff_type in ["staged", "unstaged"]:
        try:
            diff_output = get_diff(diff_type)
            violations = parse_diff(diff_output)
            all_violations.extend(violations)
        except subprocess.CalledProcessError as e:
            print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr)

    if all_violations:
        print("❌ Forbidden direct `import triton` detected."
              " ➤ Use `from vllm.triton_utils import triton` instead.\n")
        for v in all_violations:
            print(f"❌ {v}")
        return 1
    return 0


if __name__ == "__main__":
    sys.exit(main())