custom_op.py 13.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
import functools
import inspect

6
import torch
7
8
import torch.nn as nn

9
from vllm.config import get_cached_compilation_config
10
from vllm.logger import init_logger
11
from vllm.model_executor.utils import maybe_disable_graph_partition
12
from vllm.platforms import current_platform
13
14

logger = init_logger(__name__)
15

16
17
18
19
20
21
22
23
24
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}


25
26
def maybe_get_oot_by_class(class_type: type) -> type:
    class_name = class_type.__name__
27
28
    if class_name in op_registry_oot:
        return op_registry_oot[class_name]
29
    return class_type
30
31


32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
class PluggableLayer(nn.Module):
    """
    Base class for pluggable layers.

    A PluggableLayer is a *module-composing* abstraction: it may instantiate other
    ``torch.nn.Module`` objects as sub-layers, and its functionality depends on
    these sub-layers following a generalized invocation sequence. Also, it is stateful
    and may hold parameters or buffers.

    Unlike :class:`CustomOp`, PluggableLayer does NOT provide per-platform
    ``forward_*`` dispatch. Instead, it supports out-of-tree (OOT) replacement
    of the entire layer class at instantiation time, allowing customized
    initialization and submodule composition.
    """

    def __new__(cls, *args, **kwargs):
        try:
            layer_class_name = cls.__name__
        except AttributeError:
            raise TypeError(
                f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
                f"was not set, possibly because it was not decorated with "
                f"@PluggableLayer.register, or it's the PluggableLayer itself."
            ) from None

        if layer_class_name not in op_registry_oot:
            layer_cls_to_instantiate = cls
        else:
            layer_cls_to_instantiate = op_registry_oot[layer_class_name]
            logger.debug(
                "Instantiating pluggable layer: %s using %s",
                layer_class_name,
                str(layer_cls_to_instantiate),
            )
        return super().__new__(layer_cls_to_instantiate)

    # Decorator to register pluggable layers.
    @classmethod
    def register(cls, name: str):
        def decorator(op_cls):
            assert name not in op_registry, f"Duplicate op name: {name}"
            op_cls.name = name
            op_registry[name] = op_cls
            return op_cls

        return decorator

    # Decorator to register out-of-tree(oot) pluggable layers.
    # For OOT pluggable layers:
    #   if in-tree layer class is registered with an oot_custom_layer,
    #   the oot_custom_layer will be used instead.
    @classmethod
    def register_oot(cls, _decorated_layer_cls=None, name: str | None = None):
        def decorator(layer_cls):
            reg_name = name if name is not None else cls.__name__
            assert reg_name not in op_registry_oot, f"Duplicate layer name: {reg_name}"
            layer_cls.name = reg_name
            op_registry_oot[reg_name] = layer_cls
            return layer_cls

        if _decorated_layer_cls is None:
            # Called with parentheses: @PluggableLayer.register_oot()
            # or @PluggableLayer.register_oot(name="...")
            return decorator
        elif isinstance(_decorated_layer_cls, type):  # Check if it's a class
            # Called without parentheses: @PluggableLayer.register_oot
            return decorator(_decorated_layer_cls)
        else:
            raise TypeError("Decorator can only be applied to classes.")


103
class CustomOp(nn.Module):
104
105
106
107
    """
    Base class for custom ops.
    Dispatches the forward method to the appropriate backend.
    """
108

109
110
111
112
113
114
115
116
117
118
    def __new__(cls, *args, **kwargs):
        try:
            op_name = cls.__name__
        except AttributeError:
            raise TypeError(
                f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
                f"was not set, possibly because it was not decorated with "
                f"@CustomOp.register, or it's the CustomOp base class itself."
            ) from None

119
        if op_name not in op_registry_oot:
120
121
            op_cls_to_instantiate = cls
        else:
122
            op_cls_to_instantiate = op_registry_oot[op_name]
123
124
125
126
127
            logger.debug(
                "Instantiating custom op: %s using %s",
                op_name,
                str(op_cls_to_instantiate),
            )
128
129
        return super().__new__(op_cls_to_instantiate)

130
    def __init__(self, *, enforce_enable: bool = False, compile_native: bool = False):
131
        super().__init__()
132
        self._enforce_enable = enforce_enable
133
        self._forward_method = self.dispatch_forward(compile_native=compile_native)
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

    def forward(self, *args, **kwargs):
        return self._forward_method(*args, **kwargs)

    def forward_native(self, *args, **kwargs):
        """PyTorch-native implementation of the forward method.
        This method is optional. If implemented, it can be used with compilers
        such as torch.compile or PyTorch XLA. Also, it can be used for testing
        purposes.
        """
        raise NotImplementedError

    def forward_cuda(self, *args, **kwargs):
        raise NotImplementedError

    def forward_hip(self, *args, **kwargs):
        # By default, we assume that HIP ops are compatible with CUDA ops.
        return self.forward_cuda(*args, **kwargs)

    def forward_xpu(self, *args, **kwargs):
154
155
156
        # By default, we assume that XPU ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)
157
158

    def forward_cpu(self, *args, **kwargs):
159
160
161
        # By default, we assume that CPU ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)
162
163
164
165
166
167
168

    def forward_tpu(self, *args, **kwargs):
        # By default, we assume that TPU ops are compatible with the
        # PyTorch-native implementation.
        # NOTE(woosuk): This is a placeholder for future extensions.
        return self.forward_native(*args, **kwargs)

169
170
171
172
173
    def forward_oot(self, *args, **kwargs):
        # By default, we assume that OOT ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)

174
    def dispatch_forward(self, compile_native: bool):
175
176
        # NOTE(woosuk): Here we assume that vLLM was built for only one
        # specific backend. Currently, we do not support dynamic dispatching.
177
        compilation_config = get_cached_compilation_config()
178

179
180
181
182
183
184
        # NOTE(shen-shanshan): CustomOp object can be enforce enabled, e.g.,
        # enable device-specific kernels in ViT models when enabling graph
        # mode. By default, it will follow the compilation_config to determine
        # whether enable itself.
        # This enforce_enable mechanism will be removed after we adding a
        # separate compilation_config for multi-modal part.
185
        enabled = self._enforce_enable or self.enabled()
186
187
188
        if enabled:
            compilation_config.enabled_custom_ops.update([self.__class__.name])
        else:
189
            compilation_config.disabled_custom_ops.update([self.__class__.name])
190
191

        if not enabled:
192
193
194
            # Compile forward_native to avoid eager torch ops if inside
            # opaque torch custom op (e.g. fused_moe, unified_attention, etc.)
            return self.maybe_compile(self.forward_native, enable=compile_native)
195

196
        if current_platform.is_rocm():
197
            return self.forward_hip
198
        elif current_platform.is_cpu():
199
            return self.forward_cpu
200
        elif current_platform.is_tpu():
201
            return self.forward_tpu
202
        elif current_platform.is_xpu():
203
            return self.forward_xpu
204
205
        elif current_platform.is_out_of_tree():
            return self.forward_oot
206
207
        else:
            return self.forward_cuda
208

209
210
211
212
213
214
215
216
217
218
219
    def maybe_compile(self, fn, *, enable: bool = True):
        """
        Compile fn if compilation enabled.
        Useful for CustomOp instances called from within a torch custom op,
        meaning the forward call is hidden from the model-level torch.compile.

        NOTE: this does not enable fusion across ops, so opaque custom ops
        should still be unwrapped wherever possible.
        """
        from vllm.config.compilation import CompilationMode

220
        # Do not compile if compilation disabled
221
222
223
224
225
226
227
228
229
230
231
232
        if not enable:
            return fn

        # Do not compile if global compilation disabled
        compilation_config = get_cached_compilation_config()
        if compilation_config.mode == CompilationMode.NONE:
            return fn

        # If eager backend is used, do not compile either
        if compilation_config.backend == "eager":
            return fn

233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
        compile_options = maybe_disable_graph_partition(
            current_platform.simple_compile_backend
        )
        backend = current_platform.simple_compile_backend

        dynamic_arg_dims = getattr(self.__class__, "_dynamic_arg_dims", None)
        if dynamic_arg_dims is not None:
            compiled_fn = torch.compile(
                fn,
                dynamic=False,
                backend=backend,
                options=compile_options,
            )
            sig = inspect.signature(fn)

            @functools.wraps(fn)
            def wrapper(*args, **kwargs):
                bound = sig.bind(*args, **kwargs)
                bound.apply_defaults()
                for name, dims in dynamic_arg_dims.items():
                    arg = bound.arguments.get(name)
                    if arg is not None and isinstance(arg, torch.Tensor):
                        dims_list = [dims] if isinstance(dims, int) else dims
                        for d in dims_list:
                            real_d = arg.ndim + d if d < 0 else d
                            torch._dynamo.mark_dynamic(arg, real_d)
                return compiled_fn(*args, **kwargs)

            return wrapper

263
264
265
266
        # dynamic=True to avoid recompilations
        return torch.compile(
            fn,
            dynamic=True,
267
268
            backend=backend,
            options=compile_options,
269
270
        )

271
272
273
    @classmethod
    def enabled(cls) -> bool:
        # if no name, then it was not registered
274
        compilation_config = get_cached_compilation_config()
275
        custom_ops = compilation_config.custom_ops
276
        if not hasattr(cls, "name"):
277
            logger.warning_once(
278
279
280
                "Custom op %s was not registered, which means it won't appear "
                "in the op registry. It will be enabled/disabled based on the "
                "global settings.",
281
282
                cls.__name__,
            )
283
284
            return CustomOp.default_on()

285
286
        enabled = f"+{cls.name}" in custom_ops
        disabled = f"-{cls.name}" in custom_ops
287
        assert not (enabled and disabled), f"Cannot enable and disable {cls.name}"
288
289
290
291
292

        return (CustomOp.default_on() or enabled) and not disabled

    @staticmethod
    def default_on() -> bool:
293
        """
294
295
296
297
        Behavior controlled by `CompilationConfig.custom_ops`: On by default if
        'all', off by default if 'none'.
        When PyTorch Inductor is used, 'none' is the default value,
        otherwise 'all'.
298
        """
299
        compilation_config = get_cached_compilation_config()
300
301
        count_none = compilation_config.custom_ops.count("none")
        count_all = compilation_config.custom_ops.count("all")
302
303
304
        assert count_none + count_all == 1

        return not count_none > 0 or count_all > 0
305
306
307

    # Decorator to register custom ops.
    @classmethod
308
309
310
311
312
    def register(
        cls,
        name: str,
        dynamic_arg_dims: dict[str, int | list[int]] | None = None,
    ):
313
        def decorator(op_cls):
314
            assert name not in op_registry, f"Duplicate op name: {name}"
315
            op_cls.name = name
316
            op_cls._dynamic_arg_dims = dynamic_arg_dims
317
            op_registry[name] = op_cls
318
319
320
            return op_cls

        return decorator
321
322
323
324
325
326
327
328
329
330
331

    # Decorator to register out-of-tree(oot) custom ops.
    # For OOT custom ops:
    #   if in-tree layer class is registered with an oot_custom_op layer,
    #   the oot_custom_op layer will be used instead.
    # Example:
    # - @UnquantizedFusedMoEMethod.register_oot
    #   class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod)
    # or
    # - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod")
    @classmethod
332
    def register_oot(cls, _decorated_op_cls=None, name: str | None = None):
333
334
        def decorator(op_cls):
            reg_name = name if name is not None else cls.__name__
335
            assert reg_name not in op_registry_oot, f"Duplicate op name: {reg_name}"
336
            op_cls.name = reg_name
337
            op_registry_oot[reg_name] = op_cls
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
            return op_cls

        if _decorated_op_cls is None:
            # Called with parentheses: @CustomOP.register_oot()
            # or @CustomOP.register_oot(name="...")
            # So, _decorated_op_cls is None.
            # We return the actual decorator function.
            return decorator
        elif isinstance(_decorated_op_cls, type):  # Check if it's a class
            # Called without parentheses: @CustomOP.register_oot
            # The first argument is the class itself.
            # We call the 'decorator' function immediately with the class.
            return decorator(_decorated_op_cls)
        else:
            # Handle other unexpected cases if necessary
            raise TypeError("Decorator can only be applied to classes.")