wrapper.py 5.75 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
import os
import sys
from abc import abstractmethod
from contextlib import contextmanager
from types import CodeType
9
from typing import Callable, Optional
10
11
12

import torch

13
14
from vllm.config import (CompilationLevel, CUDAGraphMode,
                         get_current_vllm_config)
15
16
17
from vllm.logger import init_logger

logger = init_logger(__name__)
18

19

youkaichao's avatar
youkaichao committed
20
class TorchCompileWrapperWithCustomDispatcher:
21
22
23
24
25
26
27
28
29
30
31
32
    """
    A wrapper class for torch.compile, with a custom dispatch logic.
    Subclasses should:
    1. Implement the forward method
    2. Implement the dispatch logic in the __call__ method
        It can use `self.compiled_codes` to access the compiled bytecode,
        and `with self.dispatch_to_code(index):` to dispatch to
        the compiled code.
    3. Implement the `__init__` method to determine how to call
        `torch.compile` over the forward method.
    """

33
34
35
    def __init__(self,
                 compiled_callable: Optional[Callable] = None,
                 compilation_level: int = 0):
36

37
38
        vllm_config = get_current_vllm_config()
        self.vllm_config = vllm_config
39
40
41
42
        if compiled_callable is None:
            # default compilation settings
            # compiling the forward method

43
            backend = vllm_config.compilation_config.init_backend(vllm_config)
44
45
46
47
            options = None
            if isinstance(backend, str) and backend == "inductor":
                options = get_current_vllm_config(
                ).compilation_config.inductor_compile_config
48

49
50
51
52
            compiled_callable = torch.compile(self.forward,
                                              fullgraph=True,
                                              backend=backend,
                                              options=options)
53

54
55
        self.compiled_callable = compiled_callable
        self.original_code_object = self.__class__.forward.__code__
56
        self.compiled_codes: list[CodeType] = []
57
58
59
60
61
62
        torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)

        # read the env var to determine whether to use the custom dispatcher
        # subclasses can use this to switch between the custom dispatcher
        # and the default Dynamo guard mechanism.
        self.use_custom_dispatcher: bool = \
63
            compilation_level >= CompilationLevel.DYNAMO_ONCE
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

    def __call__(self, *args, **kwargs):
        """Implement the dispatch logic here, beyond the torch.compile level.
        NOTE: this function can have additional arguments beyond the forward
         method, for directly dispatching to the compiled code.
        """
        return self.compiled_callable(*args, **kwargs)

    @abstractmethod
    def forward(self, *args, **kwargs):
        ...

    def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
        """Hook to save the compiled bytecode for direct execution."""
        if old_code is not self.original_code_object:
            return
        # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
        frame = sys._getframe()
82
        while frame and frame.f_back:
83
84
85
86
87
88
89
90
91
92
93
94
            frame = frame.f_back
            code_name = frame.f_code.co_name
            file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
            if code_name == "_compile" and file_name == "convert_frame.py":
                break
        frame = frame.f_locals["frame"]
        assert frame.f_code == old_code

        if frame.f_locals["self"] is not self:
            return

        self.compiled_codes.append(new_code)
95
96
97
98
99

        path = self.vllm_config.compile_debug_dump_path()
        if path:
            decompiled_file = path / "transformed_code.py"
            if not decompiled_file.exists():
100
101
102
103
104
                try:
                    # usually the decompilation will succeed for most models,
                    # as we guarantee a full-graph compilation in Dynamo.
                    # but there's no 100% guarantee, since decompliation is
                    # not a reversible process.
105
106
                    import depyf
                    src = depyf.decompile(new_code)
107

108
109
110
111
112
113
114
                    with open(decompiled_file, "w") as f:
                        f.write(src)

                    logger.debug("Dynamo transformed code saved to %s",
                                 decompiled_file)
                except Exception:
                    pass
115

116
117
        if self.vllm_config.compilation_config.cudagraph_mode != \
            CUDAGraphMode.NONE and "update" in new_code.co_names:
118
119
120
121
122
            import depyf
            src = depyf.decompile(new_code)
            msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src  # noqa
            raise RuntimeError(msg)

123
124
125
126
127
128
129
130
131
132
133
134
135
    @contextmanager
    def dispatch_to_code(self, index: int):
        """Context manager to dispatch to the compiled code.
        Why does this work? Because Dynamo guarantees that the compiled
        bytecode has exactly the same arguments, cell variables, and free
        variables as the original code. Therefore we can directly switch
        the code object in the function and call it.

        See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
        """ # noqa
        self.__class__.forward.__code__ = self.compiled_codes[index]
        yield
        self.__class__.forward.__code__ = self.original_code_object