custom_op.py 7.95 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
from typing import Optional

6
7
import torch.nn as nn

8
from vllm.config import get_current_vllm_config
9
from vllm.logger import init_logger
10
from vllm.platforms import current_platform
11
12

logger = init_logger(__name__)
13
14
15


class CustomOp(nn.Module):
16
17
18
19
    """
    Base class for custom ops.
    Dispatches the forward method to the appropriate backend.
    """
20

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
    def __new__(cls, *args, **kwargs):
        try:
            op_name = cls.__name__
        except AttributeError:
            raise TypeError(
                f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
                f"was not set, possibly because it was not decorated with "
                f"@CustomOp.register, or it's the CustomOp base class itself."
            ) from None

        if op_name not in cls.op_registry_oot:
            op_cls_to_instantiate = cls
        else:
            op_cls_to_instantiate = cls.op_registry_oot[op_name]
            logger.debug("Instantiating custom op: %s using %s", op_name,
                         str(op_cls_to_instantiate))
        return super().__new__(op_cls_to_instantiate)

39
    def __init__(self):
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
        super().__init__()
        self._forward_method = self.dispatch_forward()

    def forward(self, *args, **kwargs):
        return self._forward_method(*args, **kwargs)

    def forward_native(self, *args, **kwargs):
        """PyTorch-native implementation of the forward method.
        This method is optional. If implemented, it can be used with compilers
        such as torch.compile or PyTorch XLA. Also, it can be used for testing
        purposes.
        """
        raise NotImplementedError

    def forward_cuda(self, *args, **kwargs):
        raise NotImplementedError

    def forward_hip(self, *args, **kwargs):
        # By default, we assume that HIP ops are compatible with CUDA ops.
        return self.forward_cuda(*args, **kwargs)

    def forward_xpu(self, *args, **kwargs):
62
63
64
        # By default, we assume that XPU ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)
65
66
67
68
69
70
71
72
73
74
75

    def forward_cpu(self, *args, **kwargs):
        # By default, we assume that CPU ops are compatible with CUDA ops.
        return self.forward_cuda(*args, **kwargs)

    def forward_tpu(self, *args, **kwargs):
        # By default, we assume that TPU ops are compatible with the
        # PyTorch-native implementation.
        # NOTE(woosuk): This is a placeholder for future extensions.
        return self.forward_native(*args, **kwargs)

76
    def forward_hpu(self, *args, **kwargs):
77
78
79
80
        # By default, we assume that Gaudi ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)

81
82
83
84
85
    def forward_neuron(self, *args, **kwargs):
        # By default, we assume that Neuron ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)

86
87
88
89
90
    def forward_oot(self, *args, **kwargs):
        # By default, we assume that OOT ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)

91
92
93
    def dispatch_forward(self):
        # NOTE(woosuk): Here we assume that vLLM was built for only one
        # specific backend. Currently, we do not support dynamic dispatching.
94
        compilation_config = get_current_vllm_config().compilation_config
95
        enabled = self.enabled()
96
97
98
99
100
        if enabled:
            compilation_config.enabled_custom_ops.update([self.__class__.name])
        else:
            compilation_config.disabled_custom_ops.update(
                [self.__class__.name])
101
102

        if not enabled:
103
104
            return self.forward_native

105
        if current_platform.is_rocm():
106
            return self.forward_hip
107
        elif current_platform.is_cpu():
108
            return self.forward_cpu
109
110
        elif current_platform.is_hpu():
            return self.forward_hpu
111
        elif current_platform.is_tpu():
112
            return self.forward_tpu
113
        elif current_platform.is_xpu():
114
            return self.forward_xpu
115
116
        elif current_platform.is_neuron():
            return self.forward_neuron
117
118
        elif current_platform.is_out_of_tree():
            return self.forward_oot
119
120
        else:
            return self.forward_cuda
121
122
123
124

    @classmethod
    def enabled(cls) -> bool:
        # if no name, then it was not registered
125
126
        compilation_config = get_current_vllm_config().compilation_config
        custom_ops = compilation_config.custom_ops
127
        if not hasattr(cls, "name"):
128
            logger.warning_once(
129
130
131
                "Custom op %s was not registered, which means it won't appear in the op registry. It will be enabled/disabled based on the global settings.",  # noqa: E501
                cls.__name__,
            )
132
133
            return CustomOp.default_on()

134
135
        enabled = f"+{cls.name}" in custom_ops
        disabled = f"-{cls.name}" in custom_ops
136
137
138
139
140
141
142
        assert not (enabled
                    and disabled), f"Cannot enable and disable {cls.name}"

        return (CustomOp.default_on() or enabled) and not disabled

    @staticmethod
    def default_on() -> bool:
143
        """
144
        On by default if PyTorch Inductor is not used.
145
146
147
148
        Specifying 'all' or 'none' in custom_op takes precedence.
        """
        from vllm.config import CompilationLevel
        compilation_config = get_current_vllm_config().compilation_config
149
150
151
152
153
        default_on = (compilation_config.level < CompilationLevel.PIECEWISE
                      or not compilation_config.use_inductor)
        count_none = compilation_config.custom_ops.count("none")
        count_all = compilation_config.custom_ops.count("all")
        return default_on and not count_none > 0 or count_all > 0
154
155
156
157
158
159

    # Dictionary of all custom ops (classes, indexed by registered name).
    # To check if an op with a name is enabled, call .enabled() on the class.
    # Examples:
    # - MyOp.enabled()
    # - op_registry["my_op"].enabled()
160
    op_registry: dict[str, type['CustomOp']] = {}
161
    op_registry_oot: dict[str, type['CustomOp']] = {}
162
163
164
165
166
167
168
169
170
171
172
173

    # Decorator to register custom ops.
    @classmethod
    def register(cls, name: str):

        def decorator(op_cls):
            assert name not in cls.op_registry, f"Duplicate op name: {name}"
            op_cls.name = name
            cls.op_registry[name] = op_cls
            return op_cls

        return decorator
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

    # Decorator to register out-of-tree(oot) custom ops.
    # For OOT custom ops:
    #   if in-tree layer class is registered with an oot_custom_op layer,
    #   the oot_custom_op layer will be used instead.
    # Example:
    # - @UnquantizedFusedMoEMethod.register_oot
    #   class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod)
    # or
    # - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod")
    @classmethod
    def register_oot(cls, _decorated_op_cls=None, name: Optional[str] = None):

        def decorator(op_cls):
            reg_name = name if name is not None else cls.__name__
            assert reg_name not in cls.op_registry_oot, \
                f"Duplicate op name: {reg_name}"
            op_cls.name = reg_name
            cls.op_registry_oot[reg_name] = op_cls
            return op_cls

        if _decorated_op_cls is None:
            # Called with parentheses: @CustomOP.register_oot()
            # or @CustomOP.register_oot(name="...")
            # So, _decorated_op_cls is None.
            # We return the actual decorator function.
            return decorator
        elif isinstance(_decorated_op_cls, type):  # Check if it's a class
            # Called without parentheses: @CustomOP.register_oot
            # The first argument is the class itself.
            # We call the 'decorator' function immediately with the class.
            return decorator(_decorated_op_cls)
        else:
            # Handle other unexpected cases if necessary
            raise TypeError("Decorator can only be applied to classes.")