wrapper.py 13.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import os
import sys
from abc import abstractmethod
7
from collections.abc import Callable, Generator
8
from contextlib import contextmanager, nullcontext
9
from types import CodeType
10
from typing import Any, ParamSpec, TypeVar
11
12

import torch
13
import torch._C._dynamo.guards
14

15
import vllm.envs as envs
16
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
17
from vllm.config.compilation import DynamicShapesType
18
from vllm.logger import init_logger
19
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
20
21

logger = init_logger(__name__)
22

23
24
R = TypeVar("R")
P = ParamSpec("P")
25

26
27
28
29

def _noop_add_global_state_guard(
    self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
) -> None:
30
31
32
33
    """No-op to skip the GLOBAL_STATE guard entirely"""
    pass


34
35
36
def _noop_add_torch_function_mode_stack_guard(
    self: torch._C._dynamo.guards.GuardManager, *args: Any, **kwargs: Any
) -> None:
37
38
39
40
41
    """No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
    pass


@contextmanager
42
def _compilation_context() -> Generator[None, None, None]:
43
44
45
46
47
48
49
50
51
52
53
    """Context manager for compilation settings and patches.

    This manager:
    1. Sets higher dynamo cache limits for compilation. (Needed for
        qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
        Generally a recompilation can happen whenever we use a new
        backend instance in torch.compile.
    2. Patches out add_global_state_guard to skip GLOBAL_STATE guards
    3. Patches out add_torch_function_mode_stack_guard to skip
        TORCH_FUNCTION_MODE_STACK guards.
    4. Restores everything when compilation completes
54
    """
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    # Save original values
    original_global_state_guard = (
        torch._C._dynamo.guards.GuardManager.add_global_state_guard
    )
    original_torch_function_mode_stack_guard = (
        torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard
    )
    original_cache_size = torch._dynamo.config.cache_size_limit
    original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit

    try:
        # Set higher cache limits for compilation
        torch._dynamo.config.cache_size_limit = 2048
        torch._dynamo.config.accumulated_cache_size_limit = 8192

        # Patch guard manager
        torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
            _noop_add_global_state_guard
        )
        torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
            _noop_add_torch_function_mode_stack_guard
        )
        yield
    finally:
        # Restore original values
        torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
            original_global_state_guard
        )
        torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
            original_torch_function_mode_stack_guard
        )
        torch._dynamo.config.cache_size_limit = original_cache_size
        torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache


class TorchCompileWithNoGuardsWrapper:
91
    """
92
93
94
95
96
97
98
    A wrapper class for torch.compile, it ensures that all guards are dropped
    when CompilationMode is not CompilationMode.STOCK_TORCH_COMPILE.
    When guards are dropped, the first time __call__ is invoked, a single
    compilation is triggered. Dynamo should never be traced again after that
    since we drop all guards.
    """

99
    def check_invariants_and_forward(self, *args: Any, **kwargs: Any) -> Any:
100
101
102
103
104
        assert hasattr(self, "_check_shape_invariants")
        self._check_shape_invariants(*args, **kwargs)

        return self.forward(*args, **kwargs)

105
106
107
    def _call_with_optional_nvtx_range(
        self, callable_fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs
    ) -> Any:
108
109
110
111
112
113
114
115
116
117
118
119
120
        if self.layerwise_nvtx_tracing_enabled:
            args_list = list(args)
            kwargs_dict = dict(kwargs)
            with layerwise_nvtx_marker_context(
                "Torch Compiled Module (input):{}".format(self.__class__.__name__),
                self,
                in_tensor=args_list,
                kwargs=kwargs_dict,
            ) as ctx:
                ctx.result = callable_fn(*args, **kwargs)
            return ctx.result
        return callable_fn(*args, **kwargs)

121
    def __init__(self) -> None:
122
        self.compiled = False
123

124
125
        vllm_config = get_current_vllm_config()
        self.vllm_config = vllm_config
126
        mode = vllm_config.compilation_config.mode
127
128
129
        self.layerwise_nvtx_tracing_enabled = (
            vllm_config.observability_config.enable_layerwise_nvtx_tracing
        )
130
131
132
133
134
135
136
137
138
        if mode is None:
            raise RuntimeError("Compilation mode cannot be NO_COMPILATION")

        backend = vllm_config.compilation_config.init_backend(vllm_config)
        options = {}

        if isinstance(backend, str) and backend == "inductor":
            options = vllm_config.compilation_config.inductor_compile_config

139
140
141
142
143
144
145
        self.first_compile = True
        self.evaluate_guards = (
            vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
        )

        ds_type = vllm_config.compilation_config.dynamic_shapes_config.type

146
147
        if mode != CompilationMode.STOCK_TORCH_COMPILE:
            # Drop all the guards.
148
149
150
151
152
            if self.evaluate_guards:
                assert not envs.VLLM_USE_BYTECODE_HOOK, (
                    "compilation_config.dynamic_shapes_config.evaluate_guards "
                    "requires VLLM_USE_BYTECODE_HOOK=0. "
                )
153

154
155
156
157
158
159
160
161
162
163
164
165
166
                if envs.VLLM_USE_AOT_COMPILE:
                    # disabled until https://github.com/pytorch/pytorch/pull/169239
                    # is picked up.
                    assert ds_type != DynamicShapesType.BACKED, (
                        "evaluate_guards for backed shapes requires "
                        "VLLM_USE_AOT_COMPILE=False. "
                    )

                options["guard_filter_fn"] = lambda x: [
                    entry.guard_type == "SHAPE_ENV" for entry in x
                ]
            else:
                options["guard_filter_fn"] = lambda x: [False for _ in x]
167
168

        compiled_ptr: Any = self.forward
169
170
        # Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False

171
        if ds_type == DynamicShapesType.UNBACKED:
172
173
174
175
176
177
178
179
180
181
            # reason is that bytecode does torch._dynamo.eval_frame.
            # remove_from_cache(self.original_code_object()) to force a new
            # re-compilation. And if we use
            # compiled_ptr = self.check_invariants_and_forward
            # it will reset all entries.
            assert not envs.VLLM_USE_BYTECODE_HOOK, (
                "UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
            )
            assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards"

182
183
            compiled_ptr = self.check_invariants_and_forward

184
        aot_context = nullcontext()
185
186
        if envs.VLLM_USE_AOT_COMPILE:
            if hasattr(torch._dynamo.config, "enable_aot_compile"):
187
                aot_context = torch._dynamo.config.patch(enable_aot_compile=True)
188
189
190
191
192
193
            else:
                msg = "torch._dynamo.config.enable_aot_compile is not "
                msg += "available. AOT compile is disabled and please "
                msg += "upgrade PyTorch version to use AOT compile."
                logger.warning(msg)

194
195
196
197
198
199
200
201
        with aot_context:
            self._compiled_callable = torch.compile(
                compiled_ptr,
                fullgraph=True,
                dynamic=False,
                backend=backend,
                options=options,
            )
202

203
204
        if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
            torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
205
            self._compiled_bytecode: CodeType | None = None
206

207
    def aot_compile(self, *args: Any, **kwargs: Any) -> Any:
208
        if not hasattr(self._compiled_callable, "aot_compile"):
209
210
211
212
213
            raise RuntimeError(
                "aot_compile is not supported by the current configuration. "
                + "Please make sure torch.compile is enabled with the latest "
                + f"version of PyTorch (current using torch: {torch.__version__})"
            )
214
        return self._compiled_callable.aot_compile((args, kwargs))
215

216
    def __call__(self, *args: Any, **kwargs: Any) -> Any:
217
218
219
220
221
222
223
224
225
226
227
        if envs.VLLM_USE_BYTECODE_HOOK:
            if (
                self.vllm_config.compilation_config.mode
                == CompilationMode.STOCK_TORCH_COMPILE
            ):
                return self._compiled_callable(*args, **kwargs)

            if not self._compiled_bytecode:
                # Make sure a compilation is triggered by clearing dynamo
                # cache.
                torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
228
229
230
                return self._call_with_optional_nvtx_range(
                    self._compiled_callable, *args, **kwargs
                )
231
232
            else:
                with self._dispatch_to_compiled_code():
233
234
235
                    return self._call_with_optional_nvtx_range(
                        self.forward, *args, **kwargs
                    )
236
        else:
237
238
239
240
241
242
243
            ctx = (
                nullcontext()
                if self.first_compile or not self.evaluate_guards
                else torch.compiler.set_stance("fail_on_recompile")
            )
            self.first_compile = False
            with _compilation_context(), ctx:
244
245
246
                return self._call_with_optional_nvtx_range(
                    self._compiled_callable, *args, **kwargs
                )
247
248

    @abstractmethod
249
    def forward(self, *args: Any, **kwargs: Any) -> Any: ...
250

251
252
253
254
    def original_code_object(self) -> CodeType:
        """Return the original code object of the forward method."""
        return self.__class__.forward.__code__

255
    def bytecode_hook(self, old_code: CodeType, new_code: CodeType) -> None:
256
        """Hook to save the compiled bytecode for direct execution."""
257
        if old_code is not self.original_code_object():
258
259
260
            return
        # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
        frame = sys._getframe()
261
        while frame and frame.f_back:
262
263
264
265
266
267
268
269
270
271
272
            frame = frame.f_back
            code_name = frame.f_code.co_name
            file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
            if code_name == "_compile" and file_name == "convert_frame.py":
                break
        frame = frame.f_locals["frame"]
        assert frame.f_code == old_code

        if frame.f_locals["self"] is not self:
            return

273
        self._compiled_bytecode = new_code
274
275
276
277
278

        path = self.vllm_config.compile_debug_dump_path()
        if path:
            decompiled_file = path / "transformed_code.py"
            if not decompiled_file.exists():
279
280
281
282
283
                try:
                    # usually the decompilation will succeed for most models,
                    # as we guarantee a full-graph compilation in Dynamo.
                    # but there's no 100% guarantee, since decompliation is
                    # not a reversible process.
284
                    import depyf
285

286
                    src = depyf.decompile(new_code)
287

288
289
290
                    with open(decompiled_file, "w") as f:
                        f.write(src)

291
                    logger.debug("Dynamo transformed code saved to %s", decompiled_file)
292
293
                except Exception:
                    pass
294

295
296
297
298
        if (
            self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
            and "update" in new_code.co_names
        ):
299
            import depyf
300

301
            src = depyf.decompile(new_code)
302
            msg = (
303
304
305
306
307
308
                "Assigning / modifying buffers of nn.Module during forward pass is not "
                "allowed when using cudagraph inside the compiler because it will "
                "cause silent errors. Please use eager mode or fix the code. The "
                "following code contains clues about which buffer is being modified "
                f"(please search for the usage of the function `update`):\n{src}"
            )
309
310
            raise RuntimeError(msg)

311
    @contextmanager
312
    def _dispatch_to_compiled_code(self) -> Generator[None, None, None]:
313
314
315
        # noqa: E501
        """
        Context manager to dispatch to internally compiled code for torch<2.8.
316
317
318
319
320
        Why does this work? Because Dynamo guarantees that the compiled
        bytecode has exactly the same arguments, cell variables, and free
        variables as the original code. Therefore we can directly switch
        the code object in the function and call it.

321
322
323
324
325
326
327
328
329
        See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
        """  # noqa: E501 line too long
        original = self.original_code_object()
        assert self._compiled_bytecode is not None
        self.__class__.forward.__code__ = self._compiled_bytecode
        try:
            yield
        finally:
            self.__class__.forward.__code__ = original