check_pickle_imports.py 3.71 KB
Newer Older
1
2
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
3
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import sys

import regex as re

# List of files (relative to repo root) that are allowed to import pickle or
# cloudpickle
#
# STOP AND READ BEFORE YOU ADD ANYTHING ELSE TO THIS LIST:
#  The pickle and cloudpickle modules are known to be unsafe when deserializing
#  data from potentially untrusted parties. They have resulted in multiple CVEs
#  for vLLM and numerous vulnerabilities in the Python ecosystem more broadly.
#  Before adding new uses of pickle/cloudpickle, please consider safer
#  alternatives like msgpack or pydantic that are already in use in vLLM. Only
#  add to this list if absolutely necessary and after careful security review.
18
ALLOWED_FILES = {
19
    # pickle
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    "vllm/v1/serial_utils.py",
    "vllm/v1/executor/multiproc_executor.py",
    "vllm/multimodal/hasher.py",
    "vllm/transformers_utils/config.py",
    "vllm/model_executor/models/registry.py",
    "tests/utils_/test_utils.py",
    "tests/tokenization/test_cached_tokenizer.py",
    "vllm/distributed/utils.py",
    "vllm/distributed/parallel_state.py",
    "vllm/distributed/device_communicators/all_reduce_utils.py",
    "vllm/distributed/device_communicators/shm_broadcast.py",
    "vllm/distributed/device_communicators/shm_object_storage.py",
    "benchmarks/kernels/graph_machete_bench.py",
    "benchmarks/kernels/benchmark_lora.py",
    "benchmarks/kernels/benchmark_machete.py",
    "benchmarks/fused_kernels/layernorm_rms_benchmarks.py",
    "benchmarks/cutlass_benchmarks/w8a8_benchmarks.py",
    "benchmarks/cutlass_benchmarks/sparse_benchmarks.py",
38
    # cloudpickle
39
40
41
42
    "vllm/executor/mp_distributed_executor.py",
    "vllm/executor/ray_distributed_executor.py",
    "vllm/entrypoints/llm.py",
    "tests/utils.py",
43
    # pickle and cloudpickle
44
    "vllm/utils/__init__.py",
45
}
46

47
48
49
50
PICKLE_RE = re.compile(
    r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)"
    r"|from\s+(pickle|cloudpickle)\s+import\b)"
)
51
52


53
def scan_file(path: str) -> int:
54
    with open(path, encoding="utf-8") as f:
55
        for i, line in enumerate(f, 1):
56
            if PICKLE_RE.match(line):
57
58
59
60
61
                print(
                    f"{path}:{i}: "
                    "\033[91merror:\033[0m "  # red color
                    "Found pickle/cloudpickle import"
                )
62
63
                return 1
    return 0
64
65
66


def main():
67
68
69
70
71
72
    returncode = 0
    for filename in sys.argv[1:]:
        if filename in ALLOWED_FILES:
            continue
        returncode |= scan_file(filename)
    return returncode
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98


def test_regex():
    test_cases = [
        # Should match
        ("import pickle", True),
        ("import cloudpickle", True),
        ("import pickle as pkl", True),
        ("import cloudpickle as cpkl", True),
        ("from pickle import *", True),
        ("from cloudpickle import dumps", True),
        ("from pickle import dumps, loads", True),
        ("from cloudpickle import (dumps, loads)", True),
        ("    import pickle", True),
        ("\timport cloudpickle", True),
        ("from   pickle   import   loads", True),
        # Should not match
        ("import somethingelse", False),
        ("from somethingelse import pickle", False),
        ("# import pickle", False),
        ("print('import pickle')", False),
        ("import pickleas as asdf", False),
    ]
    for i, (line, should_match) in enumerate(test_cases):
        result = bool(PICKLE_RE.match(line))
        assert result == should_match, (
99
100
            f"Test case {i} failed: '{line}' (expected {should_match}, got {result})"
        )
101
102
103
    print("All regex tests passed.")


104
105
if __name__ == "__main__":
    if "--test-regex" in sys.argv:
106
107
        test_regex()
    else:
108
        sys.exit(main())