main.py 5.39 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""The CLI entrypoints of vLLM
4

5
Note that all future modules must be lazily loaded within main
6
7
to avoid certain eager import breakage."""

8
import contextlib
9
import importlib.metadata
10
import os
11
import sys
12
import threading as _threading
13

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

# [startup] Kick off torch + transformers .so/module loading in a background
# thread before we touch vllm.logger (which pulls vllm/__init__.py ->
# vllm.env_override -> `import torch` on the main thread). Python import
# lock serializes the same-module import across threads, but the .so dlopen
# inside torch's init releases the GIL during file I/O. Main thread's
# non-torch imports (vllm.envs submodules, stdlib, fastapi, etc.) can make
# progress on the CPU while the background thread pays the ~2 s of cuda
# .so loading. `import transformers` is also ~2 s of cold-disk work and
# depends on torch; chain it after torch in the same thread so subsequent
# `from transformers import ...` lines on the main thread hit a warm
# module cache.
def _bg_preload_torch() -> None:
    try:
        import torch  # noqa: F401
    except Exception:
        return
    with contextlib.suppress(Exception):
        import transformers  # noqa: F401


_threading.Thread(
    target=_bg_preload_torch, daemon=True, name="vllm-torch-preload"
).start()


# [startup] Pre-spawn EngineCore via forkserver preload, in a background
# thread. Only fires for `vllm serve` (the only subcommand that spawns a
# long-running EngineCore). The forkserver process is forked once and
# preloaded with vllm.v1.engine.async_llm (~3-5 s of imports). When
# AsyncLLM.from_vllm_config later runs, Process.start() forks from the
# already-warm forkserver instead of paying spawn() cost (~5 s in child
# for fresh Python + imports).
#
# Kicking the preload in a BG thread lets the ~3-5 s ensure_running cost
# overlap with APIServer's argparse + config resolution (~5-10 s on cold
# disk). Default cli_env_setup sets spawn; we override to forkserver
# before that runs so the path is consistent.
def _bg_prewarm_forkserver() -> None:
    try:
        import multiprocessing
        import multiprocessing.forkserver as forkserver

        # set_start_method MUST be called before ensure_running. It also
        # can only be called once per process; any later override by
        # vllm's build_async_engine_client will just see the existing
        # setting.
        multiprocessing.set_start_method("forkserver", force=False)
        multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
        forkserver.ensure_running()
    except Exception:
        pass


if len(sys.argv) > 1 and sys.argv[1] == "serve":
    os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "forkserver")
    # daemon=True so early CLI exits (bad args, --help, import errors)
    # don't hang waiting for ensure_running(). The forkserver subprocess
    # itself is tracked by module-level state in multiprocessing.forkserver
    # and survives this thread exiting; subsequent spawn() calls reuse it.
    _threading.Thread(
        target=_bg_prewarm_forkserver,
        daemon=True,
        name="vllm-forkserver-prewarm",
    ).start()


from vllm.logger import init_logger  # noqa: E402
82
83

logger = init_logger(__name__)
84
85
86


def main():
87
88
    import vllm.entrypoints.cli.benchmark.main
    import vllm.entrypoints.cli.collect_env
89
    import vllm.entrypoints.cli.launch
90
91
92
93
    import vllm.entrypoints.cli.openai
    import vllm.entrypoints.cli.run_batch
    import vllm.entrypoints.cli.serve
    from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG, cli_env_setup
94
    from vllm.utils.argparse_utils import FlexibleArgumentParser
95
96
97
98

    CMD_MODULES = [
        vllm.entrypoints.cli.openai,
        vllm.entrypoints.cli.serve,
99
        vllm.entrypoints.cli.launch,
100
101
102
103
104
        vllm.entrypoints.cli.benchmark.main,
        vllm.entrypoints.cli.collect_env,
        vllm.entrypoints.cli.run_batch,
    ]

105
    cli_env_setup()
106

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
    # For 'vllm bench *': use CPU instead of UnspecifiedPlatform by default
    if len(sys.argv) > 1 and sys.argv[1] == "bench":
        logger.debug(
            "Bench command detected, must ensure current platform is not "
            "UnspecifiedPlatform to avoid device type inference error"
        )
        from vllm import platforms

        if platforms.current_platform.is_unspecified():
            from vllm.platforms.cpu import CpuPlatform

            platforms.current_platform = CpuPlatform()
            logger.info(
                "Unspecified platform detected, switching to CPU Platform instead."
            )

123
124
    parser = FlexibleArgumentParser(
        description="vLLM CLI",
125
        epilog=VLLM_SUBCMD_PARSER_EPILOG.format(subcmd="[subcommand]"),
126
    )
127
    parser.add_argument(
128
129
130
131
        "-v",
        "--version",
        action="version",
        version=importlib.metadata.version("vllm"),
132
    )
133
134
135
136
137
    subparsers = parser.add_subparsers(required=False, dest="subparser")
    cmds = {}
    for cmd_module in CMD_MODULES:
        new_cmds = cmd_module.cmd_init()
        for cmd in new_cmds:
138
            cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd)
139
140
141
142
143
144
145
146
147
148
149
150
151
            cmds[cmd.name] = cmd
    args = parser.parse_args()
    if args.subparser in cmds:
        cmds[args.subparser].validate(args)

    if hasattr(args, "dispatch_function"):
        args.dispatch_function(args)
    else:
        parser.print_help()


if __name__ == "__main__":
    main()