decorators.py 7.25 KB
Newer Older
1
import inspect
2
from typing import Dict, List, Optional, Union
3
4
5
6
7
8

import torch

import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
9
from vllm.logger import init_logger
10
11
12
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo

13
logger = init_logger(__name__)
14

15
16
17
18

def support_torch_compile(
        cls: Optional[type] = None,
        dynamic_arg_dims: Optional[Dict[str, Union[int, List[int]]]] = None):
19
20
21
    """
    A decorator to add support for compiling the forward method of a class.

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    Usage 1: use directly as a decorator without arguments:

    ```python
    @support_torch_compile
    class MyModel(nn.Module):
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
            ...
    ```

    Usage 2: use as a decorator with arguments:

    ```python
    @support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
    class MyModel(nn.Module):
        def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
            ...
    ```

40
41
42
43
    `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
    dimensions of the argument. The dynamic dimensions can be either a single
    integer or a list of integers.

44
45
46
47
48
49
50
51
52
53
54
55
    if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
    of the `forward` method, based on the following default rules:

    - if the argument is annotated as `torch.Tensor` or
        `Optional[torch.Tensor]`, the first dimension will be
        marked as dynamic.
    - if the argument is annotated as `IntermediateTensors`, the first
        dimension of all the tensors in the intermediate tensors
        will be marked as dynamic.

    During runtime, when we actually mark dimensions of tensors,
     it depends on the value of arguments:
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

    - if it is a single integer, the corresponding dimension of the argument
        will be marked as dynamic.
    - if it is `None`, ignored.
    - if it is `IntermediateTensors`, all the tensors in the intermediate
        tensors will be marked as dynamic.
    - otherwise, it will raise an error.

    NOTE: if an argument is `None`, it should always be passed as `None` during
    the lifetime of the model, otherwise, it cannot be captured as a single
    computation graph.
    """

    def cls_decorator_helper(cls: type):
        # helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
        # to avoid too much indentation for `_support_torch_compile``
72
73
        if not hasattr(cls, 'forward'):
            raise TypeError("decorated class should have a forward method.")
74
        sig = inspect.signature(cls.forward)
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        inferred_dynamic_arg_dims = dynamic_arg_dims
        if inferred_dynamic_arg_dims is None:
            inferred_dynamic_arg_dims = {}
            for k, v in sig.parameters.items():
                if v.annotation in [
                        torch.Tensor, Optional[torch.Tensor],
                        IntermediateTensors, Optional[IntermediateTensors]
                ]:
                    inferred_dynamic_arg_dims[k] = 0

            logger.debug(("Inferred dynamic dimensions for "
                          "forward method of %s: %s"), cls,
                         list(inferred_dynamic_arg_dims.keys()))

        if len(inferred_dynamic_arg_dims) == 0:
            raise ValueError(
                "No dynamic dimensions found in the forward method of "
                f"{cls}. Please provide dynamic_arg_dims explicitly.")

        for k in inferred_dynamic_arg_dims:
95
96
97
            if k not in sig.parameters:
                raise ValueError(
                    f"Argument {k} not found in the forward method of {cls}")
98
99
100
101
102
103
        return _support_torch_compile(cls, inferred_dynamic_arg_dims)

    if cls is not None:
        # use `support_torch_compile` as a decorator without arguments
        assert isinstance(cls, type)
        return cls_decorator_helper(cls)
104
105
106
107
108
109

    return cls_decorator_helper


def _support_torch_compile(cls: type,
                           dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    """
    A decorator to add support for compiling the forward method of a class.
    """

    # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
    # will handle the compilation, so we don't need to do anything here.
    if envs.VLLM_TORCH_COMPILE_LEVEL in [
            CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
    ] or not supports_dynamo():
        return cls

    # take care of method resolution order
    # make sure super().__init__ is called on the base class
    #  other than TorchCompileWrapperWithCustomDispatcher
124
125
126
127
    if TorchCompileWrapperWithCustomDispatcher not in cls.__bases__:
        # support decorating multiple times
        cls.__bases__ = cls.__bases__ + (
            TorchCompileWrapperWithCustomDispatcher, )
128

129
    old_init = cls.__init__  # type: ignore
130
131
132
133
134

    def __init__(self, *args, **kwargs):
        old_init(self, *args, **kwargs)
        TorchCompileWrapperWithCustomDispatcher.__init__(self)

135
    cls.__init__ = __init__  # type: ignore
136

137
    def __call__(self, *args, **kwargs):
138
139
140
141
        # torch.compiler.is_compiling() means we are inside the compilation
        # e.g. TPU has the compilation logic in model runner, so we don't
        # need to compile the model inside.
        if torch.compiler.is_compiling():
142
            return self.forward(*args, **kwargs)
143
144
145

        # the first compilation needs to have dynamic shapes marked
        if len(self.compiled_codes) < 1:
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
            sig = inspect.signature(self.__class__.forward)
            bound_args = sig.bind(self, *args, **kwargs)
            bound_args.apply_defaults()
            for k, dims in dynamic_arg_dims.items():
                arg = bound_args.arguments.get(k)
                if arg is not None:
                    if isinstance(arg, torch.Tensor):
                        torch._dynamo.mark_dynamic(arg, dims)
                    elif isinstance(arg, IntermediateTensors):
                        for tensor in arg.tensors.values():
                            torch._dynamo.mark_dynamic(tensor, dims)
                    else:
                        raise ValueError(
                            "Unsupported dynamic dimensions"
                            f" {dims} for argument {k} with type {type(arg)}.")
161
162
163
164
165

        # if we don't use custom dispatcher, we can directly call the
        # compiled function and let torch.compile handle the dispatching,
        # with the overhead of guard evaluation and recompilation.
        if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
166
167
168
169
170
            # it seems Dynamo reuse the compilation across instances,
            # while we need to make sure the compiled code is not reused.
            # we need to control all the compilation of the model.
            torch._dynamo.eval_frame.remove_from_cache(
                self.original_code_object)
171
            return self.compiled_callable(*args, **kwargs)
172
173
174
175
176

        # usually, capturing the model once is enough, and then we can
        # dispatch to the compiled code directly, without going through
        # the Dynamo guard mechanism.
        with self.dispatch_to_code(0):
177
            model_output = self.forward(*args, **kwargs)
178
179
            return model_output

180
    cls.__call__ = __call__  # type: ignore
181
    return cls