"vscode:/vscode.git/clone" did not exist on "7b80cd8ac382851527225ed1f3475c138a4b7c01"
wrapper.py 14.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import os
import sys
from abc import abstractmethod
7
from collections.abc import Callable, Generator
8
from contextlib import contextmanager, nullcontext
9
from types import CodeType
10
from typing import Any, ParamSpec, TypeVar
11
12

import torch
13
import torch._C._dynamo.guards
14

15
import vllm.envs as envs
16
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
17
from vllm.config.compilation import DynamicShapesType
18
from vllm.logger import init_logger
19
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
20
21

logger = init_logger(__name__)
22

23
24
R = TypeVar("R")
P = ParamSpec("P")
25

26
27
28
29

def _noop_add_global_state_guard(
    self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
) -> None:
30
31
32
33
    """No-op to skip the GLOBAL_STATE guard entirely"""
    pass


34
35
36
def _noop_add_torch_function_mode_stack_guard(
    self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
) -> None:
37
38
39
40
41
    """No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
    pass


@contextmanager
42
def _compilation_context() -> Generator[None, None, None]:
43
44
45
46
47
48
49
50
51
52
53
    """Context manager for compilation settings and patches.

    This manager:
    1. Sets higher dynamo cache limits for compilation. (Needed for
        qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
        Generally a recompilation can happen whenever we use a new
        backend instance in torch.compile.
    2. Patches out add_global_state_guard to skip GLOBAL_STATE guards
    3. Patches out add_torch_function_mode_stack_guard to skip
        TORCH_FUNCTION_MODE_STACK guards.
    4. Restores everything when compilation completes
54
    """
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    # Save original values
    original_global_state_guard = (
        torch._C._dynamo.guards.GuardManager.add_global_state_guard
    )
    original_torch_function_mode_stack_guard = (
        torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard
    )
    original_cache_size = torch._dynamo.config.cache_size_limit
    original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit

    try:
        # Set higher cache limits for compilation
        torch._dynamo.config.cache_size_limit = 2048
        torch._dynamo.config.accumulated_cache_size_limit = 8192

        # Patch guard manager
        torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
            _noop_add_global_state_guard
        )
        torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
            _noop_add_torch_function_mode_stack_guard
        )
        yield
    finally:
        # Restore original values
        torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
            original_global_state_guard
        )
        torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
            original_torch_function_mode_stack_guard
        )
        torch._dynamo.config.cache_size_limit = original_cache_size
        torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache


class TorchCompileWithNoGuardsWrapper:
91
    """
92
93
94
95
96
97
98
    A wrapper class for torch.compile, it ensures that all guards are dropped
    when CompilationMode is not CompilationMode.STOCK_TORCH_COMPILE.
    When guards are dropped, the first time __call__ is invoked, a single
    compilation is triggered. Dynamo should never be traced again after that
    since we drop all guards.
    """

99
    def check_invariants_and_forward(self, *args: Any, **kwargs: Any) -> Any:
100
101
102
103
104
        assert hasattr(self, "_check_shape_invariants")
        self._check_shape_invariants(*args, **kwargs)

        return self.forward(*args, **kwargs)

105
106
107
    def _call_with_optional_nvtx_range(
        self, callable_fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs
    ) -> Any:
108
109
110
111
112
113
114
115
116
117
118
119
120
        if self.layerwise_nvtx_tracing_enabled:
            args_list = list(args)
            kwargs_dict = dict(kwargs)
            with layerwise_nvtx_marker_context(
                "Torch Compiled Module (input):{}".format(self.__class__.__name__),
                self,
                in_tensor=args_list,
                kwargs=kwargs_dict,
            ) as ctx:
                ctx.result = callable_fn(*args, **kwargs)
            return ctx.result
        return callable_fn(*args, **kwargs)

121
    def __init__(self) -> None:
122
        self.compiled = False
123

124
125
        vllm_config = get_current_vllm_config()
        self.vllm_config = vllm_config
126
        mode = vllm_config.compilation_config.mode
127
128
129
        self.layerwise_nvtx_tracing_enabled = (
            vllm_config.observability_config.enable_layerwise_nvtx_tracing
        )
130
131
132
133
134
135
136
137
138
        if mode is None:
            raise RuntimeError("Compilation mode cannot be NO_COMPILATION")

        backend = vllm_config.compilation_config.init_backend(vllm_config)
        options = {}

        if isinstance(backend, str) and backend == "inductor":
            options = vllm_config.compilation_config.inductor_compile_config

139
140
141
142
143
144
145
        self.first_compile = True
        self.evaluate_guards = (
            vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
        )

        ds_type = vllm_config.compilation_config.dynamic_shapes_config.type

146
147
        if mode != CompilationMode.STOCK_TORCH_COMPILE:
            # Drop all the guards.
148
149
150
151
152
            if self.evaluate_guards:
                assert not envs.VLLM_USE_BYTECODE_HOOK, (
                    "compilation_config.dynamic_shapes_config.evaluate_guards "
                    "requires VLLM_USE_BYTECODE_HOOK=0. "
                )
153

154
155
156
157
158
                options["guard_filter_fn"] = lambda x: [
                    entry.guard_type == "SHAPE_ENV" for entry in x
                ]
            else:
                options["guard_filter_fn"] = lambda x: [False for _ in x]
159
160

        compiled_ptr: Any = self.forward
161
162
        # Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False

163
        if ds_type == DynamicShapesType.UNBACKED:
164
165
166
167
168
169
170
171
172
173
            # reason is that bytecode does torch._dynamo.eval_frame.
            # remove_from_cache(self.original_code_object()) to force a new
            # re-compilation. And if we use
            # compiled_ptr = self.check_invariants_and_forward
            # it will reset all entries.
            assert not envs.VLLM_USE_BYTECODE_HOOK, (
                "UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
            )
            assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards"

174
175
            compiled_ptr = self.check_invariants_and_forward

176
        aot_context = nullcontext()
177
178
        if envs.VLLM_USE_AOT_COMPILE:
            if hasattr(torch._dynamo.config, "enable_aot_compile"):
179
                aot_context = torch._dynamo.config.patch(enable_aot_compile=True)
180
181
182
183
184
185
            else:
                msg = "torch._dynamo.config.enable_aot_compile is not "
                msg += "available. AOT compile is disabled and please "
                msg += "upgrade PyTorch version to use AOT compile."
                logger.warning(msg)

186
187
188
189
190
191
192
193
        with aot_context:
            self._compiled_callable = torch.compile(
                compiled_ptr,
                fullgraph=True,
                dynamic=False,
                backend=backend,
                options=options,
            )
194

195
196
        if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
            torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
197
            self._compiled_bytecode: CodeType | None = None
198

199
    def aot_compile(self, *args: Any, **kwargs: Any) -> Any:
200
        if not hasattr(self._compiled_callable, "aot_compile"):
201
202
            raise RuntimeError(
                "aot_compile is not supported by the current configuration. "
203
204
                "Please make sure torch.compile is enabled with the latest "
                f"version of PyTorch (current using torch: {torch.__version__})"
205
            )
206
        return self._compiled_callable.aot_compile((args, kwargs))
207

208
    def __call__(self, *args: Any, **kwargs: Any) -> Any:
209
210
211
212
213
214
215
216
217
218
219
        if envs.VLLM_USE_BYTECODE_HOOK:
            if (
                self.vllm_config.compilation_config.mode
                == CompilationMode.STOCK_TORCH_COMPILE
            ):
                return self._compiled_callable(*args, **kwargs)

            if not self._compiled_bytecode:
                # Make sure a compilation is triggered by clearing dynamo
                # cache.
                torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
220
221
222
                return self._call_with_optional_nvtx_range(
                    self._compiled_callable, *args, **kwargs
                )
223
224
            else:
                with self._dispatch_to_compiled_code():
225
226
227
                    return self._call_with_optional_nvtx_range(
                        self.forward, *args, **kwargs
                    )
228
        else:
229
230
231
232
233
234
235
            ctx = (
                nullcontext()
                if self.first_compile or not self.evaluate_guards
                else torch.compiler.set_stance("fail_on_recompile")
            )
            self.first_compile = False
            with _compilation_context(), ctx:
236
237
238
                return self._call_with_optional_nvtx_range(
                    self._compiled_callable, *args, **kwargs
                )
239
240

    @abstractmethod
241
    def forward(self, *args: Any, **kwargs: Any) -> Any: ...
242

243
244
245
246
    def original_code_object(self) -> CodeType:
        """Return the original code object of the forward method."""
        return self.__class__.forward.__code__

247
    def bytecode_hook(self, old_code: CodeType, new_code: CodeType) -> None:
248
        """Hook to save the compiled bytecode for direct execution."""
249
        if old_code is not self.original_code_object():
250
251
252
            return
        # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
        frame = sys._getframe()
253
        while frame and frame.f_back:
254
255
256
257
258
259
260
261
262
263
264
            frame = frame.f_back
            code_name = frame.f_code.co_name
            file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
            if code_name == "_compile" and file_name == "convert_frame.py":
                break
        frame = frame.f_locals["frame"]
        assert frame.f_code == old_code

        if frame.f_locals["self"] is not self:
            return

265
        self._compiled_bytecode = new_code
266
267
268
269
270

        path = self.vllm_config.compile_debug_dump_path()
        if path:
            decompiled_file = path / "transformed_code.py"
            if not decompiled_file.exists():
271
272
273
274
275
                try:
                    # usually the decompilation will succeed for most models,
                    # as we guarantee a full-graph compilation in Dynamo.
                    # but there's no 100% guarantee, since decompliation is
                    # not a reversible process.
276
                    import depyf
277

278
                    src = depyf.decompile(new_code)
279

280
281
282
                    with open(decompiled_file, "w") as f:
                        f.write(src)

283
                    logger.debug("Dynamo transformed code saved to %s", decompiled_file)
284
285
                except Exception:
                    pass
286

287
288
289
290
        if (
            self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
            and "update" in new_code.co_names
        ):
291
            import depyf
292

293
            src = depyf.decompile(new_code)
294
            msg = (
295
296
297
298
299
300
                "Assigning / modifying buffers of nn.Module during forward pass is not "
                "allowed when using cudagraph inside the compiler because it will "
                "cause silent errors. Please use eager mode or fix the code. The "
                "following code contains clues about which buffer is being modified "
                f"(please search for the usage of the function `update`):\n{src}"
            )
301
302
            raise RuntimeError(msg)

303
    @contextmanager
304
    def _dispatch_to_compiled_code(self) -> Generator[None, None, None]:
305
306
307
        # noqa: E501
        """
        Context manager to dispatch to internally compiled code for torch<2.8.
308
309
310
311
312
        Why does this work? Because Dynamo guarantees that the compiled
        bytecode has exactly the same arguments, cell variables, and free
        variables as the original code. Therefore we can directly switch
        the code object in the function and call it.

313
314
315
316
317
318
319
320
321
        See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
        """  # noqa: E501 line too long
        original = self.original_code_object()
        assert self._compiled_bytecode is not None
        self.__class__.forward.__code__ = self._compiled_bytecode
        try:
            yield
        finally:
            self.__class__.forward.__code__ = original
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370


def reset_compile_wrapper(model: torch.nn.Module) -> None:
    """
    Clean up compiled model and captured CUDA graphs for elastic EP.
    """
    if not isinstance(model, TorchCompileWithNoGuardsWrapper) and hasattr(
        model, "model"
    ):
        model = model.model
    if not isinstance(model, TorchCompileWithNoGuardsWrapper):
        return
    # model.do_not_compile is set by the @support_torch_compile decorator
    if hasattr(model, "do_not_compile") and model.do_not_compile:
        return
    from vllm.compilation.counter import compilation_counter

    # reset the compilation counter
    compilation_counter.num_models_seen = 0
    compilation_counter.num_graphs_seen = 0
    compilation_counter.num_piecewise_graphs_seen = 0
    compilation_counter.num_piecewise_capturable_graphs_seen = 0
    compilation_counter.num_backend_compilations = 0
    compilation_counter.num_gpu_runner_capture_triggers = 0
    compilation_counter.num_cudagraph_captured = 0
    compilation_counter.num_inductor_compiles = 0
    compilation_counter.num_eager_compiles = 0
    compilation_counter.num_cache_entries_updated = 0
    compilation_counter.num_compiled_artifacts_saved = 0
    compilation_counter.stock_torch_compile_count = 0

    # Clear the AOT compiled function so the model is forced to
    # recompile on the next call. Without this, decorators.py
    # __call__ uses the stale aot_compiled_fn whose torchinductor
    # kernels have old parameters (expert_map size for example)
    # baked in as compile-time constants.
    if hasattr(model, "aot_compiled_fn"):
        model.aot_compiled_fn = None
    if hasattr(model, "was_aot_compile_fn_loaded_from_disk"):
        model.was_aot_compile_fn_loaded_from_disk = False

    # Reset the cache_dir so VllmBackend recomputes the hash
    # (data_parallel_size changed, so the config hash differs).
    compilation_config = model.vllm_config.compilation_config
    compilation_config.cache_dir = ""
    compilation_config.local_cache_dir = ""

    model.__class__.forward.__code__ = model.original_code_object()
    TorchCompileWithNoGuardsWrapper.__init__(model)