custom_op.py 13.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
import functools
import inspect

6
import torch
7
8
import torch.nn as nn

9
from vllm.config import get_cached_compilation_config
10
from vllm.logger import init_logger
11
from vllm.model_executor.utils import maybe_disable_graph_partition
12
from vllm.platforms import current_platform
13
14

logger = init_logger(__name__)
15

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}


class PluggableLayer(nn.Module):
    """
    Base class for pluggable layers.

    A PluggableLayer is a *module-composing* abstraction: it may instantiate other
    ``torch.nn.Module`` objects as sub-layers, and its functionality depends on
    these sub-layers following a generalized invocation sequence. Also, it is stateful
    and may hold parameters or buffers.

    Unlike :class:`CustomOp`, PluggableLayer does NOT provide per-platform
    ``forward_*`` dispatch. Instead, it supports out-of-tree (OOT) replacement
    of the entire layer class at instantiation time, allowing customized
    initialization and submodule composition.
    """

    def __new__(cls, *args, **kwargs):
        try:
            layer_class_name = cls.__name__
        except AttributeError:
            raise TypeError(
                f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
                f"was not set, possibly because it was not decorated with "
                f"@PluggableLayer.register, or it's the PluggableLayer itself."
            ) from None

        if layer_class_name not in op_registry_oot:
            layer_cls_to_instantiate = cls
        else:
            layer_cls_to_instantiate = op_registry_oot[layer_class_name]
            logger.debug(
                "Instantiating pluggable layer: %s using %s",
                layer_class_name,
                str(layer_cls_to_instantiate),
            )
        return super().__new__(layer_cls_to_instantiate)

    # Decorator to register pluggable layers.
    @classmethod
    def register(cls, name: str):
        def decorator(op_cls):
            assert name not in op_registry, f"Duplicate op name: {name}"
            op_cls.name = name
            op_registry[name] = op_cls
            return op_cls

        return decorator

    # Decorator to register out-of-tree(oot) pluggable layers.
    # For OOT pluggable layers:
    #   if in-tree layer class is registered with an oot_custom_layer,
    #   the oot_custom_layer will be used instead.
    @classmethod
    def register_oot(cls, _decorated_layer_cls=None, name: str | None = None):
        def decorator(layer_cls):
            reg_name = name if name is not None else cls.__name__
            assert reg_name not in op_registry_oot, f"Duplicate layer name: {reg_name}"
            layer_cls.name = reg_name
            op_registry_oot[reg_name] = layer_cls
            return layer_cls

        if _decorated_layer_cls is None:
            # Called with parentheses: @PluggableLayer.register_oot()
            # or @PluggableLayer.register_oot(name="...")
            return decorator
        elif isinstance(_decorated_layer_cls, type):  # Check if it's a class
            # Called without parentheses: @PluggableLayer.register_oot
            return decorator(_decorated_layer_cls)
        else:
            raise TypeError("Decorator can only be applied to classes.")


96
class CustomOp(nn.Module):
97
98
99
100
    """
    Base class for custom ops.
    Dispatches the forward method to the appropriate backend.
    """
101

102
103
104
105
106
107
108
109
110
111
    def __new__(cls, *args, **kwargs):
        try:
            op_name = cls.__name__
        except AttributeError:
            raise TypeError(
                f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
                f"was not set, possibly because it was not decorated with "
                f"@CustomOp.register, or it's the CustomOp base class itself."
            ) from None

112
        if op_name not in op_registry_oot:
113
114
            op_cls_to_instantiate = cls
        else:
115
            op_cls_to_instantiate = op_registry_oot[op_name]
116
117
118
119
120
            logger.debug(
                "Instantiating custom op: %s using %s",
                op_name,
                str(op_cls_to_instantiate),
            )
121
122
        return super().__new__(op_cls_to_instantiate)

123
    def __init__(self, *, enforce_enable: bool = False, compile_native: bool = False):
124
        super().__init__()
125
        self._enforce_enable = enforce_enable
126
        self._forward_method = self.dispatch_forward(compile_native=compile_native)
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

    def forward(self, *args, **kwargs):
        return self._forward_method(*args, **kwargs)

    def forward_native(self, *args, **kwargs):
        """PyTorch-native implementation of the forward method.
        This method is optional. If implemented, it can be used with compilers
        such as torch.compile or PyTorch XLA. Also, it can be used for testing
        purposes.
        """
        raise NotImplementedError

    def forward_cuda(self, *args, **kwargs):
        raise NotImplementedError

    def forward_hip(self, *args, **kwargs):
        # By default, we assume that HIP ops are compatible with CUDA ops.
        return self.forward_cuda(*args, **kwargs)

    def forward_xpu(self, *args, **kwargs):
147
148
149
        # By default, we assume that XPU ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)
150
151

    def forward_cpu(self, *args, **kwargs):
152
153
154
        # By default, we assume that CPU ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)
155
156
157
158
159
160
161

    def forward_tpu(self, *args, **kwargs):
        # By default, we assume that TPU ops are compatible with the
        # PyTorch-native implementation.
        # NOTE(woosuk): This is a placeholder for future extensions.
        return self.forward_native(*args, **kwargs)

162
163
164
165
166
    def forward_oot(self, *args, **kwargs):
        # By default, we assume that OOT ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)

167
    def dispatch_forward(self, compile_native: bool):
168
169
        # NOTE(woosuk): Here we assume that vLLM was built for only one
        # specific backend. Currently, we do not support dynamic dispatching.
170
        compilation_config = get_cached_compilation_config()
171

172
173
174
175
176
177
        # NOTE(shen-shanshan): CustomOp object can be enforce enabled, e.g.,
        # enable device-specific kernels in ViT models when enabling graph
        # mode. By default, it will follow the compilation_config to determine
        # whether enable itself.
        # This enforce_enable mechanism will be removed after we adding a
        # separate compilation_config for multi-modal part.
178
        enabled = self._enforce_enable or self.enabled()
179
180
181
        if enabled:
            compilation_config.enabled_custom_ops.update([self.__class__.name])
        else:
182
            compilation_config.disabled_custom_ops.update([self.__class__.name])
183
184

        if not enabled:
185
186
187
            # Compile forward_native to avoid eager torch ops if inside
            # opaque torch custom op (e.g. fused_moe, unified_attention, etc.)
            return self.maybe_compile(self.forward_native, enable=compile_native)
188

189
        if current_platform.is_rocm():
190
            return self.forward_hip
191
        elif current_platform.is_cpu():
192
            return self.forward_cpu
193
        elif current_platform.is_tpu():
194
            return self.forward_tpu
195
        elif current_platform.is_xpu():
196
            return self.forward_xpu
197
198
        elif current_platform.is_out_of_tree():
            return self.forward_oot
199
200
        else:
            return self.forward_cuda
201

202
203
204
205
206
207
208
209
210
211
212
    def maybe_compile(self, fn, *, enable: bool = True):
        """
        Compile fn if compilation enabled.
        Useful for CustomOp instances called from within a torch custom op,
        meaning the forward call is hidden from the model-level torch.compile.

        NOTE: this does not enable fusion across ops, so opaque custom ops
        should still be unwrapped wherever possible.
        """
        from vllm.config.compilation import CompilationMode

213
        # Do not compile if compilation disabled
214
215
216
217
218
219
220
221
222
223
224
225
        if not enable:
            return fn

        # Do not compile if global compilation disabled
        compilation_config = get_cached_compilation_config()
        if compilation_config.mode == CompilationMode.NONE:
            return fn

        # If eager backend is used, do not compile either
        if compilation_config.backend == "eager":
            return fn

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
        compile_options = maybe_disable_graph_partition(
            current_platform.simple_compile_backend
        )
        backend = current_platform.simple_compile_backend

        dynamic_arg_dims = getattr(self.__class__, "_dynamic_arg_dims", None)
        if dynamic_arg_dims is not None:
            compiled_fn = torch.compile(
                fn,
                dynamic=False,
                backend=backend,
                options=compile_options,
            )
            sig = inspect.signature(fn)

            @functools.wraps(fn)
            def wrapper(*args, **kwargs):
                bound = sig.bind(*args, **kwargs)
                bound.apply_defaults()
                for name, dims in dynamic_arg_dims.items():
                    arg = bound.arguments.get(name)
                    if arg is not None and isinstance(arg, torch.Tensor):
                        dims_list = [dims] if isinstance(dims, int) else dims
                        for d in dims_list:
                            real_d = arg.ndim + d if d < 0 else d
                            torch._dynamo.mark_dynamic(arg, real_d)
                return compiled_fn(*args, **kwargs)

            return wrapper

256
257
258
259
        # dynamic=True to avoid recompilations
        return torch.compile(
            fn,
            dynamic=True,
260
261
            backend=backend,
            options=compile_options,
262
263
        )

264
265
266
    @classmethod
    def enabled(cls) -> bool:
        # if no name, then it was not registered
267
        compilation_config = get_cached_compilation_config()
268
        custom_ops = compilation_config.custom_ops
269
        if not hasattr(cls, "name"):
270
            logger.warning_once(
271
272
273
                "Custom op %s was not registered, which means it won't appear "
                "in the op registry. It will be enabled/disabled based on the "
                "global settings.",
274
275
                cls.__name__,
            )
276
277
            return CustomOp.default_on()

278
279
        enabled = f"+{cls.name}" in custom_ops
        disabled = f"-{cls.name}" in custom_ops
280
        assert not (enabled and disabled), f"Cannot enable and disable {cls.name}"
281
282
283
284
285

        return (CustomOp.default_on() or enabled) and not disabled

    @staticmethod
    def default_on() -> bool:
286
        """
287
288
289
290
        Behavior controlled by `CompilationConfig.custom_ops`: On by default if
        'all', off by default if 'none'.
        When PyTorch Inductor is used, 'none' is the default value,
        otherwise 'all'.
291
        """
292
        compilation_config = get_cached_compilation_config()
293
294
        count_none = compilation_config.custom_ops.count("none")
        count_all = compilation_config.custom_ops.count("all")
295
296
297
        assert count_none + count_all == 1

        return not count_none > 0 or count_all > 0
298
299
300

    # Decorator to register custom ops.
    @classmethod
301
302
303
304
305
    def register(
        cls,
        name: str,
        dynamic_arg_dims: dict[str, int | list[int]] | None = None,
    ):
306
        def decorator(op_cls):
307
            assert name not in op_registry, f"Duplicate op name: {name}"
308
            op_cls.name = name
309
            op_cls._dynamic_arg_dims = dynamic_arg_dims
310
            op_registry[name] = op_cls
311
312
313
            return op_cls

        return decorator
314
315
316
317
318
319
320
321
322
323
324

    # Decorator to register out-of-tree(oot) custom ops.
    # For OOT custom ops:
    #   if in-tree layer class is registered with an oot_custom_op layer,
    #   the oot_custom_op layer will be used instead.
    # Example:
    # - @UnquantizedFusedMoEMethod.register_oot
    #   class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod)
    # or
    # - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod")
    @classmethod
325
    def register_oot(cls, _decorated_op_cls=None, name: str | None = None):
326
327
        def decorator(op_cls):
            reg_name = name if name is not None else cls.__name__
328
            assert reg_name not in op_registry_oot, f"Duplicate op name: {reg_name}"
329
            op_cls.name = reg_name
330
            op_registry_oot[reg_name] = op_cls
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
            return op_cls

        if _decorated_op_cls is None:
            # Called with parentheses: @CustomOP.register_oot()
            # or @CustomOP.register_oot(name="...")
            # So, _decorated_op_cls is None.
            # We return the actual decorator function.
            return decorator
        elif isinstance(_decorated_op_cls, type):  # Check if it's a class
            # Called without parentheses: @CustomOP.register_oot
            # The first argument is the class itself.
            # We call the 'decorator' function immediately with the class.
            return decorator(_decorated_op_cls)
        else:
            # Handle other unexpected cases if necessary
            raise TypeError("Decorator can only be applied to classes.")