"vllm/vscode:/vscode.git/clone" did not exist on "dc6de33c3d5e9026cef7b27791dfe0f98e64bbde"
check_torch_cuda.py 2.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys

import regex as re

# --------------------------------------------------------------------------- #
# Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx`
# --------------------------------------------------------------------------- #
_TORCH_CUDA_PATTERNS = [
11
    r"\btorch\.cuda\.(empty_cache|synchronize|device_count|current_device|memory_reserved|memory_allocated|max_memory_allocated|max_memory_reserved|reset_peak_memory_stats|memory_stats|set_device|device\()\b",
12
    r"\btorch\.cuda\.(manual_seed|manual_seed_all)\b",
13
    r"\bwith\storch\.cuda\.device\b",
14
15
    # Calls torch.cuda.{_is_compiled/_device_count_amdsmi/_device_count_nvml} internally
    r"\bcuda_device_count_stateless\(\)\b",
16
17
18
19
20
21
22
23
24
25
26
27
]

ALLOWED_FILES = {"vllm/platforms/", "vllm/device_allocator/"}


def scan_file(path: str) -> int:
    with open(path, encoding="utf-8") as f:
        content = f.read()
    for pattern in _TORCH_CUDA_PATTERNS:
        for match in re.finditer(pattern, content, re.MULTILINE):
            # Calculate line number from match position
            line_num = content[: match.start() + 1].count("\n") + 1
28
29
30
31
32
33
34
35
            matched_text = match.group(0)
            if "manual_seed" in matched_text:
                print(
                    f"{path}:{line_num}: "
                    "\033[91merror:\033[0m "
                    f"Found {matched_text} API call. Use set_random_seed instead."
                )
                return 1
36
37
38
            print(
                f"{path}:{line_num}: "
                "\033[91merror:\033[0m "  # red color
39
40
41
                "Found torch.cuda API call. Please refer RFC "
                "https://github.com/vllm-project/vllm/issues/30679, use "
                "torch.accelerator API instead."
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
            )
            return 1
    return 0


def main():
    returncode = 0
    for filename in sys.argv[1:]:
        if any(filename.startswith(prefix) for prefix in ALLOWED_FILES):
            continue
        returncode |= scan_file(filename)
    return returncode


if __name__ == "__main__":
    sys.exit(main())