custom_op.py 13.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
import functools
import inspect

6
import torch
7
8
import torch.nn as nn

9
from vllm.config import get_cached_compilation_config
10
from vllm.logger import init_logger
11
from vllm.model_executor.utils import maybe_disable_graph_partition
12
from vllm.platforms import current_platform
13
14

logger = init_logger(__name__)
15

16
17
18
19
20
21
22
23
24
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}


25
26
27
28
29
30
def get_oot_class_by_name(class_name: str) -> type | None:
    if class_name in op_registry_oot:
        return op_registry_oot[class_name]
    return None


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
class PluggableLayer(nn.Module):
    """
    Base class for pluggable layers.

    A PluggableLayer is a *module-composing* abstraction: it may instantiate other
    ``torch.nn.Module`` objects as sub-layers, and its functionality depends on
    these sub-layers following a generalized invocation sequence. Also, it is stateful
    and may hold parameters or buffers.

    Unlike :class:`CustomOp`, PluggableLayer does NOT provide per-platform
    ``forward_*`` dispatch. Instead, it supports out-of-tree (OOT) replacement
    of the entire layer class at instantiation time, allowing customized
    initialization and submodule composition.
    """

    def __new__(cls, *args, **kwargs):
        try:
            layer_class_name = cls.__name__
        except AttributeError:
            raise TypeError(
                f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
                f"was not set, possibly because it was not decorated with "
                f"@PluggableLayer.register, or it's the PluggableLayer itself."
            ) from None

        if layer_class_name not in op_registry_oot:
            layer_cls_to_instantiate = cls
        else:
            layer_cls_to_instantiate = op_registry_oot[layer_class_name]
            logger.debug(
                "Instantiating pluggable layer: %s using %s",
                layer_class_name,
                str(layer_cls_to_instantiate),
            )
        return super().__new__(layer_cls_to_instantiate)

    # Decorator to register pluggable layers.
    @classmethod
    def register(cls, name: str):
        def decorator(op_cls):
            assert name not in op_registry, f"Duplicate op name: {name}"
            op_cls.name = name
            op_registry[name] = op_cls
            return op_cls

        return decorator

    # Decorator to register out-of-tree(oot) pluggable layers.
    # For OOT pluggable layers:
    #   if in-tree layer class is registered with an oot_custom_layer,
    #   the oot_custom_layer will be used instead.
    @classmethod
    def register_oot(cls, _decorated_layer_cls=None, name: str | None = None):
        def decorator(layer_cls):
            reg_name = name if name is not None else cls.__name__
            assert reg_name not in op_registry_oot, f"Duplicate layer name: {reg_name}"
            layer_cls.name = reg_name
            op_registry_oot[reg_name] = layer_cls
            return layer_cls

        if _decorated_layer_cls is None:
            # Called with parentheses: @PluggableLayer.register_oot()
            # or @PluggableLayer.register_oot(name="...")
            return decorator
        elif isinstance(_decorated_layer_cls, type):  # Check if it's a class
            # Called without parentheses: @PluggableLayer.register_oot
            return decorator(_decorated_layer_cls)
        else:
            raise TypeError("Decorator can only be applied to classes.")


102
class CustomOp(nn.Module):
103
104
105
106
    """
    Base class for custom ops.
    Dispatches the forward method to the appropriate backend.
    """
107

108
109
110
111
112
113
114
115
116
117
    def __new__(cls, *args, **kwargs):
        try:
            op_name = cls.__name__
        except AttributeError:
            raise TypeError(
                f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
                f"was not set, possibly because it was not decorated with "
                f"@CustomOp.register, or it's the CustomOp base class itself."
            ) from None

118
        if op_name not in op_registry_oot:
119
120
            op_cls_to_instantiate = cls
        else:
121
            op_cls_to_instantiate = op_registry_oot[op_name]
122
123
124
125
126
            logger.debug(
                "Instantiating custom op: %s using %s",
                op_name,
                str(op_cls_to_instantiate),
            )
127
128
        return super().__new__(op_cls_to_instantiate)

129
    def __init__(self, *, enforce_enable: bool = False, compile_native: bool = False):
130
        super().__init__()
131
        self._enforce_enable = enforce_enable
132
        self._forward_method = self.dispatch_forward(compile_native=compile_native)
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152

    def forward(self, *args, **kwargs):
        return self._forward_method(*args, **kwargs)

    def forward_native(self, *args, **kwargs):
        """PyTorch-native implementation of the forward method.
        This method is optional. If implemented, it can be used with compilers
        such as torch.compile or PyTorch XLA. Also, it can be used for testing
        purposes.
        """
        raise NotImplementedError

    def forward_cuda(self, *args, **kwargs):
        raise NotImplementedError

    def forward_hip(self, *args, **kwargs):
        # By default, we assume that HIP ops are compatible with CUDA ops.
        return self.forward_cuda(*args, **kwargs)

    def forward_xpu(self, *args, **kwargs):
153
154
155
        # By default, we assume that XPU ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)
156
157

    def forward_cpu(self, *args, **kwargs):
158
159
160
        # By default, we assume that CPU ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)
161
162
163
164
165
166
167

    def forward_tpu(self, *args, **kwargs):
        # By default, we assume that TPU ops are compatible with the
        # PyTorch-native implementation.
        # NOTE(woosuk): This is a placeholder for future extensions.
        return self.forward_native(*args, **kwargs)

168
169
170
171
172
    def forward_oot(self, *args, **kwargs):
        # By default, we assume that OOT ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)

173
    def dispatch_forward(self, compile_native: bool):
174
175
        # NOTE(woosuk): Here we assume that vLLM was built for only one
        # specific backend. Currently, we do not support dynamic dispatching.
176
        compilation_config = get_cached_compilation_config()
177

178
179
180
181
182
183
        # NOTE(shen-shanshan): CustomOp object can be enforce enabled, e.g.,
        # enable device-specific kernels in ViT models when enabling graph
        # mode. By default, it will follow the compilation_config to determine
        # whether enable itself.
        # This enforce_enable mechanism will be removed after we adding a
        # separate compilation_config for multi-modal part.
184
        enabled = self._enforce_enable or self.enabled()
185
186
187
        if enabled:
            compilation_config.enabled_custom_ops.update([self.__class__.name])
        else:
188
            compilation_config.disabled_custom_ops.update([self.__class__.name])
189
190

        if not enabled:
191
192
193
            # Compile forward_native to avoid eager torch ops if inside
            # opaque torch custom op (e.g. fused_moe, unified_attention, etc.)
            return self.maybe_compile(self.forward_native, enable=compile_native)
194

195
        if current_platform.is_rocm():
196
            return self.forward_hip
197
        elif current_platform.is_cpu():
198
            return self.forward_cpu
199
        elif current_platform.is_tpu():
200
            return self.forward_tpu
201
        elif current_platform.is_xpu():
202
            return self.forward_xpu
203
204
        elif current_platform.is_out_of_tree():
            return self.forward_oot
205
206
        else:
            return self.forward_cuda
207

208
209
210
211
212
213
214
215
216
217
218
    def maybe_compile(self, fn, *, enable: bool = True):
        """
        Compile fn if compilation enabled.
        Useful for CustomOp instances called from within a torch custom op,
        meaning the forward call is hidden from the model-level torch.compile.

        NOTE: this does not enable fusion across ops, so opaque custom ops
        should still be unwrapped wherever possible.
        """
        from vllm.config.compilation import CompilationMode

219
        # Do not compile if compilation disabled
220
221
222
223
224
225
226
227
228
229
230
231
        if not enable:
            return fn

        # Do not compile if global compilation disabled
        compilation_config = get_cached_compilation_config()
        if compilation_config.mode == CompilationMode.NONE:
            return fn

        # If eager backend is used, do not compile either
        if compilation_config.backend == "eager":
            return fn

232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
        compile_options = maybe_disable_graph_partition(
            current_platform.simple_compile_backend
        )
        backend = current_platform.simple_compile_backend

        dynamic_arg_dims = getattr(self.__class__, "_dynamic_arg_dims", None)
        if dynamic_arg_dims is not None:
            compiled_fn = torch.compile(
                fn,
                dynamic=False,
                backend=backend,
                options=compile_options,
            )
            sig = inspect.signature(fn)

            @functools.wraps(fn)
            def wrapper(*args, **kwargs):
                bound = sig.bind(*args, **kwargs)
                bound.apply_defaults()
                for name, dims in dynamic_arg_dims.items():
                    arg = bound.arguments.get(name)
                    if arg is not None and isinstance(arg, torch.Tensor):
                        dims_list = [dims] if isinstance(dims, int) else dims
                        for d in dims_list:
                            real_d = arg.ndim + d if d < 0 else d
                            torch._dynamo.mark_dynamic(arg, real_d)
                return compiled_fn(*args, **kwargs)

            return wrapper

262
263
264
265
        # dynamic=True to avoid recompilations
        return torch.compile(
            fn,
            dynamic=True,
266
267
            backend=backend,
            options=compile_options,
268
269
        )

270
271
272
    @classmethod
    def enabled(cls) -> bool:
        # if no name, then it was not registered
273
        compilation_config = get_cached_compilation_config()
274
        custom_ops = compilation_config.custom_ops
275
        if not hasattr(cls, "name"):
276
            logger.warning_once(
277
278
279
                "Custom op %s was not registered, which means it won't appear "
                "in the op registry. It will be enabled/disabled based on the "
                "global settings.",
280
281
                cls.__name__,
            )
282
283
            return CustomOp.default_on()

284
285
        enabled = f"+{cls.name}" in custom_ops
        disabled = f"-{cls.name}" in custom_ops
286
        assert not (enabled and disabled), f"Cannot enable and disable {cls.name}"
287
288
289
290
291

        return (CustomOp.default_on() or enabled) and not disabled

    @staticmethod
    def default_on() -> bool:
292
        """
293
294
295
296
        Behavior controlled by `CompilationConfig.custom_ops`: On by default if
        'all', off by default if 'none'.
        When PyTorch Inductor is used, 'none' is the default value,
        otherwise 'all'.
297
        """
298
        compilation_config = get_cached_compilation_config()
299
300
        count_none = compilation_config.custom_ops.count("none")
        count_all = compilation_config.custom_ops.count("all")
301
302
303
        assert count_none + count_all == 1

        return not count_none > 0 or count_all > 0
304
305
306

    # Decorator to register custom ops.
    @classmethod
307
308
309
310
311
    def register(
        cls,
        name: str,
        dynamic_arg_dims: dict[str, int | list[int]] | None = None,
    ):
312
        def decorator(op_cls):
313
            assert name not in op_registry, f"Duplicate op name: {name}"
314
            op_cls.name = name
315
            op_cls._dynamic_arg_dims = dynamic_arg_dims
316
            op_registry[name] = op_cls
317
318
319
            return op_cls

        return decorator
320
321
322
323
324
325
326
327
328
329
330

    # Decorator to register out-of-tree(oot) custom ops.
    # For OOT custom ops:
    #   if in-tree layer class is registered with an oot_custom_op layer,
    #   the oot_custom_op layer will be used instead.
    # Example:
    # - @UnquantizedFusedMoEMethod.register_oot
    #   class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod)
    # or
    # - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod")
    @classmethod
331
    def register_oot(cls, _decorated_op_cls=None, name: str | None = None):
332
333
        def decorator(op_cls):
            reg_name = name if name is not None else cls.__name__
334
            assert reg_name not in op_registry_oot, f"Duplicate op name: {reg_name}"
335
            op_cls.name = reg_name
336
            op_registry_oot[reg_name] = op_cls
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
            return op_cls

        if _decorated_op_cls is None:
            # Called with parentheses: @CustomOP.register_oot()
            # or @CustomOP.register_oot(name="...")
            # So, _decorated_op_cls is None.
            # We return the actual decorator function.
            return decorator
        elif isinstance(_decorated_op_cls, type):  # Check if it's a class
            # Called without parentheses: @CustomOP.register_oot
            # The first argument is the class itself.
            # We call the 'decorator' function immediately with the class.
            return decorator(_decorated_op_cls)
        else:
            # Handle other unexpected cases if necessary
            raise TypeError("Decorator can only be applied to classes.")