wrapper.py 10.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import os
import sys
from abc import abstractmethod
from contextlib import contextmanager
from types import CodeType
9
from typing import Any
10
11

import torch
12
import torch._C._dynamo.guards
13

14
import vllm.envs as envs
15
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
16
17
18
from vllm.logger import init_logger

logger = init_logger(__name__)
19

20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def _noop_add_global_state_guard(self, *args, **kwargs):
    """No-op to skip the GLOBAL_STATE guard entirely"""
    pass


def _noop_add_torch_function_mode_stack_guard(self, *args, **kwargs):
    """No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
    pass


@contextmanager
def _compilation_context():
    """Context manager for compilation settings and patches.

    This manager:
    1. Sets higher dynamo cache limits for compilation. (Needed for
        qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
        Generally a recompilation can happen whenever we use a new
        backend instance in torch.compile.
    2. Patches out add_global_state_guard to skip GLOBAL_STATE guards
    3. Patches out add_torch_function_mode_stack_guard to skip
        TORCH_FUNCTION_MODE_STACK guards.
    4. Restores everything when compilation completes
44
    """
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    # Save original values
    original_global_state_guard = (
        torch._C._dynamo.guards.GuardManager.add_global_state_guard
    )
    original_torch_function_mode_stack_guard = (
        torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard
    )
    original_cache_size = torch._dynamo.config.cache_size_limit
    original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit

    try:
        # Set higher cache limits for compilation
        torch._dynamo.config.cache_size_limit = 2048
        torch._dynamo.config.accumulated_cache_size_limit = 8192

        # Patch guard manager
        torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
            _noop_add_global_state_guard
        )
        torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
            _noop_add_torch_function_mode_stack_guard
        )
        yield
    finally:
        # Restore original values
        torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
            original_global_state_guard
        )
        torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
            original_torch_function_mode_stack_guard
        )
        torch._dynamo.config.cache_size_limit = original_cache_size
        torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache


class TorchCompileWithNoGuardsWrapper:
81
    """
82
83
84
85
86
87
88
    A wrapper class for torch.compile, it ensures that all guards are dropped
    when CompilationMode is not CompilationMode.STOCK_TORCH_COMPILE.
    When guards are dropped, the first time __call__ is invoked, a single
    compilation is triggered. Dynamo should never be traced again after that
    since we drop all guards.
    """

89
90
91
92
93
94
    def check_invariants_and_forward(self, *args, **kwargs):
        assert hasattr(self, "_check_shape_invariants")
        self._check_shape_invariants(*args, **kwargs)

        return self.forward(*args, **kwargs)

95
96
    def __init__(self):
        self.compiled = False
97

98
99
        vllm_config = get_current_vllm_config()
        self.vllm_config = vllm_config
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        mode = vllm_config.compilation_config.mode
        if mode is None:
            raise RuntimeError("Compilation mode cannot be NO_COMPILATION")

        backend = vllm_config.compilation_config.init_backend(vllm_config)
        options = {}

        if isinstance(backend, str) and backend == "inductor":
            options = vllm_config.compilation_config.inductor_compile_config

        if mode != CompilationMode.STOCK_TORCH_COMPILE:
            # Drop all the guards.
            options["guard_filter_fn"] = lambda x: [False for _ in x]

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        # Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
        from vllm.compilation.decorators import DynamicShapesType

        ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
        compiled_ptr: Any = self.forward
        if ds_type == DynamicShapesType.UNBACKED:
            if envs.VLLM_USE_BYTECODE_HOOK:
                # reason is that bytecode does this hack torch._dynamo.eval_frame.
                # remove_from_cache(self.original_code_object()) to force a new
                # re-compilation.
                raise ValueError(
                    "UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
                )
            compiled_ptr = self.check_invariants_and_forward

129
130
131
132
133
134
135
136
137
138
        if envs.VLLM_USE_AOT_COMPILE:
            if hasattr(torch._dynamo.config, "enable_aot_compile"):
                torch._dynamo.config.enable_aot_compile = True
            else:
                msg = "torch._dynamo.config.enable_aot_compile is not "
                msg += "available. AOT compile is disabled and please "
                msg += "upgrade PyTorch version to use AOT compile."
                logger.warning(msg)

        self._compiled_callable = torch.compile(
139
            compiled_ptr,
140
141
142
143
            fullgraph=True,
            dynamic=False,
            backend=backend,
            options=options,
144
        )
145

146
147
148
149
        if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
            torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
            self._compiled_bytecode = None

150
    def aot_compile(self, *args, **kwargs):
151
        if not hasattr(self._compiled_callable, "aot_compile"):
152
153
154
155
156
            raise RuntimeError(
                "aot_compile is not supported by the current configuration. "
                + "Please make sure torch.compile is enabled with the latest "
                + f"version of PyTorch (current using torch: {torch.__version__})"
            )
157
        return self._compiled_callable.aot_compile((args, kwargs))
158

159
    def __call__(self, *args, **kwargs):
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        if envs.VLLM_USE_BYTECODE_HOOK:
            if (
                self.vllm_config.compilation_config.mode
                == CompilationMode.STOCK_TORCH_COMPILE
            ):
                return self._compiled_callable(*args, **kwargs)

            if not self._compiled_bytecode:
                # Make sure a compilation is triggered by clearing dynamo
                # cache.
                torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
                return self._compiled_callable(*args, **kwargs)
            else:
                with self._dispatch_to_compiled_code():
                    return self.forward(*args, **kwargs)
        else:
            with _compilation_context():
                return self._compiled_callable(*args, **kwargs)
178
179

    @abstractmethod
180
    def forward(self, *args, **kwargs): ...
181

182
183
184
185
    def original_code_object(self) -> CodeType:
        """Return the original code object of the forward method."""
        return self.__class__.forward.__code__

186
187
    def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
        """Hook to save the compiled bytecode for direct execution."""
188
        if old_code is not self.original_code_object():
189
190
191
            return
        # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
        frame = sys._getframe()
192
        while frame and frame.f_back:
193
194
195
196
197
198
199
200
201
202
203
            frame = frame.f_back
            code_name = frame.f_code.co_name
            file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
            if code_name == "_compile" and file_name == "convert_frame.py":
                break
        frame = frame.f_locals["frame"]
        assert frame.f_code == old_code

        if frame.f_locals["self"] is not self:
            return

204
        self._compiled_bytecode = new_code
205
206
207
208
209

        path = self.vllm_config.compile_debug_dump_path()
        if path:
            decompiled_file = path / "transformed_code.py"
            if not decompiled_file.exists():
210
211
212
213
214
                try:
                    # usually the decompilation will succeed for most models,
                    # as we guarantee a full-graph compilation in Dynamo.
                    # but there's no 100% guarantee, since decompliation is
                    # not a reversible process.
215
                    import depyf
216

217
                    src = depyf.decompile(new_code)
218

219
220
221
                    with open(decompiled_file, "w") as f:
                        f.write(src)

222
                    logger.debug("Dynamo transformed code saved to %s", decompiled_file)
223
224
                except Exception:
                    pass
225

226
227
228
229
        if (
            self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
            and "update" in new_code.co_names
        ):
230
            import depyf
231

232
            src = depyf.decompile(new_code)
233
            msg = (
234
235
236
237
238
239
                "Assigning / modifying buffers of nn.Module during forward pass is not "
                "allowed when using cudagraph inside the compiler because it will "
                "cause silent errors. Please use eager mode or fix the code. The "
                "following code contains clues about which buffer is being modified "
                f"(please search for the usage of the function `update`):\n{src}"
            )
240
241
            raise RuntimeError(msg)

242
    @contextmanager
243
244
245
246
    def _dispatch_to_compiled_code(self):
        # noqa: E501
        """
        Context manager to dispatch to internally compiled code for torch<2.8.
247
248
249
250
251
        Why does this work? Because Dynamo guarantees that the compiled
        bytecode has exactly the same arguments, cell variables, and free
        variables as the original code. Therefore we can directly switch
        the code object in the function and call it.

252
253
254
255
256
257
258
259
260
        See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
        """  # noqa: E501 line too long
        original = self.original_code_object()
        assert self._compiled_bytecode is not None
        self.__class__.forward.__code__ = self._compiled_bytecode
        try:
            yield
        finally:
            self.__class__.forward.__code__ = original