wrapper.py 12.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import os
import sys
from abc import abstractmethod
7
from contextlib import contextmanager, nullcontext
8
from types import CodeType
9
from typing import Any
10
11

import torch
12
import torch._C._dynamo.guards
13

14
import vllm.envs as envs
15
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
16
from vllm.config.compilation import DynamicShapesType
17
from vllm.logger import init_logger
18
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
19
20

logger = init_logger(__name__)
21

22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def _noop_add_global_state_guard(self, *args, **kwargs):
    """No-op to skip the GLOBAL_STATE guard entirely"""
    pass


def _noop_add_torch_function_mode_stack_guard(self, *args, **kwargs):
    """No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
    pass


@contextmanager
def _compilation_context():
    """Context manager for compilation settings and patches.

    This manager:
    1. Sets higher dynamo cache limits for compilation. (Needed for
        qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
        Generally a recompilation can happen whenever we use a new
        backend instance in torch.compile.
    2. Patches out add_global_state_guard to skip GLOBAL_STATE guards
    3. Patches out add_torch_function_mode_stack_guard to skip
        TORCH_FUNCTION_MODE_STACK guards.
    4. Restores everything when compilation completes
46
    """
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    # Save original values
    original_global_state_guard = (
        torch._C._dynamo.guards.GuardManager.add_global_state_guard
    )
    original_torch_function_mode_stack_guard = (
        torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard
    )
    original_cache_size = torch._dynamo.config.cache_size_limit
    original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit

    try:
        # Set higher cache limits for compilation
        torch._dynamo.config.cache_size_limit = 2048
        torch._dynamo.config.accumulated_cache_size_limit = 8192

        # Patch guard manager
        torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
            _noop_add_global_state_guard
        )
        torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
            _noop_add_torch_function_mode_stack_guard
        )
        yield
    finally:
        # Restore original values
        torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
            original_global_state_guard
        )
        torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
            original_torch_function_mode_stack_guard
        )
        torch._dynamo.config.cache_size_limit = original_cache_size
        torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache


class TorchCompileWithNoGuardsWrapper:
83
    """
84
85
86
87
88
89
90
    A wrapper class for torch.compile, it ensures that all guards are dropped
    when CompilationMode is not CompilationMode.STOCK_TORCH_COMPILE.
    When guards are dropped, the first time __call__ is invoked, a single
    compilation is triggered. Dynamo should never be traced again after that
    since we drop all guards.
    """

91
92
93
94
95
96
    def check_invariants_and_forward(self, *args, **kwargs):
        assert hasattr(self, "_check_shape_invariants")
        self._check_shape_invariants(*args, **kwargs)

        return self.forward(*args, **kwargs)

97
98
99
100
101
102
103
104
105
106
107
108
109
110
    def _call_with_optional_nvtx_range(self, callable_fn, *args, **kwargs):
        if self.layerwise_nvtx_tracing_enabled:
            args_list = list(args)
            kwargs_dict = dict(kwargs)
            with layerwise_nvtx_marker_context(
                "Torch Compiled Module (input):{}".format(self.__class__.__name__),
                self,
                in_tensor=args_list,
                kwargs=kwargs_dict,
            ) as ctx:
                ctx.result = callable_fn(*args, **kwargs)
            return ctx.result
        return callable_fn(*args, **kwargs)

111
112
    def __init__(self):
        self.compiled = False
113

114
115
        vllm_config = get_current_vllm_config()
        self.vllm_config = vllm_config
116
        mode = vllm_config.compilation_config.mode
117
118
119
        self.layerwise_nvtx_tracing_enabled = (
            vllm_config.observability_config.enable_layerwise_nvtx_tracing
        )
120
121
122
123
124
125
126
127
128
        if mode is None:
            raise RuntimeError("Compilation mode cannot be NO_COMPILATION")

        backend = vllm_config.compilation_config.init_backend(vllm_config)
        options = {}

        if isinstance(backend, str) and backend == "inductor":
            options = vllm_config.compilation_config.inductor_compile_config

129
130
131
132
133
134
135
        self.first_compile = True
        self.evaluate_guards = (
            vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
        )

        ds_type = vllm_config.compilation_config.dynamic_shapes_config.type

136
137
        if mode != CompilationMode.STOCK_TORCH_COMPILE:
            # Drop all the guards.
138
139
140
141
142
            if self.evaluate_guards:
                assert not envs.VLLM_USE_BYTECODE_HOOK, (
                    "compilation_config.dynamic_shapes_config.evaluate_guards "
                    "requires VLLM_USE_BYTECODE_HOOK=0. "
                )
143

144
145
146
147
148
149
150
151
152
153
154
155
156
                if envs.VLLM_USE_AOT_COMPILE:
                    # disabled until https://github.com/pytorch/pytorch/pull/169239
                    # is picked up.
                    assert ds_type != DynamicShapesType.BACKED, (
                        "evaluate_guards for backed shapes requires "
                        "VLLM_USE_AOT_COMPILE=False. "
                    )

                options["guard_filter_fn"] = lambda x: [
                    entry.guard_type == "SHAPE_ENV" for entry in x
                ]
            else:
                options["guard_filter_fn"] = lambda x: [False for _ in x]
157
158

        compiled_ptr: Any = self.forward
159
160
        # Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False

161
        if ds_type == DynamicShapesType.UNBACKED:
162
163
164
165
166
167
168
169
170
171
            # reason is that bytecode does torch._dynamo.eval_frame.
            # remove_from_cache(self.original_code_object()) to force a new
            # re-compilation. And if we use
            # compiled_ptr = self.check_invariants_and_forward
            # it will reset all entries.
            assert not envs.VLLM_USE_BYTECODE_HOOK, (
                "UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
            )
            assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards"

172
173
            compiled_ptr = self.check_invariants_and_forward

174
175
176
177
178
179
180
181
182
183
        if envs.VLLM_USE_AOT_COMPILE:
            if hasattr(torch._dynamo.config, "enable_aot_compile"):
                torch._dynamo.config.enable_aot_compile = True
            else:
                msg = "torch._dynamo.config.enable_aot_compile is not "
                msg += "available. AOT compile is disabled and please "
                msg += "upgrade PyTorch version to use AOT compile."
                logger.warning(msg)

        self._compiled_callable = torch.compile(
184
            compiled_ptr,
185
186
187
188
            fullgraph=True,
            dynamic=False,
            backend=backend,
            options=options,
189
        )
190

191
192
193
194
        if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
            torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
            self._compiled_bytecode = None

195
    def aot_compile(self, *args, **kwargs):
196
        if not hasattr(self._compiled_callable, "aot_compile"):
197
198
199
200
201
            raise RuntimeError(
                "aot_compile is not supported by the current configuration. "
                + "Please make sure torch.compile is enabled with the latest "
                + f"version of PyTorch (current using torch: {torch.__version__})"
            )
202
        return self._compiled_callable.aot_compile((args, kwargs))
203

204
    def __call__(self, *args, **kwargs):
205
206
207
208
209
210
211
212
213
214
215
        if envs.VLLM_USE_BYTECODE_HOOK:
            if (
                self.vllm_config.compilation_config.mode
                == CompilationMode.STOCK_TORCH_COMPILE
            ):
                return self._compiled_callable(*args, **kwargs)

            if not self._compiled_bytecode:
                # Make sure a compilation is triggered by clearing dynamo
                # cache.
                torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
216
217
218
                return self._call_with_optional_nvtx_range(
                    self._compiled_callable, *args, **kwargs
                )
219
220
            else:
                with self._dispatch_to_compiled_code():
221
222
223
                    return self._call_with_optional_nvtx_range(
                        self.forward, *args, **kwargs
                    )
224
        else:
225
226
227
228
229
230
231
            ctx = (
                nullcontext()
                if self.first_compile or not self.evaluate_guards
                else torch.compiler.set_stance("fail_on_recompile")
            )
            self.first_compile = False
            with _compilation_context(), ctx:
232
233
234
                return self._call_with_optional_nvtx_range(
                    self._compiled_callable, *args, **kwargs
                )
235
236

    @abstractmethod
237
    def forward(self, *args, **kwargs): ...
238

239
240
241
242
    def original_code_object(self) -> CodeType:
        """Return the original code object of the forward method."""
        return self.__class__.forward.__code__

243
244
    def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
        """Hook to save the compiled bytecode for direct execution."""
245
        if old_code is not self.original_code_object():
246
247
248
            return
        # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
        frame = sys._getframe()
249
        while frame and frame.f_back:
250
251
252
253
254
255
256
257
258
259
260
            frame = frame.f_back
            code_name = frame.f_code.co_name
            file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
            if code_name == "_compile" and file_name == "convert_frame.py":
                break
        frame = frame.f_locals["frame"]
        assert frame.f_code == old_code

        if frame.f_locals["self"] is not self:
            return

261
        self._compiled_bytecode = new_code
262
263
264
265
266

        path = self.vllm_config.compile_debug_dump_path()
        if path:
            decompiled_file = path / "transformed_code.py"
            if not decompiled_file.exists():
267
268
269
270
271
                try:
                    # usually the decompilation will succeed for most models,
                    # as we guarantee a full-graph compilation in Dynamo.
                    # but there's no 100% guarantee, since decompliation is
                    # not a reversible process.
272
                    import depyf
273

274
                    src = depyf.decompile(new_code)
275

276
277
278
                    with open(decompiled_file, "w") as f:
                        f.write(src)

279
                    logger.debug("Dynamo transformed code saved to %s", decompiled_file)
280
281
                except Exception:
                    pass
282

283
284
285
286
        if (
            self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
            and "update" in new_code.co_names
        ):
287
            import depyf
288

289
            src = depyf.decompile(new_code)
290
            msg = (
291
292
293
294
295
296
                "Assigning / modifying buffers of nn.Module during forward pass is not "
                "allowed when using cudagraph inside the compiler because it will "
                "cause silent errors. Please use eager mode or fix the code. The "
                "following code contains clues about which buffer is being modified "
                f"(please search for the usage of the function `update`):\n{src}"
            )
297
298
            raise RuntimeError(msg)

299
    @contextmanager
300
301
302
303
    def _dispatch_to_compiled_code(self):
        # noqa: E501
        """
        Context manager to dispatch to internally compiled code for torch<2.8.
304
305
306
307
308
        Why does this work? Because Dynamo guarantees that the compiled
        bytecode has exactly the same arguments, cell variables, and free
        variables as the original code. Therefore we can directly switch
        the code object in the function and call it.

309
310
311
312
313
314
315
316
317
        See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
        """  # noqa: E501 line too long
        original = self.original_code_object()
        assert self._compiled_bytecode is not None
        self.__class__.forward.__code__ = self._compiled_bytecode
        try:
            yield
        finally:
            self.__class__.forward.__code__ = original