custom_op.py 12.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import torch
4
5
import torch.nn as nn

6
from vllm.config import get_cached_compilation_config
7
from vllm.logger import init_logger
8
from vllm.model_executor.utils import maybe_disable_graph_partition
9
from vllm.platforms import current_platform
10
11

logger = init_logger(__name__)
12

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}


class PluggableLayer(nn.Module):
    """
    Base class for pluggable layers.

    A PluggableLayer is a *module-composing* abstraction: it may instantiate other
    ``torch.nn.Module`` objects as sub-layers, and its functionality depends on
    these sub-layers following a generalized invocation sequence. Also, it is stateful
    and may hold parameters or buffers.

    Unlike :class:`CustomOp`, PluggableLayer does NOT provide per-platform
    ``forward_*`` dispatch. Instead, it supports out-of-tree (OOT) replacement
    of the entire layer class at instantiation time, allowing customized
    initialization and submodule composition.
    """

    def __new__(cls, *args, **kwargs):
        try:
            layer_class_name = cls.__name__
        except AttributeError:
            raise TypeError(
                f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
                f"was not set, possibly because it was not decorated with "
                f"@PluggableLayer.register, or it's the PluggableLayer itself."
            ) from None

        if layer_class_name not in op_registry_oot:
            layer_cls_to_instantiate = cls
        else:
            layer_cls_to_instantiate = op_registry_oot[layer_class_name]
            logger.debug(
                "Instantiating pluggable layer: %s using %s",
                layer_class_name,
                str(layer_cls_to_instantiate),
            )
        return super().__new__(layer_cls_to_instantiate)

    # Decorator to register pluggable layers.
    @classmethod
    def register(cls, name: str):
        def decorator(op_cls):
            assert name not in op_registry, f"Duplicate op name: {name}"
            op_cls.name = name
            op_registry[name] = op_cls
            return op_cls

        return decorator

    # Decorator to register out-of-tree(oot) pluggable layers.
    # For OOT pluggable layers:
    #   if in-tree layer class is registered with an oot_custom_layer,
    #   the oot_custom_layer will be used instead.
    @classmethod
    def register_oot(cls, _decorated_layer_cls=None, name: str | None = None):
        def decorator(layer_cls):
            reg_name = name if name is not None else cls.__name__
            assert reg_name not in op_registry_oot, f"Duplicate layer name: {reg_name}"
            layer_cls.name = reg_name
            op_registry_oot[reg_name] = layer_cls
            return layer_cls

        if _decorated_layer_cls is None:
            # Called with parentheses: @PluggableLayer.register_oot()
            # or @PluggableLayer.register_oot(name="...")
            return decorator
        elif isinstance(_decorated_layer_cls, type):  # Check if it's a class
            # Called without parentheses: @PluggableLayer.register_oot
            return decorator(_decorated_layer_cls)
        else:
            raise TypeError("Decorator can only be applied to classes.")


93
class CustomOp(nn.Module):
94
95
96
97
    """
    Base class for custom ops.
    Dispatches the forward method to the appropriate backend.
    """
98

99
100
101
102
103
104
105
106
107
108
    def __new__(cls, *args, **kwargs):
        try:
            op_name = cls.__name__
        except AttributeError:
            raise TypeError(
                f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
                f"was not set, possibly because it was not decorated with "
                f"@CustomOp.register, or it's the CustomOp base class itself."
            ) from None

109
        if op_name not in op_registry_oot:
110
111
            op_cls_to_instantiate = cls
        else:
112
            op_cls_to_instantiate = op_registry_oot[op_name]
113
114
115
116
117
            logger.debug(
                "Instantiating custom op: %s using %s",
                op_name,
                str(op_cls_to_instantiate),
            )
118
119
        return super().__new__(op_cls_to_instantiate)

120
    def __init__(self, *, enforce_enable: bool = False, compile_native: bool = False):
121
        super().__init__()
122
        self._enforce_enable = enforce_enable
123
        self._forward_method = self.dispatch_forward(compile_native=compile_native)
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

    def forward(self, *args, **kwargs):
        return self._forward_method(*args, **kwargs)

    def forward_native(self, *args, **kwargs):
        """PyTorch-native implementation of the forward method.
        This method is optional. If implemented, it can be used with compilers
        such as torch.compile or PyTorch XLA. Also, it can be used for testing
        purposes.
        """
        raise NotImplementedError

    def forward_cuda(self, *args, **kwargs):
        raise NotImplementedError

    def forward_hip(self, *args, **kwargs):
        # By default, we assume that HIP ops are compatible with CUDA ops.
        return self.forward_cuda(*args, **kwargs)

    def forward_xpu(self, *args, **kwargs):
144
145
146
        # By default, we assume that XPU ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)
147
148

    def forward_cpu(self, *args, **kwargs):
149
150
151
        # By default, we assume that CPU ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)
152
153
154
155
156
157
158

    def forward_tpu(self, *args, **kwargs):
        # By default, we assume that TPU ops are compatible with the
        # PyTorch-native implementation.
        # NOTE(woosuk): This is a placeholder for future extensions.
        return self.forward_native(*args, **kwargs)

159
160
161
162
163
    def forward_oot(self, *args, **kwargs):
        # By default, we assume that OOT ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)

164
    def dispatch_forward(self, compile_native: bool):
165
166
        # NOTE(woosuk): Here we assume that vLLM was built for only one
        # specific backend. Currently, we do not support dynamic dispatching.
167
        compilation_config = get_cached_compilation_config()
168

169
170
171
172
173
174
        # NOTE(shen-shanshan): CustomOp object can be enforce enabled, e.g.,
        # enable device-specific kernels in ViT models when enabling graph
        # mode. By default, it will follow the compilation_config to determine
        # whether enable itself.
        # This enforce_enable mechanism will be removed after we adding a
        # separate compilation_config for multi-modal part.
175
        enabled = self._enforce_enable or self.enabled()
176
177
178
        if enabled:
            compilation_config.enabled_custom_ops.update([self.__class__.name])
        else:
179
            compilation_config.disabled_custom_ops.update([self.__class__.name])
180
181

        if not enabled:
182
183
184
            # Compile forward_native to avoid eager torch ops if inside
            # opaque torch custom op (e.g. fused_moe, unified_attention, etc.)
            return self.maybe_compile(self.forward_native, enable=compile_native)
185

186
        if current_platform.is_rocm():
liuchy5's avatar
liuchy5 committed
187
188
            if self.__class__.name == "sparse_attn_indexer":
                return self.forward_cuda
liuchy5's avatar
liuchy5 committed
189
            return self.forward_cuda
190
        elif current_platform.is_cpu():
191
            return self.forward_cpu
192
        elif current_platform.is_tpu():
193
            return self.forward_tpu
194
        elif current_platform.is_xpu():
195
            return self.forward_xpu
196
197
        elif current_platform.is_out_of_tree():
            return self.forward_oot
198
199
        else:
            return self.forward_cuda
200

201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    def maybe_compile(self, fn, *, enable: bool = True):
        """
        Compile fn if compilation enabled.
        Useful for CustomOp instances called from within a torch custom op,
        meaning the forward call is hidden from the model-level torch.compile.

        NOTE: this does not enable fusion across ops, so opaque custom ops
        should still be unwrapped wherever possible.
        """
        # Do not compile if compilation disabled
        from vllm.config.compilation import CompilationMode

        if not enable:
            return fn

        # Do not compile if global compilation disabled
        compilation_config = get_cached_compilation_config()
        if compilation_config.mode == CompilationMode.NONE:
            return fn

        # If eager backend is used, do not compile either
        if compilation_config.backend == "eager":
            return fn

        # dynamic=True to avoid recompilations
        return torch.compile(
            fn,
            dynamic=True,
            backend=current_platform.simple_compile_backend,
            options=maybe_disable_graph_partition(
                current_platform.simple_compile_backend
            ),
        )

235
236
237
    @classmethod
    def enabled(cls) -> bool:
        # if no name, then it was not registered
238
        compilation_config = get_cached_compilation_config()
239
        custom_ops = compilation_config.custom_ops
240
        if not hasattr(cls, "name"):
241
            logger.warning_once(
242
243
244
                "Custom op %s was not registered, which means it won't appear "
                "in the op registry. It will be enabled/disabled based on the "
                "global settings.",
245
246
                cls.__name__,
            )
247
248
            return CustomOp.default_on()

249
250
        enabled = f"+{cls.name}" in custom_ops
        disabled = f"-{cls.name}" in custom_ops
251
        assert not (enabled and disabled), f"Cannot enable and disable {cls.name}"
252
253
254
255
256

        return (CustomOp.default_on() or enabled) and not disabled

    @staticmethod
    def default_on() -> bool:
257
        """
258
259
260
261
        Behavior controlled by `CompilationConfig.custom_ops`: On by default if
        'all', off by default if 'none'.
        When PyTorch Inductor is used, 'none' is the default value,
        otherwise 'all'.
262
        """
263
        compilation_config = get_cached_compilation_config()
264
265
        count_none = compilation_config.custom_ops.count("none")
        count_all = compilation_config.custom_ops.count("all")
266
267
268
        assert count_none + count_all == 1

        return not count_none > 0 or count_all > 0
269
270
271
272
273

    # Decorator to register custom ops.
    @classmethod
    def register(cls, name: str):
        def decorator(op_cls):
274
            assert name not in op_registry, f"Duplicate op name: {name}"
275
            op_cls.name = name
276
            op_registry[name] = op_cls
277
278
279
            return op_cls

        return decorator
280
281
282
283
284
285
286
287
288
289
290

    # Decorator to register out-of-tree(oot) custom ops.
    # For OOT custom ops:
    #   if in-tree layer class is registered with an oot_custom_op layer,
    #   the oot_custom_op layer will be used instead.
    # Example:
    # - @UnquantizedFusedMoEMethod.register_oot
    #   class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod)
    # or
    # - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod")
    @classmethod
291
    def register_oot(cls, _decorated_op_cls=None, name: str | None = None):
292
293
        def decorator(op_cls):
            reg_name = name if name is not None else cls.__name__
294
            assert reg_name not in op_registry_oot, f"Duplicate op name: {reg_name}"
295
            op_cls.name = reg_name
296
            op_registry_oot[reg_name] = op_cls
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
            return op_cls

        if _decorated_op_cls is None:
            # Called with parentheses: @CustomOP.register_oot()
            # or @CustomOP.register_oot(name="...")
            # So, _decorated_op_cls is None.
            # We return the actual decorator function.
            return decorator
        elif isinstance(_decorated_op_cls, type):  # Check if it's a class
            # Called without parentheses: @CustomOP.register_oot
            # The first argument is the class itself.
            # We call the 'decorator' function immediately with the class.
            return decorator(_decorated_op_cls)
        else:
            # Handle other unexpected cases if necessary
            raise TypeError("Decorator can only be applied to classes.")