generate_argparse.py 9.32 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import importlib.metadata
import importlib.util
5
6
import logging
import sys
7
import traceback
8
9
10
from argparse import SUPPRESS, Action, HelpFormatter
from collections.abc import Iterable
from importlib.machinery import ModuleSpec
11
from pathlib import Path
12
from typing import TYPE_CHECKING, Literal
13
14
from unittest.mock import MagicMock, patch

15
16
17
18
from pydantic_core import core_schema

logger = logging.getLogger("mkdocs")

19
ROOT_DIR = Path(__file__).parent.parent.parent.parent
20
ARGPARSE_DOC_DIR = ROOT_DIR / "docs/generated/argparse"
21
22

sys.path.insert(0, str(ROOT_DIR))
23
24


25
26
27
28
29
def mock_if_no_torch(mock_module: str, mock: MagicMock):
    if not importlib.util.find_spec("torch"):
        sys.modules[mock_module] = mock


30
31
32
33
34
35
36
37
38
39
# Mock custom op code
class MockCustomOp:
    @staticmethod
    def register(name):
        def decorator(cls):
            return cls

        return decorator


40
41
42
43
44
45
mock_if_no_torch("vllm._C", MagicMock())
mock_if_no_torch("vllm.model_executor.custom_op", MagicMock(CustomOp=MockCustomOp))
mock_if_no_torch(
    "vllm.utils.torch_utils", MagicMock(direct_register_custom_op=lambda *a, **k: None)
)

46
47
48
49
50
51

# Mock any version checks by reading from compiled CI requirements
with open(ROOT_DIR / "requirements/test.txt") as f:
    VERSIONS = dict(line.strip().split("==") for line in f if "==" in line)
importlib.metadata.version = lambda name: VERSIONS.get(name) or "0.0.0"

52

53
# Make torch.nn.Parameter safe to inherit from
54
mock_if_no_torch("torch.nn", MagicMock(Parameter=object))
55
56


57
58
59
class PydanticMagicMock(MagicMock):
    """`MagicMock` that's able to generate pydantic-core schemas."""

60
61
62
    def __init__(self, *args, **kwargs):
        name = kwargs.pop("name", None)
        super().__init__(*args, **kwargs)
63
        self.__spec__ = ModuleSpec(name, None)
64

65
66
67
68
    def __get_pydantic_core_schema__(self, source_type, handler):
        return core_schema.any_schema()


69
def auto_mock(module_name: str, attr: str, max_mocks: int = 100):
70
    """Function that automatically mocks missing modules during imports."""
71
72
    logger.info("Importing %s from %s", attr, module_name)

73
74
    for _ in range(max_mocks):
        try:
75
76
            module = importlib.import_module(module_name)

77
            # First treat attr as an attr, then as a submodule
78
79
80
81
            if hasattr(module, attr):
                return getattr(module, attr)

            return importlib.import_module(f"{module_name}.{attr}")
82
        except ModuleNotFoundError as e:
83
            assert e.name is not None
84
            logger.info("Mocking %s for argparse doc generation", e.name)
85
            sys.modules[e.name] = PydanticMagicMock(name=e.name)
86
87
        except Exception:
            logger.exception("Failed to import %s.%s: %s", module_name, attr)
88
89

    raise ImportError(
90
        f"Failed to import {module_name}.{attr} after mocking {max_mocks} imports"
91
    )
92
93


94
bench_latency = auto_mock("vllm.benchmarks", "latency")
95
bench_mm_processor = auto_mock("vllm.benchmarks", "mm_processor")
96
97
bench_serve = auto_mock("vllm.benchmarks", "serve")
bench_sweep_plot = auto_mock("vllm.benchmarks.sweep.plot", "SweepPlotArgs")
98
99
100
bench_sweep_plot_pareto = auto_mock(
    "vllm.benchmarks.sweep.plot_pareto", "SweepPlotParetoArgs"
)
101
102
103
104
105
bench_sweep_serve = auto_mock("vllm.benchmarks.sweep.serve", "SweepServeArgs")
bench_sweep_serve_sla = auto_mock(
    "vllm.benchmarks.sweep.serve_sla", "SweepServeSLAArgs"
)
bench_throughput = auto_mock("vllm.benchmarks", "throughput")
106
107
108
109
AsyncEngineArgs = auto_mock("vllm.engine.arg_utils", "AsyncEngineArgs")
EngineArgs = auto_mock("vllm.engine.arg_utils", "EngineArgs")
ChatCommand = auto_mock("vllm.entrypoints.cli.openai", "ChatCommand")
CompleteCommand = auto_mock("vllm.entrypoints.cli.openai", "CompleteCommand")
110
111
openai_cli_args = auto_mock("vllm.entrypoints.openai", "cli_args")
openai_run_batch = auto_mock("vllm.entrypoints.openai", "run_batch")
112
113
114
115
116
117
118

if TYPE_CHECKING:
    from vllm.utils.argparse_utils import FlexibleArgumentParser
else:
    FlexibleArgumentParser = auto_mock(
        "vllm.utils.argparse_utils", "FlexibleArgumentParser"
    )
119
120
121
122
123


class MarkdownFormatter(HelpFormatter):
    """Custom formatter that generates markdown for argument groups."""

124
125
126
    def __init__(self, prog: str, starting_heading_level: int = 3):
        super().__init__(prog, max_help_position=sys.maxsize, width=sys.maxsize)

127
128
        self._section_heading_prefix = "#" * starting_heading_level
        self._argument_heading_prefix = "#" * (starting_heading_level + 1)
129
130
        self._markdown_output = []

131
    def start_section(self, heading: str):
132
        if heading not in {"positional arguments", "options"}:
133
134
            heading_md = f"\n{self._section_heading_prefix} {heading}\n\n"
            self._markdown_output.append(heading_md)
135
136
137
138

    def end_section(self):
        pass

139
    def add_text(self, text: str):
140
141
142
143
144
145
        if text:
            self._markdown_output.append(f"{text.strip()}\n\n")

    def add_usage(self, usage, actions, groups, prefix=None):
        pass

146
    def add_arguments(self, actions: Iterable[Action]):
147
        for action in actions:
148
            if len(action.option_strings) == 0 or "--help" in action.option_strings:
149
                continue
150

151
            option_strings = f"`{'`, `'.join(action.option_strings)}`"
152
153
            heading_md = f"{self._argument_heading_prefix} {option_strings}\n\n"
            self._markdown_output.append(heading_md)
154
155

            if choices := action.choices:
156
157
158
159
160
                choices = f"`{'`, `'.join(str(c) for c in choices)}`"
                self._markdown_output.append(f"Possible choices: {choices}\n\n")
            elif (metavar := action.metavar) and isinstance(metavar, (list, tuple)):
                metavar = f"`{'`, `'.join(str(m) for m in metavar)}`"
                self._markdown_output.append(f"Possible choices: {metavar}\n\n")
161

162
163
            if action.help:
                self._markdown_output.append(f"{action.help}\n\n")
164

165
166
            # None usually means the default is determined at runtime
            if (default := action.default) != SUPPRESS and default is not None:
167
168
169
                # Make empty string defaults visible
                if default == "":
                    default = '""'
170
171
172
173
174
175
176
                self._markdown_output.append(f"Default: `{default}`\n\n")

    def format_help(self):
        """Return the formatted help as markdown."""
        return "".join(self._markdown_output)


177
def create_parser(add_cli_args, **kwargs) -> FlexibleArgumentParser:
178
    """Create a parser for the given class with markdown formatting.
179

180
181
182
183
184
185
186
    Args:
        cls: The class to create a parser for
        **kwargs: Additional keyword arguments to pass to `cls.add_cli_args`.

    Returns:
        FlexibleArgumentParser: A parser with markdown formatting for the class.
    """
187
188
189
190
191
192
193
194
195
    try:
        parser = FlexibleArgumentParser(add_json_tip=False)
        parser.formatter_class = MarkdownFormatter
        with patch("vllm.config.DeviceConfig.__post_init__"):
            _parser = add_cli_args(parser, **kwargs)
    except ModuleNotFoundError as e:
        # Auto-mock runtime imports
        if tb_list := traceback.extract_tb(e.__traceback__):
            path = Path(tb_list[-1].filename).relative_to(ROOT_DIR)
196
            auto_mock(module_name=".".join(path.parent.parts), attr=path.stem)
197
198
199
            return create_parser(add_cli_args, **kwargs)
        else:
            raise e
200
201
    # add_cli_args might be in-place so return parser if _parser is None
    return _parser or parser
202
203


204
205
206
207
208
209
210
211
212
213
214
def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
    logger.info("Generating argparse documentation")
    logger.debug("Root directory: %s", ROOT_DIR.resolve())
    logger.debug("Output directory: %s", ARGPARSE_DOC_DIR.resolve())

    # Create the ARGPARSE_DOC_DIR if it doesn't exist
    if not ARGPARSE_DOC_DIR.exists():
        ARGPARSE_DOC_DIR.mkdir(parents=True)

    # Create parsers to document
    parsers = {
215
        # Engine args
216
217
218
219
        "engine_args": create_parser(EngineArgs.add_cli_args),
        "async_engine_args": create_parser(
            AsyncEngineArgs.add_cli_args, async_args_only=True
        ),
220
221
        # CLI
        "serve": create_parser(openai_cli_args.make_arg_parser),
222
223
        "chat": create_parser(ChatCommand.add_cli_args),
        "complete": create_parser(CompleteCommand.add_cli_args),
224
225
226
        "run-batch": create_parser(openai_run_batch.make_arg_parser),
        # Benchmark CLI
        "bench_latency": create_parser(bench_latency.add_cli_args),
227
        "bench_mm_processor": create_parser(bench_mm_processor.add_cli_args),
228
229
        "bench_serve": create_parser(bench_serve.add_cli_args),
        "bench_sweep_plot": create_parser(bench_sweep_plot.add_cli_args),
230
        "bench_sweep_plot_pareto": create_parser(bench_sweep_plot_pareto.add_cli_args),
231
232
233
        "bench_sweep_serve": create_parser(bench_sweep_serve.add_cli_args),
        "bench_sweep_serve_sla": create_parser(bench_sweep_serve_sla.add_cli_args),
        "bench_throughput": create_parser(bench_throughput.add_cli_args),
234
235
236
237
    }

    # Generate documentation for each parser
    for stem, parser in parsers.items():
238
        doc_path = ARGPARSE_DOC_DIR / f"{stem}.inc.md"
239
240
        # Specify encoding for building on Windows
        with open(doc_path, "w", encoding="utf-8") as f:
241
            f.write(super(type(parser), parser).format_help())
242
        logger.info("Argparse generated: %s", doc_path.relative_to(ROOT_DIR))
243
244
245
246


if __name__ == "__main__":
    on_startup("build", False)