wrapper.py 14.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import os
import sys
from abc import abstractmethod
7
from collections.abc import Callable, Generator
8
from contextlib import contextmanager, nullcontext
9
from types import CodeType
10
from typing import Any, ParamSpec, TypeVar
11
12
13

import torch

14
import vllm.envs as envs
15
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
16
from vllm.config.compilation import DynamicShapesType
17
from vllm.logger import init_logger
18
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
19
20

logger = init_logger(__name__)
21

22
23
R = TypeVar("R")
P = ParamSpec("P")
24

25

26
@contextmanager
27
def _compilation_context() -> Generator[None, None, None]:
28
29
30
31
32
33
    """Context manager for compilation settings.

    This manager sets higher dynamo cache limits for compilation.
    (Needed for qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
    Generally a recompilation can happen whenever we use a new
    backend instance in torch.compile.
34
    """
35
36
37
38
39
40
41
42
43
44
45
46
47
    original_cache_size = torch._dynamo.config.cache_size_limit
    original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit

    try:
        torch._dynamo.config.cache_size_limit = 2048
        torch._dynamo.config.accumulated_cache_size_limit = 8192
        yield
    finally:
        torch._dynamo.config.cache_size_limit = original_cache_size
        torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache


class TorchCompileWithNoGuardsWrapper:
48
    """
49
50
51
52
53
54
55
    A wrapper class for torch.compile, it ensures that all guards are dropped
    when CompilationMode is not CompilationMode.STOCK_TORCH_COMPILE.
    When guards are dropped, the first time __call__ is invoked, a single
    compilation is triggered. Dynamo should never be traced again after that
    since we drop all guards.
    """

56
    def check_invariants_and_forward(self, *args: Any, **kwargs: Any) -> Any:
57
58
59
60
61
        assert hasattr(self, "_check_shape_invariants")
        self._check_shape_invariants(*args, **kwargs)

        return self.forward(*args, **kwargs)

62
63
64
    def _call_with_optional_nvtx_range(
        self, callable_fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs
    ) -> Any:
65
66
67
68
69
70
71
72
73
74
75
76
77
        if self.layerwise_nvtx_tracing_enabled:
            args_list = list(args)
            kwargs_dict = dict(kwargs)
            with layerwise_nvtx_marker_context(
                "Torch Compiled Module (input):{}".format(self.__class__.__name__),
                self,
                in_tensor=args_list,
                kwargs=kwargs_dict,
            ) as ctx:
                ctx.result = callable_fn(*args, **kwargs)
            return ctx.result
        return callable_fn(*args, **kwargs)

78
79
80
81
82
    def __init__(
        self,
        compile_prefix: str = "",
        is_encoder: bool = False,
    ) -> None:
83
        self.compiled = False
84
85
        self._compile_prefix = compile_prefix
        self._is_encoder = is_encoder
86

87
88
        vllm_config = get_current_vllm_config()
        self.vllm_config = vllm_config
89
        mode = vllm_config.compilation_config.mode
90
91
92
        self.layerwise_nvtx_tracing_enabled = (
            vllm_config.observability_config.enable_layerwise_nvtx_tracing
        )
93
94
95
        if mode is None:
            raise RuntimeError("Compilation mode cannot be NO_COMPILATION")

96
97
98
        backend = vllm_config.compilation_config.init_backend(
            vllm_config, prefix=compile_prefix, is_encoder=is_encoder
        )
99
100
101
102
103
        options = {}

        if isinstance(backend, str) and backend == "inductor":
            options = vllm_config.compilation_config.inductor_compile_config

104
105
106
107
108
109
110
        self.first_compile = True
        self.evaluate_guards = (
            vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
        )

        ds_type = vllm_config.compilation_config.dynamic_shapes_config.type

111
112
        if mode != CompilationMode.STOCK_TORCH_COMPILE:
            # Drop all the guards.
113
114
115
116
117
            if self.evaluate_guards:
                assert not envs.VLLM_USE_BYTECODE_HOOK, (
                    "compilation_config.dynamic_shapes_config.evaluate_guards "
                    "requires VLLM_USE_BYTECODE_HOOK=0. "
                )
118

119
120
121
122
                options["guard_filter_fn"] = lambda x: [
                    entry.guard_type == "SHAPE_ENV" for entry in x
                ]
            else:
123
124
125
126
127
128
                if hasattr(torch.compiler, "skip_all_guards_unsafe"):
                    # Torch 2.10+ provides skip_all_guards_unsafe
                    options["guard_filter_fn"] = torch.compiler.skip_all_guards_unsafe
                else:
                    # Equivalent fallback for older PyTorch: skip all guards
                    options["guard_filter_fn"] = lambda x: [False for _ in x]
129
130

        compiled_ptr: Any = self.forward
131
132
        # Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False

133
        if ds_type == DynamicShapesType.UNBACKED:
134
135
136
137
138
139
140
141
142
143
            # reason is that bytecode does torch._dynamo.eval_frame.
            # remove_from_cache(self.original_code_object()) to force a new
            # re-compilation. And if we use
            # compiled_ptr = self.check_invariants_and_forward
            # it will reset all entries.
            assert not envs.VLLM_USE_BYTECODE_HOOK, (
                "UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
            )
            assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards"

144
145
            compiled_ptr = self.check_invariants_and_forward

146
147
148
149
150
151
152
        # Apply the constrain_to_fx_strides patch before first compilation.
        # This covers STOCK_TORCH_COMPILE and DYNAMO_ONCE paths. The VLLM
        # compile paths call this from their own compile() methods too.
        from vllm.env_override import _apply_constrain_to_fx_strides_patch

        _apply_constrain_to_fx_strides_patch()

153
        aot_context = nullcontext()
154
155
        if envs.VLLM_USE_AOT_COMPILE:
            if hasattr(torch._dynamo.config, "enable_aot_compile"):
156
                aot_context = torch._dynamo.config.patch(enable_aot_compile=True)
157
158
159
160
161
162
            else:
                msg = "torch._dynamo.config.enable_aot_compile is not "
                msg += "available. AOT compile is disabled and please "
                msg += "upgrade PyTorch version to use AOT compile."
                logger.warning(msg)

163
164
165
166
167
168
169
170
        with aot_context:
            self._compiled_callable = torch.compile(
                compiled_ptr,
                fullgraph=True,
                dynamic=False,
                backend=backend,
                options=options,
            )
171

172
173
        if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
            torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
174
            self._compiled_bytecode: CodeType | None = None
175

176
    def aot_compile(self, *args: Any, **kwargs: Any) -> Any:
177
        if not hasattr(self._compiled_callable, "aot_compile"):
178
179
            raise RuntimeError(
                "aot_compile is not supported by the current configuration. "
180
181
                "Please make sure torch.compile is enabled with the latest "
                f"version of PyTorch (current using torch: {torch.__version__})"
182
            )
183
        return self._compiled_callable.aot_compile((args, kwargs))
184

185
    def __call__(self, *args: Any, **kwargs: Any) -> Any:
186
187
188
189
190
191
192
193
194
195
196
        if envs.VLLM_USE_BYTECODE_HOOK:
            if (
                self.vllm_config.compilation_config.mode
                == CompilationMode.STOCK_TORCH_COMPILE
            ):
                return self._compiled_callable(*args, **kwargs)

            if not self._compiled_bytecode:
                # Make sure a compilation is triggered by clearing dynamo
                # cache.
                torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
197
198
199
                return self._call_with_optional_nvtx_range(
                    self._compiled_callable, *args, **kwargs
                )
200
201
            else:
                with self._dispatch_to_compiled_code():
202
203
204
                    return self._call_with_optional_nvtx_range(
                        self.forward, *args, **kwargs
                    )
205
        else:
206
207
208
209
210
211
212
            ctx = (
                nullcontext()
                if self.first_compile or not self.evaluate_guards
                else torch.compiler.set_stance("fail_on_recompile")
            )
            self.first_compile = False
            with _compilation_context(), ctx:
213
214
215
                return self._call_with_optional_nvtx_range(
                    self._compiled_callable, *args, **kwargs
                )
216
217

    @abstractmethod
218
    def forward(self, *args: Any, **kwargs: Any) -> Any: ...
219

220
221
222
223
    def original_code_object(self) -> CodeType:
        """Return the original code object of the forward method."""
        return self.__class__.forward.__code__

224
    def bytecode_hook(self, old_code: CodeType, new_code: CodeType) -> None:
225
        """Hook to save the compiled bytecode for direct execution."""
226
        if old_code is not self.original_code_object():
227
228
229
            return
        # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
        frame = sys._getframe()
230
        while frame and frame.f_back:
231
232
233
234
235
236
237
238
239
240
241
            frame = frame.f_back
            code_name = frame.f_code.co_name
            file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
            if code_name == "_compile" and file_name == "convert_frame.py":
                break
        frame = frame.f_locals["frame"]
        assert frame.f_code == old_code

        if frame.f_locals["self"] is not self:
            return

242
        self._compiled_bytecode = new_code
243
244
245
246
247

        path = self.vllm_config.compile_debug_dump_path()
        if path:
            decompiled_file = path / "transformed_code.py"
            if not decompiled_file.exists():
248
249
250
251
252
                try:
                    # usually the decompilation will succeed for most models,
                    # as we guarantee a full-graph compilation in Dynamo.
                    # but there's no 100% guarantee, since decompliation is
                    # not a reversible process.
253
                    import depyf
254

255
                    src = depyf.decompile(new_code)
256

257
258
259
                    with open(decompiled_file, "w") as f:
                        f.write(src)

260
                    logger.debug("Dynamo transformed code saved to %s", decompiled_file)
261
262
                except Exception:
                    pass
263

264
265
266
267
        if (
            self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
            and "update" in new_code.co_names
        ):
268
            import depyf
269

270
            src = depyf.decompile(new_code)
271
            msg = (
272
273
274
275
276
277
                "Assigning / modifying buffers of nn.Module during forward pass is not "
                "allowed when using cudagraph inside the compiler because it will "
                "cause silent errors. Please use eager mode or fix the code. The "
                "following code contains clues about which buffer is being modified "
                f"(please search for the usage of the function `update`):\n{src}"
            )
278
279
            raise RuntimeError(msg)

280
    @contextmanager
281
    def _dispatch_to_compiled_code(self) -> Generator[None, None, None]:
282
283
284
        # noqa: E501
        """
        Context manager to dispatch to internally compiled code for torch<2.8.
285
286
287
288
289
        Why does this work? Because Dynamo guarantees that the compiled
        bytecode has exactly the same arguments, cell variables, and free
        variables as the original code. Therefore we can directly switch
        the code object in the function and call it.

290
291
292
293
294
295
296
297
298
        See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
        """  # noqa: E501 line too long
        original = self.original_code_object()
        assert self._compiled_bytecode is not None
        self.__class__.forward.__code__ = self._compiled_bytecode
        try:
            yield
        finally:
            self.__class__.forward.__code__ = original
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328


def reset_compile_wrapper(model: torch.nn.Module) -> None:
    """
    Clean up compiled model and captured CUDA graphs for elastic EP.
    """
    if not isinstance(model, TorchCompileWithNoGuardsWrapper) and hasattr(
        model, "model"
    ):
        model = model.model
    if not isinstance(model, TorchCompileWithNoGuardsWrapper):
        return
    # model.do_not_compile is set by the @support_torch_compile decorator
    if hasattr(model, "do_not_compile") and model.do_not_compile:
        return
    from vllm.compilation.counter import compilation_counter

    # reset the compilation counter
    compilation_counter.num_models_seen = 0
    compilation_counter.num_graphs_seen = 0
    compilation_counter.num_piecewise_graphs_seen = 0
    compilation_counter.num_piecewise_capturable_graphs_seen = 0
    compilation_counter.num_backend_compilations = 0
    compilation_counter.num_gpu_runner_capture_triggers = 0
    compilation_counter.num_cudagraph_captured = 0
    compilation_counter.num_inductor_compiles = 0
    compilation_counter.num_eager_compiles = 0
    compilation_counter.num_cache_entries_updated = 0
    compilation_counter.num_compiled_artifacts_saved = 0
    compilation_counter.stock_torch_compile_count = 0
329
330
331
    compilation_counter.num_aot_compiles = 0
    compilation_counter.num_aot_artifacts_saved = 0
    compilation_counter.num_aot_artifacts_loaded = 0
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349

    # Clear the AOT compiled function so the model is forced to
    # recompile on the next call. Without this, decorators.py
    # __call__ uses the stale aot_compiled_fn whose torchinductor
    # kernels have old parameters (expert_map size for example)
    # baked in as compile-time constants.
    if hasattr(model, "aot_compiled_fn"):
        model.aot_compiled_fn = None
    if hasattr(model, "was_aot_compile_fn_loaded_from_disk"):
        model.was_aot_compile_fn_loaded_from_disk = False

    # Reset the cache_dir so VllmBackend recomputes the hash
    # (data_parallel_size changed, so the config hash differs).
    compilation_config = model.vllm_config.compilation_config
    compilation_config.cache_dir = ""
    compilation_config.local_cache_dir = ""

    model.__class__.forward.__code__ = model.original_code_object()
350
351
352
353
354
    TorchCompileWithNoGuardsWrapper.__init__(
        model,
        compile_prefix=model._compile_prefix,
        is_encoder=model._is_encoder,
    )