custom_op.py 5.47 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
import torch.nn as nn

5
from vllm.config import get_current_vllm_config
6
from vllm.logger import init_logger
7
from vllm.platforms import current_platform
8
9

logger = init_logger(__name__)
10
11
12


class CustomOp(nn.Module):
13
14
15
16
    """
    Base class for custom ops.
    Dispatches the forward method to the appropriate backend.
    """
17

18
    def __init__(self):
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
        super().__init__()
        self._forward_method = self.dispatch_forward()

    def forward(self, *args, **kwargs):
        return self._forward_method(*args, **kwargs)

    def forward_native(self, *args, **kwargs):
        """PyTorch-native implementation of the forward method.
        This method is optional. If implemented, it can be used with compilers
        such as torch.compile or PyTorch XLA. Also, it can be used for testing
        purposes.
        """
        raise NotImplementedError

    def forward_cuda(self, *args, **kwargs):
        raise NotImplementedError

    def forward_hip(self, *args, **kwargs):
        # By default, we assume that HIP ops are compatible with CUDA ops.
        return self.forward_cuda(*args, **kwargs)

    def forward_xpu(self, *args, **kwargs):
41
42
43
        # By default, we assume that XPU ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)
44
45
46
47
48
49
50
51
52
53
54

    def forward_cpu(self, *args, **kwargs):
        # By default, we assume that CPU ops are compatible with CUDA ops.
        return self.forward_cuda(*args, **kwargs)

    def forward_tpu(self, *args, **kwargs):
        # By default, we assume that TPU ops are compatible with the
        # PyTorch-native implementation.
        # NOTE(woosuk): This is a placeholder for future extensions.
        return self.forward_native(*args, **kwargs)

55
    def forward_hpu(self, *args, **kwargs):
56
57
58
59
        # By default, we assume that Gaudi ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)

60
61
62
63
64
    def forward_neuron(self, *args, **kwargs):
        # By default, we assume that Neuron ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)

65
66
67
68
69
    def forward_oot(self, *args, **kwargs):
        # By default, we assume that OOT ops are compatible with the
        # PyTorch-native implementation.
        return self.forward_native(*args, **kwargs)

70
71
72
    def dispatch_forward(self):
        # NOTE(woosuk): Here we assume that vLLM was built for only one
        # specific backend. Currently, we do not support dynamic dispatching.
73
        compilation_config = get_current_vllm_config().compilation_config
74
        enabled = self.enabled()
75
76
77
78
79
        if enabled:
            compilation_config.enabled_custom_ops.update([self.__class__.name])
        else:
            compilation_config.disabled_custom_ops.update(
                [self.__class__.name])
80
81

        if not enabled:
82
83
            return self.forward_native

84
        if current_platform.is_rocm():
85
            return self.forward_hip
86
        elif current_platform.is_cpu():
87
            return self.forward_cpu
88
89
        elif current_platform.is_hpu():
            return self.forward_hpu
90
        elif current_platform.is_tpu():
91
            return self.forward_tpu
92
        elif current_platform.is_xpu():
93
            return self.forward_xpu
94
95
        elif current_platform.is_neuron():
            return self.forward_neuron
96
97
        elif current_platform.is_out_of_tree():
            return self.forward_oot
98
99
        else:
            return self.forward_cuda
100
101
102
103

    @classmethod
    def enabled(cls) -> bool:
        # if no name, then it was not registered
104
105
        compilation_config = get_current_vllm_config().compilation_config
        custom_ops = compilation_config.custom_ops
106
        if not hasattr(cls, "name"):
107
            logger.warning_once(
108
109
110
                "Custom op %s was not registered, which means it won't appear in the op registry. It will be enabled/disabled based on the global settings.",  # noqa: E501
                cls.__name__,
            )
111
112
            return CustomOp.default_on()

113
114
        enabled = f"+{cls.name}" in custom_ops
        disabled = f"-{cls.name}" in custom_ops
115
116
117
118
119
120
121
        assert not (enabled
                    and disabled), f"Cannot enable and disable {cls.name}"

        return (CustomOp.default_on() or enabled) and not disabled

    @staticmethod
    def default_on() -> bool:
122
123
124
125
126
127
128
129
130
131
        """
        On by default if level < CompilationLevel.PIECEWISE
        Specifying 'all' or 'none' in custom_op takes precedence.
        """
        from vllm.config import CompilationLevel
        compilation_config = get_current_vllm_config().compilation_config
        custom_ops = compilation_config.custom_ops
        count_none = custom_ops.count("none")
        count_all = custom_ops.count("all")
        return compilation_config.level < CompilationLevel.PIECEWISE and \
132
133
134
135
136
137
138
            not count_none > 0 or count_all > 0

    # Dictionary of all custom ops (classes, indexed by registered name).
    # To check if an op with a name is enabled, call .enabled() on the class.
    # Examples:
    # - MyOp.enabled()
    # - op_registry["my_op"].enabled()
139
    op_registry: dict[str, type['CustomOp']] = {}
140
141
142
143
144
145
146
147
148
149
150
151

    # Decorator to register custom ops.
    @classmethod
    def register(cls, name: str):

        def decorator(op_cls):
            assert name not in cls.op_registry, f"Duplicate op name: {name}"
            op_cls.name = name
            cls.op_registry[name] = op_cls
            return op_cls

        return decorator