wrapper.py 6.99 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import os
import sys
from abc import abstractmethod
7
from collections.abc import Callable
8
9
10
11
12
from contextlib import contextmanager
from types import CodeType

import torch

13
import vllm.envs as envs
14
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
15
16
17
from vllm.logger import init_logger

logger = init_logger(__name__)
18

19

youkaichao's avatar
youkaichao committed
20
class TorchCompileWrapperWithCustomDispatcher:
21
22
23
24
25
26
27
28
29
30
31
32
    """
    A wrapper class for torch.compile, with a custom dispatch logic.
    Subclasses should:
    1. Implement the forward method
    2. Implement the dispatch logic in the __call__ method
        It can use `self.compiled_codes` to access the compiled bytecode,
        and `with self.dispatch_to_code(index):` to dispatch to
        the compiled code.
    3. Implement the `__init__` method to determine how to call
        `torch.compile` over the forward method.
    """

33
    def __init__(
34
35
36
        self,
        compiled_callable: Callable | None = None,
        compilation_mode: CompilationMode = CompilationMode.NONE,
37
    ):
38
39
        vllm_config = get_current_vllm_config()
        self.vllm_config = vllm_config
40
41
42
43
        if compiled_callable is None:
            # default compilation settings
            # compiling the forward method

44
            backend = vllm_config.compilation_config.init_backend(vllm_config)
45
46
            options = None
            if isinstance(backend, str) and backend == "inductor":
47
48
49
                options = (
                    get_current_vllm_config().compilation_config.inductor_compile_config
                )
50
51
52
53
54
55
56
57
58
59
60
61
62
            if envs.VLLM_USE_AOT_COMPILE:
                options = options or {}
                # This effectively drop all the guards.
                # We need this because bytecode hook is not used any more to
                # drop guards in the AOT compile mode.
                options["guard_filter_fn"] = lambda guards: [False for _ in guards]
                if hasattr(torch._dynamo.config, "enable_aot_compile"):
                    torch._dynamo.config.enable_aot_compile = True
                else:
                    msg = "torch._dynamo.config.enable_aot_compile is not "
                    msg += "available. AOT compile is disabled and please "
                    msg += "upgrade PyTorch version to use AOT compile."
                    logger.warning(msg)
63

64
65
66
            compiled_callable = torch.compile(
                self.forward, fullgraph=True, backend=backend, options=options
            )
67

68
69
        self.compiled_callable = compiled_callable
        self.original_code_object = self.__class__.forward.__code__
70
        self.compiled_codes: list[CodeType] = []
71
72
73
74
75
        torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)

        # read the env var to determine whether to use the custom dispatcher
        # subclasses can use this to switch between the custom dispatcher
        # and the default Dynamo guard mechanism.
76
        self.use_custom_dispatcher: bool = (
77
            compilation_mode >= CompilationMode.DYNAMO_TRACE_ONCE
78
        )
79

80
81
82
83
84
85
86
87
88
    def aot_compile(self, *args, **kwargs):
        if not hasattr(self.compiled_callable, "aot_compile"):
            raise RuntimeError(
                "aot_compile is not supported by the current configuration. "
                + "Please make sure torch.compile is enabled with the latest "
                + f"version of PyTorch (current using torch: {torch.__version__})"
            )
        return self.compiled_callable.aot_compile((args, kwargs))

89
    def __call__(self, *args, **kwargs):
90
        """Implement the dispatch logic here, beyond the torch.compile mode.
91
92
93
94
95
96
        NOTE: this function can have additional arguments beyond the forward
         method, for directly dispatching to the compiled code.
        """
        return self.compiled_callable(*args, **kwargs)

    @abstractmethod
97
    def forward(self, *args, **kwargs): ...
98
99
100
101
102
103
104

    def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
        """Hook to save the compiled bytecode for direct execution."""
        if old_code is not self.original_code_object:
            return
        # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
        frame = sys._getframe()
105
        while frame and frame.f_back:
106
107
108
109
110
111
112
113
114
115
116
117
            frame = frame.f_back
            code_name = frame.f_code.co_name
            file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
            if code_name == "_compile" and file_name == "convert_frame.py":
                break
        frame = frame.f_locals["frame"]
        assert frame.f_code == old_code

        if frame.f_locals["self"] is not self:
            return

        self.compiled_codes.append(new_code)
118
119
120
121
122

        path = self.vllm_config.compile_debug_dump_path()
        if path:
            decompiled_file = path / "transformed_code.py"
            if not decompiled_file.exists():
123
124
125
126
127
                try:
                    # usually the decompilation will succeed for most models,
                    # as we guarantee a full-graph compilation in Dynamo.
                    # but there's no 100% guarantee, since decompliation is
                    # not a reversible process.
128
                    import depyf
129

130
                    src = depyf.decompile(new_code)
131

132
133
134
                    with open(decompiled_file, "w") as f:
                        f.write(src)

135
                    logger.debug("Dynamo transformed code saved to %s", decompiled_file)
136
137
                except Exception:
                    pass
138

139
140
141
142
        if (
            self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
            and "update" in new_code.co_names
        ):
143
            import depyf
144

145
            src = depyf.decompile(new_code)
146
            msg = (
147
148
149
150
151
152
                "Assigning / modifying buffers of nn.Module during forward pass is not "
                "allowed when using cudagraph inside the compiler because it will "
                "cause silent errors. Please use eager mode or fix the code. The "
                "following code contains clues about which buffer is being modified "
                f"(please search for the usage of the function `update`):\n{src}"
            )
153
154
            raise RuntimeError(msg)

155
156
157
158
159
160
161
162
    @contextmanager
    def dispatch_to_code(self, index: int):
        """Context manager to dispatch to the compiled code.
        Why does this work? Because Dynamo guarantees that the compiled
        bytecode has exactly the same arguments, cell variables, and free
        variables as the original code. Therefore we can directly switch
        the code object in the function and call it.

163
164
165
        See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7
        for more details.
        """
166
167
168
        self.__class__.forward.__code__ = self.compiled_codes[index]
        yield
        self.__class__.forward.__code__ = self.original_code_object