generate_api_readme.py 9.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-Omni project
"""
Hook to automatically generate docs/api/README.md from the codebase.

This script scans the vllm_omni module for public classes and functions,
categorizes them, and generates a summary README file.
"""

import ast
import logging
from pathlib import Path

logger = logging.getLogger("mkdocs")

ROOT_DIR = Path(__file__).parent.parent.parent.parent
API_README_PATH = ROOT_DIR / "docs" / "api" / "README.md"

# Category mappings: module prefix -> category name and description
CATEGORIES = {
    "entrypoints": {
        "name": "Entry Points",
        "description": "Main entry points for vLLM-Omni inference and serving.",
    },
    "inputs": {
        "name": "Inputs",
        "description": "Input data structures for multi-modal inputs.",
    },
    "engine": {
        "name": "Engine",
        "description": "Engine classes for offline and online inference.",
    },
    "core": {
        "name": "Core",
        "description": "Core scheduling and caching components.",
    },
    # "model_executor": {
    #     "name": "Model Executor",
    #     "description": "Model execution components.",
    # },
    "config": {
        "name": "Configuration",
        "description": "Configuration classes.",
    },
    "worker": {
        "name": "Workers",
        "description": "Worker classes and model runners for distributed inference.",
    },
}


class APIVisitor(ast.NodeVisitor):
    """AST visitor to extract public classes and module-level functions."""

    def __init__(self, module_path: str):
        self.module_path = module_path
        self.classes: list[str] = []
        self.functions: list[str] = []
        self._class_stack: list[str] = []  # Track nested class definitions

    def visit_ClassDef(self, node: ast.ClassDef) -> None:
        """Visit class definitions."""
        if not node.name.startswith("_"):
            self.classes.append(f"{self.module_path}.{node.name}")
        # Track that we're entering a class
        self._class_stack.append(node.name)
        self.generic_visit(node)
        # Remove from stack when done visiting
        self._class_stack.pop()

    def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
        """Visit function definitions - only collect module-level functions."""
        # Only collect if we're not inside a class (stack is empty)
        if not self._class_stack and not node.name.startswith("_"):
            self.functions.append(f"{self.module_path}.{node.name}")
        self.generic_visit(node)

    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
        """Visit async function definitions - only collect module-level functions."""
        # Only collect if we're not inside a class (stack is empty)
        if not self._class_stack and not node.name.startswith("_"):
            self.functions.append(f"{self.module_path}.{node.name}")
        self.generic_visit(node)


def parse_file_for_symbols(file_path: Path, module_path: str) -> tuple[list[str], list[str]]:
    """
    Parse a Python file and extract public classes and functions.

    Returns:
        Tuple of (classes, functions)
    """
    try:
        # If this is __init__.py, use parent module path
        if file_path.name == "__init__.py":
            # Remove .__init__ from module path
            if module_path.endswith(".__init__"):
                module_path = module_path[:-9]

        with open(file_path, encoding="utf-8") as f:
            content = f.read()

        tree = ast.parse(content, filename=str(file_path))
        visitor = APIVisitor(module_path)
        visitor.visit(tree)

        return visitor.classes, visitor.functions
    except Exception as e:
        logger.debug(f"Could not parse {file_path}: {e}")
        return [], []


def scan_package(package_name: str = "vllm_omni") -> dict[str, list[str]]:
    """
    Scan the vllm_omni package and categorize public symbols.

    Returns:
        Dict mapping category names to lists of symbol full names
    """
    categorized: dict[str, list[str]] = {cat["name"]: [] for cat in CATEGORIES.values()}

    try:
        # Find the package directory
        package_path = ROOT_DIR / package_name
        if not package_path.exists():
            logger.warning(f"Package path not found: {package_path}")
            return categorized

        # Walk through all Python files
        for py_file in package_path.rglob("*.py"):
            # Skip __init__.py and private modules
            if py_file.name.startswith("_") and py_file.name != "__init__.py":
                continue

            # Get module path
            relative_path = py_file.relative_to(ROOT_DIR)
            module_path = str(relative_path.with_suffix("")).replace("/", ".").replace("\\", ".")

            # Skip excluded modules (avoid importing vllm during docs build)
            excluded_prefixes = [
                "vllm_omni.diffusion.models.qwen_image",
                "vllm_omni.entrypoints.async_diffusion",
                "vllm_omni.entrypoints.openai",
            ]
            if any(module_path.startswith(prefix) for prefix in excluded_prefixes):
                continue

            # Handle __init__.py - use parent module path
            if py_file.name == "__init__.py":
                # Remove .__init__ from module path
                if module_path.endswith(".__init__"):
                    module_path = module_path[:-9]

            # Determine category from module path
            category = None
            for prefix, cat_info in CATEGORIES.items():
                if prefix in module_path:
                    category = cat_info["name"]
                    break

            if not category:
                continue

            # Parse file for symbols
            classes, functions = parse_file_for_symbols(py_file, module_path)

            # Filter out internal implementation classes
            # Skip classes that look like internal components (DiT layers, etc.)
            internal_patterns = [
                "Block",
                "Layer",
                "Net",
                "Embedding",
                "Norm",
                "Activation",
                "Solver",
                "Pooling",
                "Attention",
                "MLP",
                "DecoderLayer",
                "InputEmbedding",
                "TimestepEmbedding",
                "CodecEmbedding",
                "DownSample",
                "UpSample",
                "Res2Net",
                "SqueezeExcitation",
                "TimeDelay",
                "TorchActivation",
                "SnakeBeta",
                "SinusPosition",
                "RungeKutta",
                "AMPBlock",
                "AdaLayerNorm",
            ]

            # Add classes (filter out internal ones)
            for class_name in classes:
                class_short_name = class_name.split(".")[-1]
                # Skip if it matches internal patterns (unless it's a main model class)
                if any(pattern in class_short_name for pattern in internal_patterns):
                    # But include main model classes
                    if not any(
                        main_class in class_short_name
                        for main_class in [
                            "ForConditionalGeneration",
                            "Model",
                            "Registry",
                            "Worker",
                            "Runner",
                            "Scheduler",
                            "Manager",
                            "Processor",
                            "Config",
                        ]
                    ):
                        continue
                categorized[category].append(class_name)

            # Add important functions (parse, preprocess, etc.)
            for func_name in functions:
                # Include functions that match certain patterns
                if any(keyword in func_name.lower() for keyword in ["parse", "preprocess"]):
                    categorized[category].append(func_name)

        # Sort symbols within each category
        for category in categorized:
            categorized[category].sort()

    except Exception as e:
        logger.error(f"Error scanning package: {e}", exc_info=True)

    return categorized


def generate_readme(categorized: dict[str, list[str]]) -> str:
    """Generate the API README markdown content."""
    lines = ["# Summary", ""]

    # Generate sections for each category
    for prefix, cat_info in CATEGORIES.items():
        category_name = cat_info["name"]
        description = cat_info["description"]
        symbols = categorized.get(category_name, [])

        if not symbols:
            continue

        lines.append(f"## {category_name}")
        lines.append("")
        lines.append(description)
        lines.append("")

        for symbol in symbols:
            lines.append(f"- [{symbol}][]")

        lines.append("")

    return "\n".join(lines)


def on_startup(command, dirty: bool):
    """MkDocs hook entry point."""
    logger.info("Generating API README documentation")

    # Scan the package
    categorized = scan_package()

    # Generate README content
    content = generate_readme(categorized)

    # Write to file
    API_README_PATH.parent.mkdir(parents=True, exist_ok=True)
    with open(API_README_PATH, "w", encoding="utf-8") as f:
        f.write(content)

    logger.info(f"API README generated: {API_README_PATH.relative_to(ROOT_DIR)}")