op.py 16.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import inspect
from collections.abc import Callable
from pathlib import Path
from typing import Any, ClassVar, overload

import torch
from torch.library import Library, infer_schema

12
from vllm.ir.tolerances import DEFAULT_TOLERANCES, ToleranceSpec
13
14
15
16
from vllm.ir.util import hash_source, weak_cache
from vllm.logger import init_logger
from vllm.logging_utils import lazy, tensors_str_no_data

17
18
InputGenerator = Callable[..., tuple[Any, ...]]

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
vllm_ir_lib = Library("vllm_ir", "FRAGMENT")

logger = init_logger(__name__)

RESERVED_PROVIDERS = ["native", "unfused"]
"""Providers that are reserved and cannot be used for custom implementations."""

_ENABLE_TORCH_WRAP: bool = True
"""Global override flag to control torch op layer wrapping."""


@contextlib.contextmanager
def enable_torch_wrap(enable: bool = True):
    """
    Context manager to enable/disable torch custom op wrapping for vLLM IR ops.
    When torch wrapping is disabled, the torch custom op layer is skipped
    and IR ops dispatch directly to the implementation.
    Helpful for avoiding torch dispatch overhead in eager mode
    and avoiding the need for lowering for platforms not using Inductor.
    """

    global _ENABLE_TORCH_WRAP
    old = _ENABLE_TORCH_WRAP
    try:
        _ENABLE_TORCH_WRAP = enable
        yield
    finally:
        _ENABLE_TORCH_WRAP = old


# 0-param decorator overload
@overload
def register_op(f: Callable[..., Any]) -> "IrOp": ...


# parametrized decorator overload
@overload
def register_op(
    *,
    name: str | None = None,
) -> Callable[[Callable[..., Any]], "IrOp"]: ...


def register_op(
    f: Callable | None = None,
    *,
    name: str | None = None,
) -> "IrOp | Callable[[Callable], IrOp]":
    """
    Register a new vLLM IR op.

    :param f: the native implementation of the op
    :param name: the name of the op, defaults to the function name
    :return: the IrOp object if f is provided, otherwise a decorator

    Example usage:
    ```python
    @vllm.ir.register_op
    def my_op(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        return x + y


    @vllm.ir.register_op(name="custom_mul")
    def multiply(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        return x * y"""

    def decorator(_f: Callable):
        op_name: str = _f.__name__ if name is None else name
        assert op_name not in IrOp.registry
        op = IrOp(op_name, _f)
        IrOp.registry[op_name] = op
        return op

    if f is not None:
        return decorator(f)

    return decorator


class IrOp:
    registry: ClassVar[dict[str, "IrOp"]] = {}

    name: str
    impls: dict[str, "IrOpImpl"]

    def __init__(self, name: str, native_impl: Callable):
        self._py_signature = inspect.signature(native_impl)
        if any(
            p.kind == inspect.Parameter.KEYWORD_ONLY
            for p in self._py_signature.parameters.values()
        ):
            raise ValueError(
                f"Op {name} has keyword-only arguments which are not currently "
                f"supported. That's because kwargs are not allowed during lowering."
            )

        self.name = name
        self.impls: dict[str, IrOpImpl] = {}
        self._priority_impls: list[IrOpImpl] = []
        self._schema_str = infer_schema(native_impl, mutates_args=[])
119
120
        self._input_generator: InputGenerator | None = None
        self._tolerance_overrides: ToleranceSpec = {}
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323

        # native implementation
        self.impls["native"] = IrOpImpl(
            self, "native", native_impl, supported=True, supports_args=None
        )

        # By default, fake routes directly to native,
        # can be overridden by register_fake
        self._fake_fn = native_impl

        # torch registration
        vllm_ir_lib.define(self.name + self._schema_str)
        # CompositeExplicitAutograd is not decomposed
        # by ATen IR normalization in AOTAutograd
        vllm_ir_lib.impl(
            self.name, self._inner_call, dispatch_key="CompositeExplicitAutograd"
        )
        vllm_ir_lib._register_fake(self.name, self._fake_call)
        assert hasattr(torch.ops.vllm_ir, name)
        self.torch_op: torch._ops.OpOverload = getattr(torch.ops.vllm_ir, name).default

    def register_fake(self, fn: Callable) -> Callable:
        """
        Register a fake impl for the torch custom op. If this method is not called,
        the native implementation is used directly for the fake implementation.
        """
        self._fake_fn = fn
        return fn

    def _fake_call(self, *args, **kwargs) -> Any:
        """
        Call to the fake implementation of the op. We use indirection because we want
        users to be able to register fake later but also want it to fall back to native
        directly by default, instead of going through the dispatching mechanism.
        """
        return self._fake_fn(*args, **kwargs)

    def register_impl(
        self,
        provider: str,
        *,
        supported: bool = True,
        supports_args: Callable[..., bool] | None = None,
    ):
        """
        Register an implementation for this custom op.
        :param provider: The name of the provider, must be unique.
        :param supported: Static support check, use this to check platform support.
        :param supports_args: Dynamic arg support check, used for types and shapes.
        :return: A decorator that registers the implementation.

        The decorated function must have the same semantics and signature as
        the native implementation.

        The provider name must be unique and not one of the RESERVED_PROVIDERS.
        The supported and supports_args parameters should not be used to implement
        custom enablement logic based on global state (e.g. environment variables).
        Instead, supported param should only be used to check for platform support
        (e.g. whether a specific hardware or library is available).
        supports_args should be used to check whether the provided arguments are
        compatible with the implementation.
        For custom enablement logic, set op impl priority.

        Example:
        ```python
        @my_op.register_impl("my_provider", supported=torch.cuda.is_available())
        def my_provider_impl(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ...
        ```

        """
        assert provider not in RESERVED_PROVIDERS, (
            f"Provider name {provider} is reserved."
        )

        def _register_impl(f: Callable):
            impl = IrOpImpl(self, provider, f, supported, supports_args)
            self.impls[provider] = impl

            if self.get_priority():
                logger.warning(
                    "Warning: registering new impl %s for op %s while priority is set.",
                    provider,
                    self.name,
                )

            return impl

        return _register_impl

    def _inner_call(self, *args, **kwargs) -> Any:
        """
        Eager call to torch op lands here. When torch wrapping is disabled,
        __call__ routes straight here instead of going through torch op dispatching.
        """
        impl = self.dispatch(*args, **kwargs)
        return impl.impl_fn(*args, **kwargs)

    def apply_arg_defaults(self, args) -> tuple:
        """
        Return args with default values applied.
        Defaults are taken from the native implementation signature.

        SHOULD NOT BE USED IN THE DISPATCH PATH (SLOW).
        Only for Inductor lowering.
        """
        bound_args = self._py_signature.bind(*args)
        bound_args.apply_defaults()
        return bound_args.args

    def dispatch(self, *args, **kwargs) -> "IrOpImpl":
        """
        Dispatch to the appropriate implementation based on current priority
        and argument support checks. Returns the selected IrOpImpl.

        THIS FUNCTION IS ON THE HOT PATH (OP DISPATCH), MUST BE FAST.
        """
        if not self._priority_impls:
            if not torch.compiler.is_compiling():
                # Logging not compatible with Dynamo tracing
                # (this code is exposed when torch wrapping is disabled)
                logger.warning_once(
                    "Priority not set for op %s, using native implementation.",
                    self.name,
                )
            return self.impls["native"]

        for impl in self._priority_impls:
            if not impl.supported:
                raise ValueError(
                    f"Implementation {impl.provider} for op {self.name} not supported. "
                    f"All implementations in priority list must be supported."
                )
            if impl.supports_args(*args, **kwargs):
                return impl

            if not torch.compiler.is_compiling():
                logger.debug(
                    "Skipping provider %s because it does not support "
                    "%s with args=%s kwargs=%s",
                    impl.provider,
                    self.name,
                    lazy(lambda: tensors_str_no_data(args)),
                    lazy(lambda: tensors_str_no_data(kwargs)),
                )

        raise RuntimeError(
            "Priority set incorrectly: the last implementation must "
            "support all args (can be native). This is likely an internal bug"
        )

    def __call__(self, *args, **kwargs) -> Any:
        if not _ENABLE_TORCH_WRAP:
            return self._inner_call(*args, **kwargs)

        return self.torch_op(*args, **kwargs)

    def get_priority(self) -> list[str]:
        """Get the current dispatch priority for implementations for this op."""
        return [p.provider for p in self._priority_impls]

    @contextlib.contextmanager
    def set_priority(self, priority: list[str]):
        """
        Context manager to set the dispatch priority for implementations for this op.
        """
        assert all(p in self.impls for p in priority), (
            "All providers in priority must be registered implementations."
        )

        def filter_priority_impls(p_list: list[str]) -> list[IrOpImpl]:
            filtered_impls = []
            for p in p_list:
                impl = self.impls[p]
                if not impl.supported:
                    # Skip unsupported implementations
                    continue

                filtered_impls.append(impl)

                # If all args are supported, skip other implementations
                if impl.supports_all_args:
                    return filtered_impls

            logger.warning_once(
                "Op %s: No implementation in priority list supports all args, "
                "execution fallback to native is possible. To silence this warning, "
                "explicitly add 'native' to the end of the priority list",
                self.name,
            )
            filtered_impls.append(self.impls["native"])
            return filtered_impls

        # Temporarily set priority
        old_priority_impls = self._priority_impls
        try:
            self._priority_impls = filter_priority_impls(priority)
            yield
        finally:
            self._priority_impls = old_priority_impls

    def supported_providers(self) -> list[str]:
        return [p.provider for p in self.impls.values() if p.supported]

324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
    @property
    def has_input_generator(self) -> bool:
        return self._input_generator is not None

    def register_input_generator(self, fn: InputGenerator) -> InputGenerator:
        self._input_generator = fn
        return fn

    def generate_inputs(self, **kwargs: Any) -> tuple[Any, ...]:
        if self._input_generator is None:
            raise RuntimeError(
                f"No input generator registered for op '{self.name}'. "
                f"Use @ir.ops.{self.name}.register_input_generator"
            )
        return self._input_generator(**kwargs)

    def override_tolerance(
        self, dtype: torch.dtype, *, atol: float, rtol: float
    ) -> None:
        self._tolerance_overrides[dtype] = {"atol": atol, "rtol": rtol}

    def get_tolerance(self, dtype: torch.dtype) -> dict[str, float]:
        if dtype in self._tolerance_overrides:
            return self._tolerance_overrides[dtype]
        if dtype in DEFAULT_TOLERANCES:
            return DEFAULT_TOLERANCES[dtype]
        raise ValueError(
            f"No tolerance defined for dtype {dtype} in op '{self.name}'. "
            f"Use op.override_tolerance({dtype}, atol=..., rtol=...) "
            f"or add {dtype} to DEFAULT_TOLERANCES."
        )

356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451

class IrOpImpl:
    def __init__(
        self,
        op: IrOp,
        provider: str,
        impl_fn: Callable,
        supported: bool,
        supports_args: Callable[..., bool] | None,
    ):
        assert provider not in op.impls, (
            f"Implementation for provider {provider} already registered."
        )
        # Native also uses this path, so we allow it here.
        assert provider == "native" or provider not in RESERVED_PROVIDERS

        # Enforce the exact same schema as the native implementation.
        # This takes care of names, types, and defaults.
        schema = infer_schema(impl_fn, mutates_args=[])
        if schema != op._schema_str:
            raise ValueError(
                f"Implementation for provider {provider} has schema '{schema}' which "
                f"does not match native schema '{op._schema_str}' for op {op.name}."
            )

        if supports_args is not None:
            if not callable(supports_args):
                raise ValueError(
                    f"supports_args for provider {provider} must be a callable"
                )

            # We also manually validate the supports_args signature.
            # Matching signatures allow faster dispatch on the hotpath.

            # Check that supports_args does not have keyword-only parameters
            supports_args_signature = inspect.signature(supports_args)
            params = supports_args_signature.parameters
            if any(p.kind == inspect.Parameter.KEYWORD_ONLY for p in params.values()):
                raise ValueError(
                    f"supports_args for provider {provider} "
                    f"cannot have keyword-only parameters"
                )

            # Check that supports_args has the same total number of parameters
            op_params = op._py_signature.parameters
            if len(params) != len(op_params):
                raise ValueError(
                    f"supports_args for provider {provider} must have the same number "
                    f"of parameters ({len(params)}) as the native implementation "
                    f"({len(op_params)})"
                )

            # Check that names and defaults match for supports_args
            for p, op_p in zip(params.values(), op_params.values()):
                if p.name != op_p.name:
                    raise ValueError(
                        f"supports_args for provider {provider} has parameter "
                        f"'{p.name}' which does not match native parameter "
                        f"'{op_p.name}'"
                    )
                if p.default != op_p.default:
                    raise ValueError(
                        f"supports_args for provider {provider} has parameter "
                        f"'{p.name}' with default {p.default} which does not match "
                        f"native default {op_p.default}'"
                    )

        self.op = op
        self.provider = provider
        self.impl_fn = impl_fn
        self.supported = supported
        self._supports_args = supports_args

    @property
    def supports_all_args(self) -> bool:
        """Check if this implementation supports all args unconditionally."""
        return self._supports_args is None

    def supports_args(self, *args, **kwargs) -> bool:
        if self._supports_args is None:
            return True

        return self._supports_args(*args, **kwargs)

    @weak_cache
    def uuid(self):
        """
        Compile-time hash to uniquely determine whether the implementation has changed.
        Used by vllm-compile hash mechanism and torch.compile lowering pass uuid to
        control the vLLM compile cache and AOTAutograd/Inductor caches respectively.

        Source file contents do not change so we cache uuid.
        TODO(luka): Cache the file hash as multiple impls are likely in the same file.
        """
        sources = [Path(inspect.getfile(self.impl_fn))]
        return hash_source(*sources)