check_pickle_imports.py 3.7 KB
Newer Older
1
2
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
3
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import sys

import regex as re

# List of files (relative to repo root) that are allowed to import pickle or
# cloudpickle
#
# STOP AND READ BEFORE YOU ADD ANYTHING ELSE TO THIS LIST:
#  The pickle and cloudpickle modules are known to be unsafe when deserializing
#  data from potentially untrusted parties. They have resulted in multiple CVEs
#  for vLLM and numerous vulnerabilities in the Python ecosystem more broadly.
#  Before adding new uses of pickle/cloudpickle, please consider safer
#  alternatives like msgpack or pydantic that are already in use in vLLM. Only
#  add to this list if absolutely necessary and after careful security review.
18
ALLOWED_FILES = {
19
20
21
22
23
24
    # pickle
    'vllm/v1/serial_utils.py',
    'vllm/v1/executor/multiproc_executor.py',
    'vllm/multimodal/hasher.py',
    'vllm/transformers_utils/config.py',
    'vllm/model_executor/models/registry.py',
25
    'tests/utils_/test_utils.py',
26
27
28
    'tests/tokenization/test_cached_tokenizer.py',
    'vllm/distributed/utils.py',
    'vllm/distributed/parallel_state.py',
29
    'vllm/distributed/device_communicators/all_reduce_utils.py',
30
    'vllm/distributed/device_communicators/shm_broadcast.py',
31
    'vllm/distributed/device_communicators/shm_object_storage.py',
32
33
34
35
36
37
38
39
40
41
42
43
    'benchmarks/kernels/graph_machete_bench.py',
    'benchmarks/kernels/benchmark_lora.py',
    'benchmarks/kernels/benchmark_machete.py',
    'benchmarks/fused_kernels/layernorm_rms_benchmarks.py',
    'benchmarks/cutlass_benchmarks/w8a8_benchmarks.py',
    'benchmarks/cutlass_benchmarks/sparse_benchmarks.py',
    # cloudpickle
    'vllm/executor/mp_distributed_executor.py',
    'vllm/executor/ray_distributed_executor.py',
    'vllm/entrypoints/llm.py',
    'tests/utils.py',
    # pickle and cloudpickle
44
    'vllm/utils/__init__.py',
45
}
46
47
48
49
50

PICKLE_RE = re.compile(r"^\s*(import\s+(pickle|cloudpickle)(\s|$|\sas)"
                       r"|from\s+(pickle|cloudpickle)\s+import\b)")


51
def scan_file(path: str) -> int:
52
    with open(path, encoding='utf-8') as f:
53
        for i, line in enumerate(f, 1):
54
            if PICKLE_RE.match(line):
55
56
57
58
59
                print(f"{path}:{i}: "
                      "\033[91merror:\033[0m "  # red color
                      "Found pickle/cloudpickle import")
                return 1
    return 0
60
61
62


def main():
63
64
65
66
67
68
    returncode = 0
    for filename in sys.argv[1:]:
        if filename in ALLOWED_FILES:
            continue
        returncode |= scan_file(filename)
    return returncode
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103


def test_regex():
    test_cases = [
        # Should match
        ("import pickle", True),
        ("import cloudpickle", True),
        ("import pickle as pkl", True),
        ("import cloudpickle as cpkl", True),
        ("from pickle import *", True),
        ("from cloudpickle import dumps", True),
        ("from pickle import dumps, loads", True),
        ("from cloudpickle import (dumps, loads)", True),
        ("    import pickle", True),
        ("\timport cloudpickle", True),
        ("from   pickle   import   loads", True),
        # Should not match
        ("import somethingelse", False),
        ("from somethingelse import pickle", False),
        ("# import pickle", False),
        ("print('import pickle')", False),
        ("import pickleas as asdf", False),
    ]
    for i, (line, should_match) in enumerate(test_cases):
        result = bool(PICKLE_RE.match(line))
        assert result == should_match, (
            f"Test case {i} failed: '{line}' "
            f"(expected {should_match}, got {result})")
    print("All regex tests passed.")


if __name__ == '__main__':
    if '--test-regex' in sys.argv:
        test_regex()
    else:
104
        sys.exit(main())