custom_op.py 7.59 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4

5
6
import torch.nn as nn

7
from vllm.config import get_cached_compilation_config
8
from vllm.logger import init_logger
9
from vllm.platforms import current_platform
10
11

logger = init_logger(__name__)
12
13
14


class CustomOp(nn.Module):
15
16
17
18
    """
    Base class for custom ops.
    Dispatches the forward method to the appropriate backend.
    """
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
    def __new__(cls, *args, **kwargs):
        try:
            op_name = cls.__name__
        except AttributeError:
            raise TypeError(
                f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
                f"was not set, possibly because it was not decorated with "
                f"@CustomOp.register, or it's the CustomOp base class itself."
            ) from None

        if op_name not in cls.op_registry_oot:
            op_cls_to_instantiate = cls
        else:
            op_cls_to_instantiate = cls.op_registry_oot[op_name]
34
35
36
37
38
            logger.debug(
                "Instantiating custom op: %s using %s",
                op_name,
                str(op_cls_to_instantiate),
            )
39
40
        return super().__new__(op_cls_to_instantiate)

41
    def __init__(self, enforce_enable: bool = False):
42
        super().__init__()
43
        self._enforce_enable = enforce_enable
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        self._forward_method = self.dispatch_forward()

    def forward(self, *args, **kwargs):
        return self._forward_method(*args, **kwargs)

    def forward_native(self, *args, **kwargs):
        """PyTorch-native implementation of the forward method.
        This method is optional. If implemented, it can be used with compilers
        such as torch.compile or PyTorch XLA. Also, it can be used for testing
        purposes.
        """
        raise NotImplementedError

    def forward_cuda(self, *args, **kwargs):
        raise NotImplementedError

    def forward_hip(self, *args, **kwargs):
        # By default, we assume that HIP ops are compatible with CUDA ops.
        return self.forward_cuda(*args, **kwargs)

    def forward_xpu(self, *args, **kwargs):
65
66
67
        # By default, we assume that XPU ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)
68
69
70
71
72
73
74
75
76
77
78

    def forward_cpu(self, *args, **kwargs):
        # By default, we assume that CPU ops are compatible with CUDA ops.
        return self.forward_cuda(*args, **kwargs)

    def forward_tpu(self, *args, **kwargs):
        # By default, we assume that TPU ops are compatible with the
        # PyTorch-native implementation.
        # NOTE(woosuk): This is a placeholder for future extensions.
        return self.forward_native(*args, **kwargs)

79
80
81
82
83
    def forward_oot(self, *args, **kwargs):
        # By default, we assume that OOT ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)

84
85
86
    def dispatch_forward(self):
        # NOTE(woosuk): Here we assume that vLLM was built for only one
        # specific backend. Currently, we do not support dynamic dispatching.
87
        compilation_config = get_cached_compilation_config()
88
89
90
91
92

        # CustomOp object can be enforce enabled, e.g., enable device-specific
        # kernels in ViT models when enabling graph mode. By default, it will
        # follow the compilation_config to determine whether enable itself.
        enabled = self._enforce_enable or self.enabled()
93
94
95
        if enabled:
            compilation_config.enabled_custom_ops.update([self.__class__.name])
        else:
96
            compilation_config.disabled_custom_ops.update([self.__class__.name])
97
98

        if not enabled:
99
100
            return self.forward_native

101
        if current_platform.is_rocm():
102
            return self.forward_hip
103
        elif current_platform.is_cpu():
104
            return self.forward_cpu
105
        elif current_platform.is_tpu():
106
            return self.forward_tpu
107
        elif current_platform.is_xpu():
108
            return self.forward_xpu
109
110
        elif current_platform.is_out_of_tree():
            return self.forward_oot
111
112
        else:
            return self.forward_cuda
113
114
115
116

    @classmethod
    def enabled(cls) -> bool:
        # if no name, then it was not registered
117
        compilation_config = get_cached_compilation_config()
118
        custom_ops = compilation_config.custom_ops
119
        if not hasattr(cls, "name"):
120
            logger.warning_once(
121
122
123
                "Custom op %s was not registered, which means it won't appear "
                "in the op registry. It will be enabled/disabled based on the "
                "global settings.",
124
125
                cls.__name__,
            )
126
127
            return CustomOp.default_on()

128
129
        enabled = f"+{cls.name}" in custom_ops
        disabled = f"-{cls.name}" in custom_ops
130
        assert not (enabled and disabled), f"Cannot enable and disable {cls.name}"
131
132
133
134
135

        return (CustomOp.default_on() or enabled) and not disabled

    @staticmethod
    def default_on() -> bool:
136
        """
137
138
139
140
        Behavior controlled by `CompilationConfig.custom_ops`: On by default if
        'all', off by default if 'none'.
        When PyTorch Inductor is used, 'none' is the default value,
        otherwise 'all'.
141
        """
142
        compilation_config = get_cached_compilation_config()
143
144
        count_none = compilation_config.custom_ops.count("none")
        count_all = compilation_config.custom_ops.count("all")
145
146
147
        assert count_none + count_all == 1

        return not count_none > 0 or count_all > 0
148
149
150
151
152
153

    # Dictionary of all custom ops (classes, indexed by registered name).
    # To check if an op with a name is enabled, call .enabled() on the class.
    # Examples:
    # - MyOp.enabled()
    # - op_registry["my_op"].enabled()
154
155
    op_registry: dict[str, type["CustomOp"]] = {}
    op_registry_oot: dict[str, type["CustomOp"]] = {}
156
157
158
159
160
161
162
163
164
165
166

    # Decorator to register custom ops.
    @classmethod
    def register(cls, name: str):
        def decorator(op_cls):
            assert name not in cls.op_registry, f"Duplicate op name: {name}"
            op_cls.name = name
            cls.op_registry[name] = op_cls
            return op_cls

        return decorator
167
168
169
170
171
172
173
174
175
176
177

    # Decorator to register out-of-tree(oot) custom ops.
    # For OOT custom ops:
    #   if in-tree layer class is registered with an oot_custom_op layer,
    #   the oot_custom_op layer will be used instead.
    # Example:
    # - @UnquantizedFusedMoEMethod.register_oot
    #   class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod)
    # or
    # - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod")
    @classmethod
178
    def register_oot(cls, _decorated_op_cls=None, name: str | None = None):
179
180
        def decorator(op_cls):
            reg_name = name if name is not None else cls.__name__
181
            assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}"
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
            op_cls.name = reg_name
            cls.op_registry_oot[reg_name] = op_cls
            return op_cls

        if _decorated_op_cls is None:
            # Called with parentheses: @CustomOP.register_oot()
            # or @CustomOP.register_oot(name="...")
            # So, _decorated_op_cls is None.
            # We return the actual decorator function.
            return decorator
        elif isinstance(_decorated_op_cls, type):  # Check if it's a class
            # Called without parentheses: @CustomOP.register_oot
            # The first argument is the class itself.
            # We call the 'decorator' function immediately with the class.
            return decorator(_decorated_op_cls)
        else:
            # Handle other unexpected cases if necessary
            raise TypeError("Decorator can only be applied to classes.")