decorators.py 4.68 KB
Newer Older
1
2
import inspect
from typing import Dict, List, Union
3
4
5
6
7
8
9
10
11
12

import torch

import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo


13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
    """
    A decorator to add support for compiling the forward method of a class.

    `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
    dimensions of the argument. The dynamic dimensions can be either a single
    integer or a list of integers.

    Depending on the value of arguments:

    - if it is a single integer, the corresponding dimension of the argument
        will be marked as dynamic.
    - if it is `None`, ignored.
    - if it is `IntermediateTensors`, all the tensors in the intermediate
        tensors will be marked as dynamic.
    - otherwise, it will raise an error.

    NOTE: if an argument is `None`, it should always be passed as `None` during
    the lifetime of the model, otherwise, it cannot be captured as a single
    computation graph.
    """

    def cls_decorator_helper(cls: type):
        # helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
        # to avoid too much indentation for `_support_torch_compile``
38
39
        if not hasattr(cls, 'forward'):
            raise TypeError("decorated class should have a forward method.")
40
41
42
43
44
45
46
47
48
49
50
51
        sig = inspect.signature(cls.forward)
        for k in dynamic_arg_dims:
            if k not in sig.parameters:
                raise ValueError(
                    f"Argument {k} not found in the forward method of {cls}")
        return _support_torch_compile(cls, dynamic_arg_dims)

    return cls_decorator_helper


def _support_torch_compile(cls: type,
                           dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    """
    A decorator to add support for compiling the forward method of a class.
    """

    # for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
    # will handle the compilation, so we don't need to do anything here.
    if envs.VLLM_TORCH_COMPILE_LEVEL in [
            CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
    ] or not supports_dynamo():
        return cls

    # take care of method resolution order
    # make sure super().__init__ is called on the base class
    #  other than TorchCompileWrapperWithCustomDispatcher
    cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )

68
    old_init = cls.__init__  # type: ignore
69
70
71
72
73

    def __init__(self, *args, **kwargs):
        old_init(self, *args, **kwargs)
        TorchCompileWrapperWithCustomDispatcher.__init__(self)

74
    cls.__init__ = __init__  # type: ignore
75

76
    def __call__(self, *args, **kwargs):
77
78
79
80
        # torch.compiler.is_compiling() means we are inside the compilation
        # e.g. TPU has the compilation logic in model runner, so we don't
        # need to compile the model inside.
        if torch.compiler.is_compiling():
81
            return self.forward(*args, **kwargs)
82
83
84

        # the first compilation needs to have dynamic shapes marked
        if len(self.compiled_codes) < 1:
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
            sig = inspect.signature(self.__class__.forward)
            bound_args = sig.bind(self, *args, **kwargs)
            bound_args.apply_defaults()
            for k, dims in dynamic_arg_dims.items():
                arg = bound_args.arguments.get(k)
                if arg is not None:
                    if isinstance(arg, torch.Tensor):
                        torch._dynamo.mark_dynamic(arg, dims)
                    elif isinstance(arg, IntermediateTensors):
                        for tensor in arg.tensors.values():
                            torch._dynamo.mark_dynamic(tensor, dims)
                    else:
                        raise ValueError(
                            "Unsupported dynamic dimensions"
                            f" {dims} for argument {k} with type {type(arg)}.")
100
101
102
103
104

        # if we don't use custom dispatcher, we can directly call the
        # compiled function and let torch.compile handle the dispatching,
        # with the overhead of guard evaluation and recompilation.
        if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
105
            return self.compiled_callable(*args, **kwargs)
106
107
108
109
110

        # usually, capturing the model once is enough, and then we can
        # dispatch to the compiled code directly, without going through
        # the Dynamo guard mechanism.
        with self.dispatch_to_code(0):
111
            model_output = self.forward(*args, **kwargs)
112
113
            return model_output

114
    cls.__call__ = __call__  # type: ignore
115
    return cls