register.py 19.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
vLLM Helion kernel registration with pre-tuned config selection.

This module leverages Helion's internal config selection infrastructure to use
pre-tuned configs instead of runtime autotuning.

How Helion Normally Works
-------------------------
For each kernel invocation, Helion:
1. Computes a cache key from input arguments
2. Looks up the key in its internal compilation cache
3. On cache miss, runs autotuning to find the best config
4. Compiles and caches the kernel with that config

How We Override It
------------------
We override two Helion hooks to use pre-tuned configs:

1. **key**: We provide a key function (derived from config_picker) that
   computes cache keys matching our pre-tuned config keys. This ensures Helion's
   internal cache uses keys that correspond to configs we've prepared.

2. **autotuner_fn**: We provide PresetConfigSearch which, instead of autotuning,
   simply returns the pre-tuned config for the computed key. On cache miss,
   Helion calls our autotuner which returns the author-prepared config.

Both hooks use the same config_picker logic to ensure the cache key computed
by key matches the config returned by the autotuner.

Key Classes
-----------
34
35
- HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured kernels
- ConfiguredHelionKernel: Platform-specific kernel with pre-tuned configs
36
37
38
39
- PresetConfigSearch: Custom autotuner that returns pre-tuned configs
"""

from collections.abc import Callable
40
from typing import Any, cast, overload
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

import torch
from torch.library import Library

from vllm.logger import init_logger
from vllm.utils.import_utils import has_helion
from vllm.utils.torch_utils import direct_register_custom_op

if not has_helion():
    raise ImportError(
        "register module requires helion to be installed. "
        "Install it with: pip install helion"
    )

import helion
56
from helion._compat import requires_torch_version
57
58
59
60
from helion.autotuner.base_search import BaseAutotuner
from helion.runtime.config import Config
from helion.runtime.settings import default_autotuner_fn

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# TODO(gmagogsfm): Remove CustomOp fallback path (_get_or_register_custom_op,
# vllm_helion_lib, direct_register_custom_op) once vLLM requires PyTorch >= 2.11.
_HOP_AVAILABLE = requires_torch_version("2.11")

if _HOP_AVAILABLE:
    import torch.utils._pytree as pytree
    from helion._compiler._dynamo.higher_order_ops import (
        helion_kernel_side_table,
        helion_kernel_wrapper_mutation,
    )
    from helion._compiler._dynamo.variables import infer_output_spec
    from torch.fx.experimental.proxy_tensor import (
        disable_proxy_modes_tracing,
        get_proxy_mode,
    )

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
logger = init_logger(__name__)

vllm_helion_lib = Library("vllm_helion", "FRAGMENT")  # noqa


def validate_helion_settings(
    helion_settings: "helion.Settings | None", op_name: str
) -> None:
    if helion_settings is None:
        return

    settings_dict = helion_settings.to_dict()

    if (
        "autotuner_fn" in settings_dict
        and settings_dict["autotuner_fn"] is not None
        and settings_dict["autotuner_fn"] is not default_autotuner_fn
    ):
        raise ValueError(
            f"HelionKernelWrapper for '{op_name}' uses a custom autotuner via "
            f"config picker. Remove 'autotuner_fn' from helion_settings and use "
            f"@{op_name}.register_config_picker instead."
        )

    # Warn if static_shapes is explicitly set to True since most vLLM ops need
    # dynamic shapes for variable batch sizes and sequence lengths
    if settings_dict.get("static_shapes") is True:
        logger.warning(
            "Kernel '%s' has static_shapes=True in helion_settings. "
            "Most vLLM ops require dynamic shapes for variable batch sizes "
            "and sequence lengths. Consider removing this setting.",
            op_name,
        )


112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def create_helion_decorated_kernel(
    raw_kernel_func: Callable,
    helion_settings: "helion.Settings | None" = None,
    extra_kwargs: dict[str, Any] | None = None,
) -> Any:
    kernel_kwargs: dict[str, Any] = {}
    if helion_settings:
        kernel_kwargs.update(helion_settings.to_dict())

    # Set static_shapes=False by default if user didn't explicitly set it
    # This is needed for dynamic batch sizes and sequence lengths in vLLM
    if kernel_kwargs.get("static_shapes") is not True:
        kernel_kwargs["static_shapes"] = False

    if extra_kwargs:
        kernel_kwargs.update(extra_kwargs)

    return helion.kernel(**kernel_kwargs)(raw_kernel_func)


132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
class PresetConfigSearch(BaseAutotuner):
    """Custom autotuner that uses a preset config selector instead of autotuning."""

    def __init__(
        self,
        args: tuple[Any, ...],
        config_selector: Callable[[tuple[Any, ...]], Config],
    ):
        self.args = args
        self.config_selector = config_selector

    def autotune(self, *, skip_cache: bool = False) -> Config:
        return self.config_selector(self.args)


class ConfiguredHelionKernel:
    """A configured Helion kernel bound to a specific platform."""

    def __init__(
        self,
        op_name: str,
153
        config_picker: Callable[[tuple[Any, ...], list[str]], str | None] | None,
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        raw_kernel_func: Callable,
        helion_settings: "helion.Settings | None" = None,
    ):
        self.op_name = op_name
        self.config_picker = config_picker
        self.raw_kernel_func = raw_kernel_func
        self.helion_settings = helion_settings
        self._decorated_kernel = self._create_decorated_kernel()

    def __call__(self, *args, **kwargs):
        return self._decorated_kernel(*args, **kwargs)

    def _create_key_computer(self):
        """
        Create a key computer function derived from the config picker.

        The returned function receives kernel arguments unpacked (*args) to match
        Helion's key signature (called as self._key_fn(*args)).
        """
        if self.config_picker is None:
            raise RuntimeError(
                f"No config picker registered for kernel '{self.op_name}'. "
                f"Use @{self.op_name}.register_config_picker to register one."
            )

179
180
181
        # After None check, config_picker is guaranteed to be non-None
        assert self.config_picker is not None

182
183
        def key_computer(*args):
            config_keys = list(self.configs.keys())
184
185
186
187
188
            # Cast is safe because we checked for None above
            config_picker = cast(
                Callable[[tuple[Any, ...], list[str]], str | None], self.config_picker
            )
            selected_key = config_picker(args, config_keys)
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
            if selected_key:
                return selected_key
            return "default" if "default" in self.configs else None

        return key_computer

    def _create_config_selector(self, key_computer):
        def config_selector(args):
            # args is a tuple; key_computer expects unpacked args
            selected_config_key = key_computer(*args)

            if selected_config_key is None:
                raise ValueError(
                    f"Config picker returned None for kernel '{self.op_name}' "
                    f"with available config keys: {list(self.configs.keys())}"
                )

            if selected_config_key not in self.configs:
                raise ValueError(
                    f"Config picker returned invalid config key "
                    f"'{selected_config_key}' for kernel '{self.op_name}'. "
                    f"Available keys: {list(self.configs.keys())}"
                )

            return self.configs[selected_config_key]

        return config_selector

    def _load_platform_configs(self) -> None:
        from vllm.kernels.helion.config_manager import ConfigManager
        from vllm.kernels.helion.utils import get_canonical_gpu_name

        self.platform = get_canonical_gpu_name()
        config_manager = ConfigManager.get_instance()
        self.configs = config_manager.get_platform_configs(self.op_name, self.platform)

        if not self.configs:
            raise ValueError(
                f"No configs available for kernel '{self.op_name}' "
                f"on platform '{self.platform}'"
            )

    def _create_decorated_kernel(self) -> Callable[..., Any]:
        self._load_platform_configs()

        key_computer = self._create_key_computer()
        config_selector = self._create_config_selector(key_computer)

237
238
239
240
        extra_kwargs = {
            "autotuner_fn": lambda _, args: PresetConfigSearch(args, config_selector),
            "key": key_computer,
        }
241
242
243
244
245
246

        logger.debug(
            "Creating decorated kernel %s with custom autotuner on platform %s",
            self.op_name,
            self.platform,
        )
247
248
249
        return create_helion_decorated_kernel(
            self.raw_kernel_func, self.helion_settings, extra_kwargs
        )
250
251
252


class HelionKernelWrapper:
253
    """Wrapper for Helion kernels with pre-tuned config selection and HOP support."""
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271

    def __init__(
        self,
        raw_kernel_func: Callable,
        op_name: str,
        fake_impl: Callable,
        helion_settings: "helion.Settings | None" = None,
    ):
        # Validate helion_settings doesn't conflict with our custom autotuner
        validate_helion_settings(helion_settings, op_name)

        self.raw_kernel_func = raw_kernel_func
        self.op_name = op_name
        self._fake_impl = fake_impl
        self.helion_settings = helion_settings
        self._config_picker: (
            Callable[[tuple[Any, ...], list[str]], str | None] | None
        ) = None
272
        self._configured_kernel: ConfiguredHelionKernel | None = None
273
        self._input_generator: Callable[[], dict[str, tuple[Any, ...]]] | None = None
274
275

    def __call__(self, *args, **kwargs):
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
        # CustomOp fallback: register as torch custom op for torch.compile
        # compatibility on older PyTorch lacking HOP/EffectType support
        if not _HOP_AVAILABLE:
            custom_op = self._get_or_register_custom_op()
            return custom_op(*args, **kwargs)
        # HOP tracing: record HigherOrderOp in the FX graph
        if get_proxy_mode() is not None:
            return self._call_via_hop(args, kwargs)
        # Eager: run the configured kernel directly
        return self.get_configured_op()(*args, **kwargs)

    def _call_via_hop(
        self,
        args: tuple[Any, ...],
        kwargs: dict[str, Any],
    ) -> Any:
        kernel = self.get_configured_op()._decorated_kernel
        kernel_idx = helion_kernel_side_table.add_kernel(kernel)

        constant_args, tensor_args = self._partition_args(kernel, args, kwargs)

        all_named = {**constant_args, **tensor_args}
        full_args = tuple(
            all_named.get(n, p.default)
            for n, p in kernel.signature.parameters.items()  # type: ignore[attr-defined]
            if n in all_named or p.default is not p.empty
        )

        with disable_proxy_modes_tracing():
            output_spec = infer_output_spec(kernel, full_args)

        hop_result = helion_kernel_wrapper_mutation(
            kernel_idx=kernel_idx,
            constant_args=constant_args,
            tensor_args=tensor_args,
            output_spec=output_spec,
        )

        tree_spec_str = output_spec.get("tree_spec_str")
        if tree_spec_str is None:
            return None
        tree_spec = pytree.treespec_loads(tree_spec_str)

        hop_iter = iter(hop_result)
        reconstructed = []
        for spec in output_spec["leaf_specs"]:
            is_constant_scalar = spec["type"] == "scalar" and not isinstance(
                spec.get("scalar_value"), torch.SymInt
            )
            if is_constant_scalar:
                reconstructed.append(spec["scalar_value"])
            else:
                reconstructed.append(next(hop_iter))
        return pytree.tree_unflatten(reconstructed, tree_spec)

    @staticmethod
    def _partition_args(
        kernel: Any,
        args: tuple[Any, ...],
        kwargs: dict[str, Any],
    ) -> tuple[dict[str, Any], dict[str, Any]]:
        constant_args: dict[str, Any] = {}
        tensor_args: dict[str, Any] = {}
        params = list(kernel.signature.parameters.keys())
        for i, val in enumerate(args):
            name = params[i]
            if isinstance(val, torch.Tensor):
                tensor_args[name] = val
            else:
                constant_args[name] = val
        for name, val in kwargs.items():
            if isinstance(val, torch.Tensor):
                tensor_args[name] = val
            else:
                constant_args[name] = val
        return constant_args, tensor_args
352
353
354
355
356
357
358

    def register_config_picker(
        self, picker_func: Callable[[tuple[Any, ...], list[str]], str | None]
    ) -> Callable[[tuple[Any, ...], list[str]], str | None]:
        self._config_picker = picker_func
        return picker_func

359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    def register_input_generator(
        self, generator_func: Callable[[], dict[str, tuple[Any, ...]]]
    ) -> Callable[[], dict[str, tuple[Any, ...]]]:
        """
        Register a function to generate inputs for autotuning and benchmarking.

        Args:
            generator_func: Function that returns dict[str, tuple] where:
                - key: Configuration identifier (e.g., "4096", "hidden_4096")
                - value: Tuple of arguments to pass to the kernel

        Returns:
            The registered function (for decorator usage)

        Example:
            @kernel_wrapper.register_input_generator
            def generate_inputs():
                return {
                    "4096": (torch.randn(4096, device="cuda"), 0.5),
                    "8192": (torch.randn(8192, device="cuda"), 0.5),
                }
        """
        self._input_generator = generator_func
        return generator_func

    def get_inputs(self) -> dict[str, tuple[Any, ...]]:
        if self._input_generator is None:
            raise NotImplementedError(
                f"No input generator registered for kernel '{self.op_name}'. "
                f"Use @{self.op_name}.register_input_generator to register one."
            )
        return self._input_generator()

    def run_autotune(
        self,
        inputs: tuple[Any, ...],
        autotune_effort: str = "quick",
    ) -> Config:
        """Run autotuning for a single input configuration."""
398
399
400
401
        extra_kwargs = {
            "autotune_effort": autotune_effort,
            "autotune_ignore_errors": True,
        }
402
403
404
405
406
        autotune_kernel = create_helion_decorated_kernel(
            self.raw_kernel_func, self.helion_settings, extra_kwargs
        )
        return autotune_kernel.autotune(inputs)

407
    def get_configured_op(self) -> ConfiguredHelionKernel:
408
409
410
411
412
        assert self._config_picker is not None, (
            f"No config picker registered for kernel '{self.op_name}'. "
            f"Use @{self.op_name}.register_config_picker to register one."
        )

413
414
415
416
417
418
419
420
421
422
423
        if self._configured_kernel is None:
            self._configured_kernel = ConfiguredHelionKernel(
                op_name=self.op_name,
                config_picker=self._config_picker,
                raw_kernel_func=self.raw_kernel_func,
                helion_settings=self.helion_settings,
            )

        return self._configured_kernel

    def _get_or_register_custom_op(self) -> Any:
424
425
426
        if hasattr(torch.ops.vllm_helion, self.op_name):
            return getattr(torch.ops.vllm_helion, self.op_name)

427
        configured_kernel = self.get_configured_op()
428
429
430
431

        logger.info("Registering op: vllm_helion::%s", self.op_name)
        direct_register_custom_op(
            op_name=self.op_name,
432
            op_func=configured_kernel._decorated_kernel,
433
434
435
436
437
            mutates_args=None,
            fake_impl=self._fake_impl,
            target_lib=vllm_helion_lib,
        )
        return getattr(torch.ops.vllm_helion, self.op_name)
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549


# Global registry for tracking all registered HelionKernelWrapper instances
_REGISTERED_KERNELS: dict[str, HelionKernelWrapper] = {}


def get_registered_kernels() -> dict[str, HelionKernelWrapper]:
    return _REGISTERED_KERNELS.copy()


def get_kernel_by_name(kernel_name: str) -> HelionKernelWrapper | None:
    return _REGISTERED_KERNELS.get(kernel_name)


def infer_fake_impl(
    kernel_func: Callable,
    helion_settings: "helion.Settings | None" = None,
) -> Callable:
    def helion_fake_kernel(*args, **kwargs):
        kernel_kwargs = {}
        if helion_settings:
            kernel_kwargs.update(helion_settings.to_dict())

        temp_decorated_kernel = helion.kernel(**kernel_kwargs)(kernel_func)

        # Bind with args to get config_spec, then get a valid default config
        bound = temp_decorated_kernel.bind(args)
        default_config = bound.config_spec.default_config()
        compiled_runner = bound.compile_config(default_config)

        return compiled_runner(*args, **kwargs, _launcher=lambda *a, **kw: None)

    return helion_fake_kernel


# Overloads are necessary for proper mypy type inference.
# Without overloads, the union return type HelionKernelWrapper | Callable[...]
# causes mypy to complain about missing attributes when tests do:
#   wrapper = register_kernel(func)  # Should return HelionKernelWrapper
#   wrapper._fake_impl  # mypy error: "Callable has no attribute _fake_impl"
# The overloads tell mypy the exact return type based on the argument pattern.
@overload
def register_kernel(
    op_name_or_func: Callable,
    *,
    fake_impl: Callable | None = None,
    helion_settings: "helion.Settings | None" = None,
) -> HelionKernelWrapper: ...


@overload
def register_kernel(
    op_name_or_func: str | None = None,
    *,
    fake_impl: Callable | None = None,
    helion_settings: "helion.Settings | None" = None,
) -> Callable[[Callable], HelionKernelWrapper]: ...


def register_kernel(
    op_name_or_func: str | Callable | None = None,
    *,
    fake_impl: Callable | None = None,
    helion_settings: "helion.Settings | None" = None,
) -> HelionKernelWrapper | Callable[[Callable], HelionKernelWrapper]:
    """
    Decorator to register a Helion kernel function as a HelionKernelWrapper.

    Wraps the raw kernel function in a HelionKernelWrapper and registers it
    in the global kernel registry. Auto-generates fake_impl if not provided.
    """

    def decorator(kernel_func: Callable) -> HelionKernelWrapper:
        op_name = op_name_or_func if isinstance(op_name_or_func, str) else None
        final_op_name = op_name if op_name else kernel_func.__name__

        if final_op_name in _REGISTERED_KERNELS:
            raise ValueError(
                f"Helion kernel '{final_op_name}' is already registered. "
                f"Use a different op_name or check for duplicate registrations."
            )

        final_fake_impl = fake_impl
        if final_fake_impl is None:
            final_fake_impl = infer_fake_impl(kernel_func, helion_settings)
            logger.debug(
                "Auto-generated fake_impl for Helion kernel '%s'",
                kernel_func.__name__,
            )

        kernel_wrapper = HelionKernelWrapper(
            raw_kernel_func=kernel_func,
            op_name=final_op_name,
            fake_impl=final_fake_impl,
            helion_settings=helion_settings,
        )

        _REGISTERED_KERNELS[final_op_name] = kernel_wrapper

        logger.info(
            "Registered Helion kernel '%s' as HelionKernelWrapper",
            kernel_func.__name__,
        )

        return kernel_wrapper

    if callable(op_name_or_func) and not isinstance(op_name_or_func, str):
        # Bare decorator usage: @register_kernel
        return decorator(op_name_or_func)
    else:
        # Decorator with arguments: @register_kernel(...)
        return decorator