wrapper.py 13.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import os
import sys
from abc import abstractmethod
7
from collections.abc import Callable, Generator
8
from contextlib import contextmanager, nullcontext
9
from types import CodeType
10
from typing import Any, ParamSpec, TypeVar
11
12
13

import torch

14
import vllm.envs as envs
15
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
16
from vllm.config.compilation import DynamicShapesType
17
from vllm.logger import init_logger
18
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
19
20

logger = init_logger(__name__)
21

22
23
R = TypeVar("R")
P = ParamSpec("P")
24

25

26
@contextmanager
27
def _compilation_context() -> Generator[None, None, None]:
28
29
30
31
32
33
    """Context manager for compilation settings.

    This manager sets higher dynamo cache limits for compilation.
    (Needed for qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
    Generally a recompilation can happen whenever we use a new
    backend instance in torch.compile.
34
    """
35
36
37
38
39
40
41
42
43
44
45
46
47
    original_cache_size = torch._dynamo.config.cache_size_limit
    original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit

    try:
        torch._dynamo.config.cache_size_limit = 2048
        torch._dynamo.config.accumulated_cache_size_limit = 8192
        yield
    finally:
        torch._dynamo.config.cache_size_limit = original_cache_size
        torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache


class TorchCompileWithNoGuardsWrapper:
48
    """
49
50
51
52
53
54
55
    A wrapper class for torch.compile, it ensures that all guards are dropped
    when CompilationMode is not CompilationMode.STOCK_TORCH_COMPILE.
    When guards are dropped, the first time __call__ is invoked, a single
    compilation is triggered. Dynamo should never be traced again after that
    since we drop all guards.
    """

56
    def check_invariants_and_forward(self, *args: Any, **kwargs: Any) -> Any:
57
58
59
60
61
        assert hasattr(self, "_check_shape_invariants")
        self._check_shape_invariants(*args, **kwargs)

        return self.forward(*args, **kwargs)

62
63
64
    def _call_with_optional_nvtx_range(
        self, callable_fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs
    ) -> Any:
65
66
67
68
69
70
71
72
73
74
75
76
77
        if self.layerwise_nvtx_tracing_enabled:
            args_list = list(args)
            kwargs_dict = dict(kwargs)
            with layerwise_nvtx_marker_context(
                "Torch Compiled Module (input):{}".format(self.__class__.__name__),
                self,
                in_tensor=args_list,
                kwargs=kwargs_dict,
            ) as ctx:
                ctx.result = callable_fn(*args, **kwargs)
            return ctx.result
        return callable_fn(*args, **kwargs)

78
    def __init__(self) -> None:
79
        self.compiled = False
80

81
82
        vllm_config = get_current_vllm_config()
        self.vllm_config = vllm_config
83
        mode = vllm_config.compilation_config.mode
84
85
86
        self.layerwise_nvtx_tracing_enabled = (
            vllm_config.observability_config.enable_layerwise_nvtx_tracing
        )
87
88
89
90
91
92
93
94
95
        if mode is None:
            raise RuntimeError("Compilation mode cannot be NO_COMPILATION")

        backend = vllm_config.compilation_config.init_backend(vllm_config)
        options = {}

        if isinstance(backend, str) and backend == "inductor":
            options = vllm_config.compilation_config.inductor_compile_config

96
97
98
99
100
101
102
        self.first_compile = True
        self.evaluate_guards = (
            vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
        )

        ds_type = vllm_config.compilation_config.dynamic_shapes_config.type

103
104
        if mode != CompilationMode.STOCK_TORCH_COMPILE:
            # Drop all the guards.
105
106
107
108
109
            if self.evaluate_guards:
                assert not envs.VLLM_USE_BYTECODE_HOOK, (
                    "compilation_config.dynamic_shapes_config.evaluate_guards "
                    "requires VLLM_USE_BYTECODE_HOOK=0. "
                )
110

111
112
113
114
                options["guard_filter_fn"] = lambda x: [
                    entry.guard_type == "SHAPE_ENV" for entry in x
                ]
            else:
115
116
117
118
119
120
                if hasattr(torch.compiler, "skip_all_guards_unsafe"):
                    # Torch 2.10+ provides skip_all_guards_unsafe
                    options["guard_filter_fn"] = torch.compiler.skip_all_guards_unsafe
                else:
                    # Equivalent fallback for older PyTorch: skip all guards
                    options["guard_filter_fn"] = lambda x: [False for _ in x]
121
122

        compiled_ptr: Any = self.forward
123
124
        # Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False

125
        if ds_type == DynamicShapesType.UNBACKED:
126
127
128
129
130
131
132
133
134
135
            # reason is that bytecode does torch._dynamo.eval_frame.
            # remove_from_cache(self.original_code_object()) to force a new
            # re-compilation. And if we use
            # compiled_ptr = self.check_invariants_and_forward
            # it will reset all entries.
            assert not envs.VLLM_USE_BYTECODE_HOOK, (
                "UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
            )
            assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards"

136
137
            compiled_ptr = self.check_invariants_and_forward

138
        aot_context = nullcontext()
139
140
        if envs.VLLM_USE_AOT_COMPILE:
            if hasattr(torch._dynamo.config, "enable_aot_compile"):
141
                aot_context = torch._dynamo.config.patch(enable_aot_compile=True)
142
143
144
145
146
147
            else:
                msg = "torch._dynamo.config.enable_aot_compile is not "
                msg += "available. AOT compile is disabled and please "
                msg += "upgrade PyTorch version to use AOT compile."
                logger.warning(msg)

148
149
150
151
152
153
154
155
        with aot_context:
            self._compiled_callable = torch.compile(
                compiled_ptr,
                fullgraph=True,
                dynamic=False,
                backend=backend,
                options=options,
            )
156

157
158
        if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
            torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
159
            self._compiled_bytecode: CodeType | None = None
160

161
    def aot_compile(self, *args: Any, **kwargs: Any) -> Any:
162
        if not hasattr(self._compiled_callable, "aot_compile"):
163
164
            raise RuntimeError(
                "aot_compile is not supported by the current configuration. "
165
166
                "Please make sure torch.compile is enabled with the latest "
                f"version of PyTorch (current using torch: {torch.__version__})"
167
            )
168
        return self._compiled_callable.aot_compile((args, kwargs))
169

170
    def __call__(self, *args: Any, **kwargs: Any) -> Any:
171
172
173
174
175
176
177
178
179
180
181
        if envs.VLLM_USE_BYTECODE_HOOK:
            if (
                self.vllm_config.compilation_config.mode
                == CompilationMode.STOCK_TORCH_COMPILE
            ):
                return self._compiled_callable(*args, **kwargs)

            if not self._compiled_bytecode:
                # Make sure a compilation is triggered by clearing dynamo
                # cache.
                torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
182
183
184
                return self._call_with_optional_nvtx_range(
                    self._compiled_callable, *args, **kwargs
                )
185
186
            else:
                with self._dispatch_to_compiled_code():
187
188
189
                    return self._call_with_optional_nvtx_range(
                        self.forward, *args, **kwargs
                    )
190
        else:
191
192
193
194
195
196
197
            ctx = (
                nullcontext()
                if self.first_compile or not self.evaluate_guards
                else torch.compiler.set_stance("fail_on_recompile")
            )
            self.first_compile = False
            with _compilation_context(), ctx:
198
199
200
                return self._call_with_optional_nvtx_range(
                    self._compiled_callable, *args, **kwargs
                )
201
202

    @abstractmethod
203
    def forward(self, *args: Any, **kwargs: Any) -> Any: ...
204

205
206
207
208
    def original_code_object(self) -> CodeType:
        """Return the original code object of the forward method."""
        return self.__class__.forward.__code__

209
    def bytecode_hook(self, old_code: CodeType, new_code: CodeType) -> None:
210
        """Hook to save the compiled bytecode for direct execution."""
211
        if old_code is not self.original_code_object():
212
213
214
            return
        # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
        frame = sys._getframe()
215
        while frame and frame.f_back:
216
217
218
219
220
221
222
223
224
225
226
            frame = frame.f_back
            code_name = frame.f_code.co_name
            file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
            if code_name == "_compile" and file_name == "convert_frame.py":
                break
        frame = frame.f_locals["frame"]
        assert frame.f_code == old_code

        if frame.f_locals["self"] is not self:
            return

227
        self._compiled_bytecode = new_code
228
229
230
231
232

        path = self.vllm_config.compile_debug_dump_path()
        if path:
            decompiled_file = path / "transformed_code.py"
            if not decompiled_file.exists():
233
234
235
236
237
                try:
                    # usually the decompilation will succeed for most models,
                    # as we guarantee a full-graph compilation in Dynamo.
                    # but there's no 100% guarantee, since decompliation is
                    # not a reversible process.
238
                    import depyf
239

240
                    src = depyf.decompile(new_code)
241

242
243
244
                    with open(decompiled_file, "w") as f:
                        f.write(src)

245
                    logger.debug("Dynamo transformed code saved to %s", decompiled_file)
246
247
                except Exception:
                    pass
248

249
250
251
252
        if (
            self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
            and "update" in new_code.co_names
        ):
253
            import depyf
254

255
            src = depyf.decompile(new_code)
256
            msg = (
257
258
259
260
261
262
                "Assigning / modifying buffers of nn.Module during forward pass is not "
                "allowed when using cudagraph inside the compiler because it will "
                "cause silent errors. Please use eager mode or fix the code. The "
                "following code contains clues about which buffer is being modified "
                f"(please search for the usage of the function `update`):\n{src}"
            )
263
264
            raise RuntimeError(msg)

265
    @contextmanager
266
    def _dispatch_to_compiled_code(self) -> Generator[None, None, None]:
267
268
269
        # noqa: E501
        """
        Context manager to dispatch to internally compiled code for torch<2.8.
270
271
272
273
274
        Why does this work? Because Dynamo guarantees that the compiled
        bytecode has exactly the same arguments, cell variables, and free
        variables as the original code. Therefore we can directly switch
        the code object in the function and call it.

275
276
277
278
279
280
281
282
283
        See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
        """  # noqa: E501 line too long
        original = self.original_code_object()
        assert self._compiled_bytecode is not None
        self.__class__.forward.__code__ = self._compiled_bytecode
        try:
            yield
        finally:
            self.__class__.forward.__code__ = original
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313


def reset_compile_wrapper(model: torch.nn.Module) -> None:
    """
    Clean up compiled model and captured CUDA graphs for elastic EP.
    """
    if not isinstance(model, TorchCompileWithNoGuardsWrapper) and hasattr(
        model, "model"
    ):
        model = model.model
    if not isinstance(model, TorchCompileWithNoGuardsWrapper):
        return
    # model.do_not_compile is set by the @support_torch_compile decorator
    if hasattr(model, "do_not_compile") and model.do_not_compile:
        return
    from vllm.compilation.counter import compilation_counter

    # reset the compilation counter
    compilation_counter.num_models_seen = 0
    compilation_counter.num_graphs_seen = 0
    compilation_counter.num_piecewise_graphs_seen = 0
    compilation_counter.num_piecewise_capturable_graphs_seen = 0
    compilation_counter.num_backend_compilations = 0
    compilation_counter.num_gpu_runner_capture_triggers = 0
    compilation_counter.num_cudagraph_captured = 0
    compilation_counter.num_inductor_compiles = 0
    compilation_counter.num_eager_compiles = 0
    compilation_counter.num_cache_entries_updated = 0
    compilation_counter.num_compiled_artifacts_saved = 0
    compilation_counter.stock_torch_compile_count = 0
314
315
316
    compilation_counter.num_aot_compiles = 0
    compilation_counter.num_aot_artifacts_saved = 0
    compilation_counter.num_aot_artifacts_loaded = 0
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335

    # Clear the AOT compiled function so the model is forced to
    # recompile on the next call. Without this, decorators.py
    # __call__ uses the stale aot_compiled_fn whose torchinductor
    # kernels have old parameters (expert_map size for example)
    # baked in as compile-time constants.
    if hasattr(model, "aot_compiled_fn"):
        model.aot_compiled_fn = None
    if hasattr(model, "was_aot_compile_fn_loaded_from_disk"):
        model.was_aot_compile_fn_loaded_from_disk = False

    # Reset the cache_dir so VllmBackend recomputes the hash
    # (data_parallel_size changed, so the config hash differs).
    compilation_config = model.vllm_config.compilation_config
    compilation_config.cache_dir = ""
    compilation_config.local_cache_dir = ""

    model.__class__.forward.__code__ = model.original_code_object()
    TorchCompileWithNoGuardsWrapper.__init__(model)