Unverified Commit ff9fbc9a authored by Yanan Cao's avatar Yanan Cao Committed by GitHub
Browse files

[Kernel][Helion] [16/N] Refactor register_kernel API to be more Dynamo-friendly (#36705)


Signed-off-by: default avatarYanan Cao <gmagogsfm@gmail.com>
Co-authored-by: default avatarClaude Opus 4.6 <noreply@anthropic.com>
parent e6c47977
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import tempfile
from collections.abc import Callable
from contextlib import contextmanager
from pathlib import Path
from unittest.mock import patch
import helion
from vllm.kernels.helion.config_manager import ConfigManager
from vllm.kernels.helion.register import register_kernel
from vllm.kernels.helion.utils import get_canonical_gpu_name
GPU_PLATFORM = get_canonical_gpu_name()
DEFAULT_CONFIGS: dict[str, helion.Config] = {
"default": helion.Config(block_sizes=[32]),
}
@contextmanager
def dummy_kernel_registry(
configs: dict[str, helion.Config] | None = None,
):
"""Context manager providing a register function with automatic config setup.
Yields a ``register`` callable with the same signature as
``register_kernel``. Before applying the real decorator it writes a
config JSON for the kernel name (from ``op_name`` or ``fn.__name__``)
into a temporary directory backed by a fresh ``ConfigManager``.
"""
if configs is None:
configs = DEFAULT_CONFIGS
config_data = {k: v.__dict__["config"] for k, v in configs.items()}
with tempfile.TemporaryDirectory() as tmpdir:
config_dir = Path(tmpdir)
ConfigManager.reset_instance()
cm = ConfigManager(base_dir=config_dir)
with patch(
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=cm,
):
def register(
op_name: str | None = None,
**kwargs,
) -> Callable:
def decorator(fn: Callable) -> Callable:
name = op_name or fn.__name__
kernel_dir = config_dir / name
kernel_dir.mkdir(parents=True, exist_ok=True)
(kernel_dir / f"{GPU_PLATFORM}.json").write_text(
json.dumps(config_data)
)
return register_kernel(op_name, **kwargs)(fn)
return decorator
try:
yield register
finally:
ConfigManager.reset_instance()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for autotuning Helion kernels, including disabled kernels with no configs."""
import pytest
import torch
from vllm.utils.import_utils import has_helion
if not has_helion():
pytest.skip(
"Helion is not installed. Install with: pip install vllm[helion]",
allow_module_level=True,
)
import helion
import helion.language as hl
from helion.autotuner.base_search import BaseSearch
from tests.kernels.helion.helpers import dummy_kernel_registry
from vllm.kernels.helion.register import create_helion_decorated_kernel
def _add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
for tile in hl.tile(x.size()):
out[tile] = x[tile] + y[tile]
return out
class NoCompileSearch(BaseSearch):
"""Autotuner that returns the default config without GPU compilation.
Modeled after helion's test BasicSearch (pytorch/helion#1649).
"""
def autotune(self, *, skip_cache: bool = False):
return self.config_spec.default_config()
def _no_compile_autotuner_fn(bound_kernel, args, **kwargs):
return NoCompileSearch(bound_kernel, args, **kwargs)
class TestAutotuneDisabledKernel:
"""Test autotuning flow on disabled kernels (no platform configs)."""
def setup_method(self):
from vllm.kernels.helion.register import _REGISTERED_KERNELS
self._saved_registry = dict(_REGISTERED_KERNELS)
_REGISTERED_KERNELS.clear()
def teardown_method(self):
from vllm.kernels.helion.register import _REGISTERED_KERNELS
_REGISTERED_KERNELS.clear()
_REGISTERED_KERNELS.update(self._saved_registry)
def test_autotune_disabled_kernel_produces_valid_config(self):
"""Register a kernel with no configs (disabled), run autotune,
verify it produces a valid helion.Config."""
with dummy_kernel_registry(configs={}) as register:
wrapper = register(
"autotune_test_kernel",
config_picker=lambda args, keys: "default",
fake_impl=lambda *a, **kw: None,
input_generator=lambda: {
"small": (
torch.randn(4, 4, device="cuda"),
torch.randn(4, 4, device="cuda"),
),
},
)(_add_kernel)
assert wrapper._disabled is True
inputs = wrapper.get_inputs()
assert "small" in inputs
settings = helion.Settings()
settings.autotuner_fn = _no_compile_autotuner_fn
wrapper.helion_settings = settings
config = wrapper.run_autotune(inputs["small"])
expected_default = (
create_helion_decorated_kernel(_add_kernel, helion_settings=settings)
.bind(inputs["small"])
.config_spec.default_config()
)
assert config == expected_default
...@@ -52,7 +52,7 @@ def _helion_mock_context(): ...@@ -52,7 +52,7 @@ def _helion_mock_context():
with ( with (
patch( patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance", "vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager, return_value=mock_config_manager,
), ),
patch( patch(
...@@ -87,8 +87,8 @@ class TestMakeFxHop: ...@@ -87,8 +87,8 @@ class TestMakeFxHop:
raw_kernel_func=raw_add_scale, raw_kernel_func=raw_add_scale,
op_name="test_make_fx", op_name="test_make_fx",
fake_impl=lambda *a, **kw: None, fake_impl=lambda *a, **kw: None,
config_picker=lambda args, keys: "default",
) )
wrapper.register_config_picker(lambda args, keys: "default")
def fn(x, y): def fn(x, y):
return wrapper(x, y, scale) return wrapper(x, y, scale)
...@@ -143,8 +143,8 @@ class TestMakeFxHop: ...@@ -143,8 +143,8 @@ class TestMakeFxHop:
raw_kernel_func=raw_silu_mul, raw_kernel_func=raw_silu_mul,
op_name="test_pm_silu_mul", op_name="test_pm_silu_mul",
fake_impl=lambda *a, **kw: None, fake_impl=lambda *a, **kw: None,
config_picker=lambda args, keys: "default",
) )
wrapper.register_config_picker(lambda args, keys: "default")
def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.silu(x) * y return torch.nn.functional.silu(x) * y
......
...@@ -21,7 +21,9 @@ if not has_helion(): ...@@ -21,7 +21,9 @@ if not has_helion():
) )
import helion import helion
import helion.language as hl
from tests.kernels.helion.helpers import dummy_kernel_registry
from vllm.kernels.helion.config_manager import ConfigManager from vllm.kernels.helion.config_manager import ConfigManager
from vllm.kernels.helion.register import ( from vllm.kernels.helion.register import (
_HOP_AVAILABLE, _HOP_AVAILABLE,
...@@ -34,6 +36,13 @@ from vllm.kernels.helion.register import ( ...@@ -34,6 +36,13 @@ from vllm.kernels.helion.register import (
) )
def _add_kernel(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
for tile in hl.tile(x.size()):
out[tile] = x[tile] + y[tile]
return out
@pytest.fixture @pytest.fixture
def sample_configs(): def sample_configs():
"""Create real Helion config objects for testing.""" """Create real Helion config objects for testing."""
...@@ -90,7 +99,7 @@ def configured_kernel(sample_kernel, sample_configs, config_manager_with_test_co ...@@ -90,7 +99,7 @@ def configured_kernel(sample_kernel, sample_configs, config_manager_with_test_co
with ( with (
patch( patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance", "vllm.kernels.helion.config_manager.ConfigManager",
return_value=config_manager_with_test_configs, return_value=config_manager_with_test_configs,
), ),
patch( patch(
...@@ -158,7 +167,7 @@ def create_configured_kernel_with_configs( ...@@ -158,7 +167,7 @@ def create_configured_kernel_with_configs(
with ( with (
patch( patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance", "vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager, return_value=mock_config_manager,
), ),
patch( patch(
...@@ -189,7 +198,7 @@ class TestConfiguredHelionKernel: ...@@ -189,7 +198,7 @@ class TestConfiguredHelionKernel:
with ( with (
patch( patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance", "vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager, return_value=mock_config_manager,
), ),
patch( patch(
...@@ -266,7 +275,7 @@ class TestConfiguredHelionKernel: ...@@ -266,7 +275,7 @@ class TestConfiguredHelionKernel:
with ( with (
patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel, patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
patch( patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance", "vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager, return_value=mock_config_manager,
), ),
patch( patch(
...@@ -310,7 +319,7 @@ class TestConfiguredHelionKernel: ...@@ -310,7 +319,7 @@ class TestConfiguredHelionKernel:
with ( with (
patch("vllm.kernels.helion.register.helion.kernel") as mock_helion_kernel, patch("vllm.kernels.helion.register.helion.kernel") as mock_helion_kernel,
patch( patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance", "vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager, return_value=mock_config_manager,
), ),
patch( patch(
...@@ -346,23 +355,15 @@ class TestConfiguredHelionKernel: ...@@ -346,23 +355,15 @@ class TestConfiguredHelionKernel:
class TestHelionKernelWrapper: class TestHelionKernelWrapper:
"""Test suite for HelionKernelWrapper.""" """Test suite for HelionKernelWrapper."""
def test_get_configured_op_validates_configs_available(self, sample_kernel): def test_init_disables_on_missing_configs(self, sample_kernel):
"""Test get_configured_op validates configs are available.""" """Test __init__ marks wrapper as disabled when configs are missing."""
def fake_impl(*args, **kwargs): def fake_impl(*args, **kwargs):
return torch.zeros_like(args[0]) return torch.zeros_like(args[0])
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
)
def default_picker(args, config_keys): def default_picker(args, config_keys):
return "default" return "default"
wrapper._config_picker = default_picker
mock_config_manager = Mock(spec=ConfigManager) mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock( mock_config_manager.get_platform_configs = Mock(
return_value={} return_value={}
...@@ -370,52 +371,99 @@ class TestHelionKernelWrapper: ...@@ -370,52 +371,99 @@ class TestHelionKernelWrapper:
with ( with (
patch( patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance", "vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager, return_value=mock_config_manager,
), ),
patch( patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name", "vllm.kernels.helion.utils.get_canonical_gpu_name",
return_value="nvidia_h200", return_value="nvidia_h200",
), ),
pytest.raises(ValueError, match="No configs available"), patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
): ):
wrapper.get_configured_op() mock_kernel.return_value = Mock(return_value=sample_kernel)
def test_get_configured_op_validates_config_picker( wrapper = HelionKernelWrapper(
self, sample_kernel, sample_configs raw_kernel_func=sample_kernel,
): op_name="test_kernel",
"""Test get_configured_op validates config picker.""" fake_impl=fake_impl,
config_picker=default_picker,
)
assert wrapper._disabled is True
assert "No configs available" in wrapper._disabled_reason
def test_disabled_wrapper_raises_on_call(self, sample_kernel):
"""Test __call__ raises RuntimeError on a disabled wrapper."""
def fake_impl(*args, **kwargs): def fake_impl(*args, **kwargs):
return torch.zeros_like(args[0]) return torch.zeros_like(args[0])
wrapper = HelionKernelWrapper( def default_picker(args, config_keys):
raw_kernel_func=sample_kernel, return "default"
op_name="test_kernel",
fake_impl=fake_impl,
)
# Don't set config picker - should raise assertion error
mock_config_manager = Mock(spec=ConfigManager) mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(return_value=sample_configs) mock_config_manager.get_platform_configs = Mock(return_value={})
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager,
),
patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name",
return_value="nvidia_h200",
),
patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
):
mock_kernel.return_value = Mock(return_value=sample_kernel)
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
config_picker=default_picker,
)
with pytest.raises(RuntimeError, match="is disabled"):
wrapper(torch.randn(4, 4), torch.randn(4, 4))
def test_disabled_wrapper_get_configured_op_raises(self, sample_kernel):
"""Test get_configured_op raises RuntimeError on a disabled wrapper."""
def fake_impl(*args, **kwargs):
return torch.zeros_like(args[0])
def default_picker(args, config_keys):
return "default"
mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(return_value={})
with ( with (
patch( patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance", "vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager, return_value=mock_config_manager,
), ),
patch( patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name", "vllm.kernels.helion.utils.get_canonical_gpu_name",
return_value="nvidia_h200", return_value="nvidia_h200",
), ),
pytest.raises(AssertionError, match="No config picker registered"), patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
): ):
mock_kernel.return_value = Mock(return_value=sample_kernel)
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
config_picker=default_picker,
)
with pytest.raises(RuntimeError, match="is disabled"):
wrapper.get_configured_op() wrapper.get_configured_op()
def test_get_configured_op_returns_cached_kernel( def test_disabled_wrapper_supports_get_inputs(self, sample_kernel):
self, sample_kernel, sample_configs """Test get_inputs works on a disabled wrapper."""
):
"""Test get_configured_op returns cached ConfiguredHelionKernel."""
def fake_impl(*args, **kwargs): def fake_impl(*args, **kwargs):
return torch.zeros_like(args[0]) return torch.zeros_like(args[0])
...@@ -423,19 +471,99 @@ class TestHelionKernelWrapper: ...@@ -423,19 +471,99 @@ class TestHelionKernelWrapper:
def default_picker(args, config_keys): def default_picker(args, config_keys):
return "default" return "default"
wrapper = HelionKernelWrapper( expected_inputs = {"key1": (torch.randn(4),)}
raw_kernel_func=sample_kernel, input_gen = Mock(return_value=expected_inputs)
op_name="test_kernel",
fake_impl=fake_impl, mock_config_manager = Mock(spec=ConfigManager)
) mock_config_manager.get_platform_configs = Mock(return_value={})
wrapper._config_picker = default_picker
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager,
),
patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name",
return_value="nvidia_h200",
),
patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
):
mock_kernel.return_value = Mock(return_value=sample_kernel)
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
config_picker=default_picker,
input_generator=input_gen,
)
assert wrapper._disabled is True
result = wrapper.get_inputs()
assert result is expected_inputs
def test_disabled_wrapper_supports_run_autotune(self, sample_kernel):
"""Test run_autotune works on a disabled wrapper."""
def fake_impl(*args, **kwargs):
return torch.zeros_like(args[0])
def default_picker(args, config_keys):
return "default"
mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(return_value={})
mock_config = Mock()
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager,
),
patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name",
return_value="nvidia_h200",
),
patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
):
mock_kernel.return_value = Mock(return_value=sample_kernel)
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
config_picker=default_picker,
)
assert wrapper._disabled is True
with patch(
"vllm.kernels.helion.register.create_helion_decorated_kernel"
) as mock_create:
mock_autotune_kernel = Mock()
mock_autotune_kernel.autotune.return_value = mock_config
mock_create.return_value = mock_autotune_kernel
inputs = (torch.randn(4, 4),)
result = wrapper.run_autotune(inputs)
assert result is mock_config
def test_init_caches_configured_kernel(self, sample_kernel, sample_configs):
"""Test __init__ eagerly builds and caches ConfiguredHelionKernel."""
def fake_impl(*args, **kwargs):
return torch.zeros_like(args[0])
def default_picker(args, config_keys):
return "default"
mock_config_manager = Mock(spec=ConfigManager) mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(return_value=sample_configs) mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)
with ( with (
patch( patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance", "vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager, return_value=mock_config_manager,
), ),
patch( patch(
...@@ -444,13 +572,77 @@ class TestHelionKernelWrapper: ...@@ -444,13 +572,77 @@ class TestHelionKernelWrapper:
), ),
patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel, patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
): ):
mock_decorated = Mock() mock_kernel.return_value = Mock(return_value=sample_kernel)
mock_kernel.return_value = Mock(return_value=mock_decorated)
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
config_picker=default_picker,
)
assert wrapper._configured_kernel is not None
result1 = wrapper.get_configured_op() result1 = wrapper.get_configured_op()
result2 = wrapper.get_configured_op() result2 = wrapper.get_configured_op()
assert result1 is result2 assert result1 is result2
@pytest.mark.skipif(
not _HOP_AVAILABLE, reason="HOP path only used when HOP available"
)
def test_init_eagerly_initializes_hop_path(self):
"""Test that register_kernel eagerly builds the configured kernel
on the HOP path (no custom op registration needed)."""
from vllm.kernels.helion.utils import get_canonical_gpu_name
configs = {"default": helion.Config(block_sizes=[4, 4])}
with (
dummy_kernel_registry(configs=configs) as register,
patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name",
wraps=get_canonical_gpu_name,
) as mock_gpu,
):
wrapper = register(
config_picker=lambda args, keys: "default",
)(_add_kernel)
mock_gpu.assert_called_once()
assert wrapper._configured_kernel is not None
with patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name",
side_effect=AssertionError("get_canonical_gpu_name called during __call__"),
):
x = torch.randn(4, 4, device="cuda")
y = torch.randn(4, 4, device="cuda")
result = wrapper(x, y)
expected = x + y
assert torch.allclose(result, expected)
@pytest.mark.skipif(
_HOP_AVAILABLE, reason="CustomOp path not used when HOP available"
)
def test_init_eagerly_initializes(self):
"""Test that register_kernel eagerly loads configs and detects GPU
during construction so __call__ needs no further initialization."""
from vllm.kernels.helion.utils import get_canonical_gpu_name
with (
dummy_kernel_registry() as register,
patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name",
wraps=get_canonical_gpu_name,
) as mock_gpu,
):
wrapper = register(
config_picker=lambda args, keys: "default",
)(_add_kernel)
# Init must have detected GPU and built the kernel
mock_gpu.assert_called_once()
assert wrapper._configured_kernel is not None
assert hasattr(torch.ops.vllm_helion, wrapper.op_name)
@pytest.mark.skipif( @pytest.mark.skipif(
_HOP_AVAILABLE, reason="CustomOp path not used when HOP available" _HOP_AVAILABLE, reason="CustomOp path not used when HOP available"
) )
...@@ -463,13 +655,6 @@ class TestHelionKernelWrapper: ...@@ -463,13 +655,6 @@ class TestHelionKernelWrapper:
def default_picker(args, config_keys): def default_picker(args, config_keys):
return "default" return "default"
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
)
wrapper._config_picker = default_picker
mock_config_manager = Mock(spec=ConfigManager) mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(return_value=sample_configs) mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)
...@@ -479,7 +664,7 @@ class TestHelionKernelWrapper: ...@@ -479,7 +664,7 @@ class TestHelionKernelWrapper:
with ( with (
patch( patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance", "vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager, return_value=mock_config_manager,
), ),
patch( patch(
...@@ -491,6 +676,13 @@ class TestHelionKernelWrapper: ...@@ -491,6 +676,13 @@ class TestHelionKernelWrapper:
): ):
mock_decorated = Mock() mock_decorated = Mock()
mock_kernel.return_value = Mock(return_value=mock_decorated) mock_kernel.return_value = Mock(return_value=mock_decorated)
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
config_picker=default_picker,
)
result = wrapper._get_or_register_custom_op() result = wrapper._get_or_register_custom_op()
assert result is existing_op assert result is existing_op
...@@ -506,13 +698,6 @@ class TestHelionKernelWrapper: ...@@ -506,13 +698,6 @@ class TestHelionKernelWrapper:
def default_picker(args, config_keys): def default_picker(args, config_keys):
return "default" return "default"
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
)
wrapper._config_picker = default_picker
mock_config_manager = Mock(spec=ConfigManager) mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(return_value=sample_configs) mock_config_manager.get_platform_configs = Mock(return_value=sample_configs)
...@@ -532,7 +717,7 @@ class TestHelionKernelWrapper: ...@@ -532,7 +717,7 @@ class TestHelionKernelWrapper:
with ( with (
patch( patch(
"vllm.kernels.helion.config_manager.ConfigManager.get_instance", "vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager, return_value=mock_config_manager,
), ),
patch( patch(
...@@ -548,6 +733,13 @@ class TestHelionKernelWrapper: ...@@ -548,6 +733,13 @@ class TestHelionKernelWrapper:
): ):
mock_decorated = Mock() mock_decorated = Mock()
mock_kernel.return_value = Mock(return_value=mock_decorated) mock_kernel.return_value = Mock(return_value=mock_decorated)
wrapper = HelionKernelWrapper(
raw_kernel_func=sample_kernel,
op_name="test_kernel",
fake_impl=fake_impl,
config_picker=default_picker,
)
result = wrapper._get_or_register_custom_op() result = wrapper._get_or_register_custom_op()
mock_register.assert_called_once() mock_register.assert_called_once()
...@@ -584,11 +776,10 @@ class TestKernelRegistry: ...@@ -584,11 +776,10 @@ class TestKernelRegistry:
def test_get_kernel_by_name_returns_kernel(self): def test_get_kernel_by_name_returns_kernel(self):
"""Test get_kernel_by_name returns registered kernel.""" """Test get_kernel_by_name returns registered kernel."""
wrapper = HelionKernelWrapper( with dummy_kernel_registry() as register:
raw_kernel_func=Mock(), wrapper = register(
op_name="test_kernel", "test_kernel", config_picker=lambda args, keys: "default"
fake_impl=Mock(), )(_add_kernel)
)
from vllm.kernels.helion.register import _REGISTERED_KERNELS from vllm.kernels.helion.register import _REGISTERED_KERNELS
...@@ -604,112 +795,87 @@ class TestKernelRegistry: ...@@ -604,112 +795,87 @@ class TestKernelRegistry:
def test_register_kernel_auto_generates_fake_impl(self): def test_register_kernel_auto_generates_fake_impl(self):
"""Test register_kernel auto-generates fake_impl when not provided.""" """Test register_kernel auto-generates fake_impl when not provided."""
with patch("vllm.kernels.helion.register.infer_fake_impl") as mock_infer: with (
dummy_kernel_registry() as register,
patch("vllm.kernels.helion.register.infer_fake_impl") as mock_infer,
):
mock_fake = Mock() mock_fake = Mock()
mock_infer.return_value = mock_fake mock_infer.return_value = mock_fake
wrapper = register(
config_picker=lambda args, keys: "default",
)(_add_kernel)
def original_kernel(x): mock_infer.assert_called_once_with(_add_kernel, None)
return x assert wrapper._fake_impl is mock_fake
wrapper = register_kernel(original_kernel)
mock_infer.assert_called_once_with(original_kernel, None)
assert wrapper._fake_impl is mock_fake
def test_register_kernel_creates_wrapper(self): def test_register_kernel_creates_wrapper(self):
"""Test register_kernel creates HelionKernelWrapper.""" """Test register_kernel creates HelionKernelWrapper."""
with dummy_kernel_registry() as register:
def test_kernel(x): result = register("test_name", config_picker=lambda args, keys: "default")(
return x _add_kernel
)
result = register_kernel("test_name")(test_kernel)
assert isinstance(result, HelionKernelWrapper) assert isinstance(result, HelionKernelWrapper)
assert result.op_name == "test_name" assert result.op_name == "test_name"
assert result.raw_kernel_func is test_kernel assert result.raw_kernel_func is _add_kernel
def test_register_kernel_auto_detects_name(self): def test_register_kernel_auto_detects_name(self):
"""Test register_kernel uses function name when no name provided.""" """Test register_kernel uses function name when no name provided."""
with dummy_kernel_registry() as register:
wrapper = register(config_picker=lambda args, keys: "default")(_add_kernel)
@register_kernel assert wrapper.op_name == "_add_kernel"
def my_test_kernel(x):
return x
assert my_test_kernel.op_name == "my_test_kernel"
def test_register_kernel_registers_in_global_registry(self): def test_register_kernel_registers_in_global_registry(self):
"""Test register_kernel adds wrapper to global registry.""" """Test register_kernel adds wrapper to global registry."""
with dummy_kernel_registry() as register:
@register_kernel wrapper = register(
def test_kernel(x): "test_kernel", config_picker=lambda args, keys: "default"
return x )(_add_kernel)
registered_kernels = get_registered_kernels() registered_kernels = get_registered_kernels()
assert "test_kernel" in registered_kernels assert "test_kernel" in registered_kernels
assert registered_kernels["test_kernel"] is test_kernel assert registered_kernels["test_kernel"] is wrapper
def test_register_kernel_passes_helion_settings(self): def test_register_kernel_passes_helion_settings(self):
"""Test register_kernel passes helion_settings to wrapper.""" """Test register_kernel passes helion_settings to wrapper."""
mock_settings = Mock() settings = helion.Settings()
mock_settings.to_dict.return_value = {"debug": True} settings.print_output_code = True
@register_kernel("test_name", helion_settings=mock_settings) with dummy_kernel_registry() as register:
def test_kernel(x): result = register(
return x "test_name",
config_picker=lambda args, keys: "default",
helion_settings=settings,
)(_add_kernel)
assert test_kernel.helion_settings is mock_settings assert result.helion_settings is settings
def test_register_kernel_supports_decorator_syntax(self): def test_register_kernel_supports_decorator_syntax(self):
"""Test register_kernel works with decorator arguments.""" """Test register_kernel works with decorator arguments."""
mock_fake = Mock() mock_fake = Mock()
wrapper = register_kernel("custom_name", fake_impl=mock_fake) with dummy_kernel_registry() as register:
result = register(
def test_kernel(x): "custom_name",
return x config_picker=lambda args, keys: "default",
fake_impl=mock_fake,
result = wrapper(test_kernel) )(_add_kernel)
assert result.op_name == "custom_name" assert result.op_name == "custom_name"
assert result._fake_impl is mock_fake assert result._fake_impl is mock_fake
def test_register_kernel_bare_decorator(self):
"""Test register_kernel works as bare decorator."""
@register_kernel
def test_kernel(x):
return x
assert isinstance(test_kernel, HelionKernelWrapper)
assert test_kernel.op_name == "test_kernel"
def test_registered_wrapper_can_register_config_picker(self):
"""Test that registered wrapper can register config picker."""
@register_kernel
def test_kernel(x):
return x
def my_picker(args, config_keys):
return "default"
result = test_kernel.register_config_picker(my_picker)
assert result is my_picker
assert test_kernel._config_picker is my_picker
def test_register_kernel_raises_on_duplicate_registration(self): def test_register_kernel_raises_on_duplicate_registration(self):
"""Test register_kernel raises error on duplicate names.""" """Test register_kernel raises error on duplicate names."""
with dummy_kernel_registry() as register:
register("duplicate_name", config_picker=lambda args, keys: "default")(
_add_kernel
)
@register_kernel("duplicate_name") with pytest.raises(ValueError, match="already registered"):
def kernel1(x): register("duplicate_name", config_picker=lambda args, keys: "default")(
return x _add_kernel
)
with pytest.raises(ValueError, match="already registered"):
@register_kernel("duplicate_name")
def kernel2(x):
return x
def test_register_kernel_rejects_autotuner_fn_in_settings(self): def test_register_kernel_rejects_autotuner_fn_in_settings(self):
"""Test register_kernel rejects conflicting autotuner_fn.""" """Test register_kernel rejects conflicting autotuner_fn."""
...@@ -718,7 +884,11 @@ class TestKernelRegistry: ...@@ -718,7 +884,11 @@ class TestKernelRegistry:
with pytest.raises(ValueError, match="uses a custom autotuner"): with pytest.raises(ValueError, match="uses a custom autotuner"):
@register_kernel("test", helion_settings=mock_settings) @register_kernel(
"test",
config_picker=lambda args, keys: "default",
helion_settings=mock_settings,
)
def test_kernel(x): def test_kernel(x):
return x return x
...@@ -727,11 +897,47 @@ class TestKernelRegistry: ...@@ -727,11 +897,47 @@ class TestKernelRegistry:
mock_settings = Mock() mock_settings = Mock()
mock_settings.to_dict.return_value = {"static_shapes": False} mock_settings.to_dict.return_value = {"static_shapes": False}
with patch("vllm.kernels.helion.register.logger") as mock_logger: with (
dummy_kernel_registry() as register,
patch("vllm.kernels.helion.register.logger") as mock_logger,
):
register(
"test",
config_picker=lambda args, keys: "default",
helion_settings=mock_settings,
)(_add_kernel)
@register_kernel("test", helion_settings=mock_settings) mock_logger.warning.assert_not_called()
def test_kernel(x):
return x
# Should not call warning def test_disabled_kernel_appears_in_registry(self):
mock_logger.warning.assert_not_called() """Test that a disabled wrapper is still in the global registry."""
def fake_impl(*args, **kwargs):
return torch.zeros_like(args[0])
mock_config_manager = Mock(spec=ConfigManager)
mock_config_manager.get_platform_configs = Mock(return_value={})
with (
patch(
"vllm.kernels.helion.config_manager.ConfigManager",
return_value=mock_config_manager,
),
patch(
"vllm.kernels.helion.utils.get_canonical_gpu_name",
return_value="nvidia_h200",
),
patch("vllm.kernels.helion.register.helion.kernel") as mock_kernel,
):
mock_kernel.return_value = Mock(return_value=_add_kernel)
wrapper = register_kernel(
"disabled_kernel",
config_picker=lambda args, keys: "default",
fake_impl=fake_impl,
)(_add_kernel)
assert wrapper._disabled is True
registered = get_registered_kernels()
assert "disabled_kernel" in registered
assert registered["disabled_kernel"] is wrapper
...@@ -22,39 +22,6 @@ from vllm.kernels.helion.register import register_kernel ...@@ -22,39 +22,6 @@ from vllm.kernels.helion.register import register_kernel
logger = init_logger(__name__) logger = init_logger(__name__)
@register_kernel # type: ignore[misc]
def silu_mul_fp8(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
original_shape = input.shape
two_d = hl.specialize(original_shape[-1])
d = two_d // 2
output_shape = original_shape[:-1] + (d,)
input_2d = input.view(-1, original_shape[-1])
m = input_2d.shape[0]
# TODO(gmagogsfm): Support for more float8 subtypes (e4m3fnuz, e5m2) coming
out = torch.empty((m, d), device=input.device, dtype=torch.float8_e4m3fn)
input_part_a = input_2d[:, :d]
input_part_b = input_2d[:, d:]
assert scale.numel() == 1, "Scale must be a scalar Tensor"
for tile_m, tile_n in hl.tile([m, d]):
a_vals = input_part_a[tile_m, tile_n]
silu_result = torch.nn.functional.silu(a_vals)
b_vals = input_part_b[tile_m, tile_n]
result = silu_result * b_vals
result_f32 = result.to(torch.float32)
scale_val = hl.load(scale, [0])
inv_scale = 1.0 / scale_val
result_scaled = result_f32 * inv_scale
out[tile_m, tile_n] = result_scaled.to(out.dtype)
return out.view(output_shape)
@silu_mul_fp8.register_input_generator # type: ignore[misc]
def generate_silu_mul_fp8_inputs() -> dict[str, tuple[Any, ...]]: def generate_silu_mul_fp8_inputs() -> dict[str, tuple[Any, ...]]:
intermediate_sizes = [2048, 2880, 4096, 8192, 11008, 14336] intermediate_sizes = [2048, 2880, 4096, 8192, 11008, 14336]
...@@ -65,8 +32,6 @@ def generate_silu_mul_fp8_inputs() -> dict[str, tuple[Any, ...]]: ...@@ -65,8 +32,6 @@ def generate_silu_mul_fp8_inputs() -> dict[str, tuple[Any, ...]]:
inputs = {} inputs = {}
for num_tokens in num_tokens_list: for num_tokens in num_tokens_list:
for intermediate_size in intermediate_sizes: for intermediate_size in intermediate_sizes:
# Input tensor has shape (num_tokens, 2 * intermediate_size)
# because silu_mul splits it into two halves
input_tensor = torch.randn( input_tensor = torch.randn(
num_tokens, num_tokens,
2 * intermediate_size, 2 * intermediate_size,
...@@ -81,7 +46,6 @@ def generate_silu_mul_fp8_inputs() -> dict[str, tuple[Any, ...]]: ...@@ -81,7 +46,6 @@ def generate_silu_mul_fp8_inputs() -> dict[str, tuple[Any, ...]]:
return inputs return inputs
@silu_mul_fp8.register_config_picker # type: ignore[misc]
def pick_silu_mul_fp8_config( def pick_silu_mul_fp8_config(
args: tuple[Any, ...], config_keys: list[str] args: tuple[Any, ...], config_keys: list[str]
) -> str | None: ) -> str | None:
...@@ -128,6 +92,41 @@ def pick_silu_mul_fp8_config( ...@@ -128,6 +92,41 @@ def pick_silu_mul_fp8_config(
return f"intermediate_{best_isize}_numtokens_{best_ntokens}" return f"intermediate_{best_isize}_numtokens_{best_ntokens}"
@register_kernel(
config_picker=pick_silu_mul_fp8_config,
input_generator=generate_silu_mul_fp8_inputs,
)
def silu_mul_fp8(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
original_shape = input.shape
two_d = hl.specialize(original_shape[-1])
d = two_d // 2
output_shape = original_shape[:-1] + (d,)
input_2d = input.view(-1, original_shape[-1])
m = input_2d.shape[0]
# TODO(gmagogsfm): Support for more float8 subtypes (e4m3fnuz, e5m2) coming
out = torch.empty((m, d), device=input.device, dtype=torch.float8_e4m3fn)
input_part_a = input_2d[:, :d]
input_part_b = input_2d[:, d:]
assert scale.numel() == 1, "Scale must be a scalar Tensor"
for tile_m, tile_n in hl.tile([m, d]):
a_vals = input_part_a[tile_m, tile_n]
silu_result = torch.nn.functional.silu(a_vals)
b_vals = input_part_b[tile_m, tile_n]
result = silu_result * b_vals
result_f32 = result.to(torch.float32)
scale_val = hl.load(scale, [0])
inv_scale = 1.0 / scale_val
result_scaled = result_f32 * inv_scale
out[tile_m, tile_n] = result_scaled.to(out.dtype)
return out.view(output_shape)
def silu_mul_fp8_baseline(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: def silu_mul_fp8_baseline(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
output_shape = input.shape[:-1] + (input.shape[-1] // 2,) output_shape = input.shape[:-1] + (input.shape[-1] // 2,)
out = torch.empty(output_shape, dtype=torch.float8_e4m3fn, device=input.device) out = torch.empty(output_shape, dtype=torch.float8_e4m3fn, device=input.device)
......
...@@ -37,7 +37,7 @@ Key Classes ...@@ -37,7 +37,7 @@ Key Classes
""" """
from collections.abc import Callable from collections.abc import Callable
from typing import Any, cast, overload from typing import Any, cast
import torch import torch
from torch.library import Library from torch.library import Library
...@@ -95,7 +95,7 @@ def validate_helion_settings( ...@@ -95,7 +95,7 @@ def validate_helion_settings(
raise ValueError( raise ValueError(
f"HelionKernelWrapper for '{op_name}' uses a custom autotuner via " f"HelionKernelWrapper for '{op_name}' uses a custom autotuner via "
f"config picker. Remove 'autotuner_fn' from helion_settings and use " f"config picker. Remove 'autotuner_fn' from helion_settings and use "
f"@{op_name}.register_config_picker instead." f"register_kernel(..., config_picker=...) instead."
) )
if settings_dict.get("static_shapes") is True: if settings_dict.get("static_shapes") is True:
...@@ -169,7 +169,7 @@ class ConfiguredHelionKernel: ...@@ -169,7 +169,7 @@ class ConfiguredHelionKernel:
if self.config_picker is None: if self.config_picker is None:
raise RuntimeError( raise RuntimeError(
f"No config picker registered for kernel '{self.op_name}'. " f"No config picker registered for kernel '{self.op_name}'. "
f"Use @{self.op_name}.register_config_picker to register one." f"A config_picker must be provided to register_kernel()."
) )
# After None check, config_picker is guaranteed to be non-None # After None check, config_picker is guaranteed to be non-None
...@@ -215,7 +215,7 @@ class ConfiguredHelionKernel: ...@@ -215,7 +215,7 @@ class ConfiguredHelionKernel:
from vllm.kernels.helion.utils import get_canonical_gpu_name from vllm.kernels.helion.utils import get_canonical_gpu_name
self.platform = get_canonical_gpu_name() self.platform = get_canonical_gpu_name()
config_manager = ConfigManager.get_instance() config_manager = ConfigManager()
self.configs = config_manager.get_platform_configs(self.op_name, self.platform) self.configs = config_manager.get_platform_configs(self.op_name, self.platform)
if not self.configs: if not self.configs:
...@@ -253,7 +253,9 @@ class HelionKernelWrapper: ...@@ -253,7 +253,9 @@ class HelionKernelWrapper:
raw_kernel_func: Callable, raw_kernel_func: Callable,
op_name: str, op_name: str,
fake_impl: Callable, fake_impl: Callable,
config_picker: Callable[[tuple[Any, ...], list[str]], str | None],
helion_settings: "helion.Settings | None" = None, helion_settings: "helion.Settings | None" = None,
input_generator: Callable[[], dict[str, tuple[Any, ...]]] | None = None,
): ):
# Validate helion_settings doesn't conflict with our custom autotuner # Validate helion_settings doesn't conflict with our custom autotuner
validate_helion_settings(helion_settings, op_name) validate_helion_settings(helion_settings, op_name)
...@@ -262,23 +264,43 @@ class HelionKernelWrapper: ...@@ -262,23 +264,43 @@ class HelionKernelWrapper:
self.op_name = op_name self.op_name = op_name
self._fake_impl = fake_impl self._fake_impl = fake_impl
self.helion_settings = helion_settings self.helion_settings = helion_settings
self._config_picker: ( self._config_picker = config_picker
Callable[[tuple[Any, ...], list[str]], str | None] | None self._input_generator = input_generator
) = None
self._configured_kernel: ConfiguredHelionKernel | None = None self._configured_kernel: ConfiguredHelionKernel | None = None
self._input_generator: Callable[[], dict[str, tuple[Any, ...]]] | None = None # TODO(@gmagogsfm): Remove this disable flag once integrated with vLLM IR,
# which handles op enablement/disablement.
self._disabled = False
self._disabled_reason: str | None = None
try:
if not _HOP_AVAILABLE:
self._get_or_register_custom_op()
else:
self.get_configured_op()
except ValueError as e:
self._disabled = True
self._disabled_reason = str(e)
logger.warning(
"Helion kernel '%s' is disabled: %s",
op_name,
self._disabled_reason,
)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
# CustomOp fallback: register as torch custom op for torch.compile if self._disabled:
# compatibility on older PyTorch lacking HOP/EffectType support raise RuntimeError(
f"Helion kernel '{self.op_name}' is disabled: {self._disabled_reason}"
)
if not _HOP_AVAILABLE: if not _HOP_AVAILABLE:
custom_op = self._get_or_register_custom_op() op = getattr(torch.ops.vllm_helion, self.op_name)
return custom_op(*args, **kwargs) return op(*args, **kwargs)
# HOP tracing: record HigherOrderOp in the FX graph assert self._configured_kernel is not None, (
f"Kernel '{self.op_name}' was not initialized. "
"Please open an issue on GitHub."
)
if get_proxy_mode() is not None: if get_proxy_mode() is not None:
return self._call_via_hop(args, kwargs) return self._call_via_hop(args, kwargs)
# Eager: run the configured kernel directly return self._configured_kernel(*args, **kwargs)
return self.get_configured_op()(*args, **kwargs)
def _call_via_hop( def _call_via_hop(
self, self,
...@@ -346,42 +368,11 @@ class HelionKernelWrapper: ...@@ -346,42 +368,11 @@ class HelionKernelWrapper:
constant_args[name] = val constant_args[name] = val
return constant_args, tensor_args return constant_args, tensor_args
def register_config_picker(
self, picker_func: Callable[[tuple[Any, ...], list[str]], str | None]
) -> Callable[[tuple[Any, ...], list[str]], str | None]:
self._config_picker = picker_func
return picker_func
def register_input_generator(
self, generator_func: Callable[[], dict[str, tuple[Any, ...]]]
) -> Callable[[], dict[str, tuple[Any, ...]]]:
"""
Register a function to generate inputs for autotuning and benchmarking.
Args:
generator_func: Function that returns dict[str, tuple] where:
- key: Configuration identifier (e.g., "4096", "hidden_4096")
- value: Tuple of arguments to pass to the kernel
Returns:
The registered function (for decorator usage)
Example:
@kernel_wrapper.register_input_generator
def generate_inputs():
return {
"4096": (torch.randn(4096, device="cuda"), 0.5),
"8192": (torch.randn(8192, device="cuda"), 0.5),
}
"""
self._input_generator = generator_func
return generator_func
def get_inputs(self) -> dict[str, tuple[Any, ...]]: def get_inputs(self) -> dict[str, tuple[Any, ...]]:
if self._input_generator is None: if self._input_generator is None:
raise NotImplementedError( raise NotImplementedError(
f"No input generator registered for kernel '{self.op_name}'. " f"No input generator registered for kernel '{self.op_name}'. "
f"Use @{self.op_name}.register_input_generator to register one." f"Use register_kernel(..., input_generator=...) to register one."
) )
return self._input_generator() return self._input_generator()
...@@ -401,11 +392,10 @@ class HelionKernelWrapper: ...@@ -401,11 +392,10 @@ class HelionKernelWrapper:
return autotune_kernel.autotune(inputs) return autotune_kernel.autotune(inputs)
def get_configured_op(self) -> ConfiguredHelionKernel: def get_configured_op(self) -> ConfiguredHelionKernel:
assert self._config_picker is not None, ( if self._disabled:
f"No config picker registered for kernel '{self.op_name}'. " raise RuntimeError(
f"Use @{self.op_name}.register_config_picker to register one." f"Helion kernel '{self.op_name}' is disabled: {self._disabled_reason}"
) )
if self._configured_kernel is None: if self._configured_kernel is None:
self._configured_kernel = ConfiguredHelionKernel( self._configured_kernel = ConfiguredHelionKernel(
op_name=self.op_name, op_name=self.op_name,
...@@ -413,7 +403,6 @@ class HelionKernelWrapper: ...@@ -413,7 +403,6 @@ class HelionKernelWrapper:
raw_kernel_func=self.raw_kernel_func, raw_kernel_func=self.raw_kernel_func,
helion_settings=self.helion_settings, helion_settings=self.helion_settings,
) )
return self._configured_kernel return self._configured_kernel
def _get_or_register_custom_op(self) -> Any: def _get_or_register_custom_op(self) -> Any:
...@@ -466,45 +455,51 @@ def infer_fake_impl( ...@@ -466,45 +455,51 @@ def infer_fake_impl(
return helion_fake_kernel return helion_fake_kernel
# Overloads are necessary for proper mypy type inference.
# Without overloads, the union return type HelionKernelWrapper | Callable[...]
# causes mypy to complain about missing attributes when tests do:
# wrapper = register_kernel(func) # Should return HelionKernelWrapper
# wrapper._fake_impl # mypy error: "Callable has no attribute _fake_impl"
# The overloads tell mypy the exact return type based on the argument pattern.
@overload
def register_kernel( def register_kernel(
op_name_or_func: Callable, op_name: str | None = None,
*, *,
config_picker: Callable[[tuple[Any, ...], list[str]], str | None],
fake_impl: Callable | None = None, fake_impl: Callable | None = None,
helion_settings: "helion.Settings | None" = None, helion_settings: "helion.Settings | None" = None,
) -> HelionKernelWrapper: ... input_generator: Callable[[], dict[str, tuple[Any, ...]]] | None = None,
) -> Callable[[Callable], HelionKernelWrapper]:
"""Register a Helion kernel with pre-tuned config selection.
@overload
def register_kernel( Wraps the kernel function in a HelionKernelWrapper that eagerly builds
op_name_or_func: str | None = None, the configured kernel and (on older PyTorch) registers a custom op.
*,
fake_impl: Callable | None = None, Args:
helion_settings: "helion.Settings | None" = None, config_picker: Required. Function with signature
) -> Callable[[Callable], HelionKernelWrapper]: ... ``(args: tuple, config_keys: list[str]) -> str | None``
that picks the best config key from available options.
Return ``None`` to fall back to ``"default"``.
def register_kernel(
op_name_or_func: str | Callable | None = None, Example::
*,
fake_impl: Callable | None = None, def pick_config(args, config_keys):
helion_settings: "helion.Settings | None" = None, x = args[0]
) -> HelionKernelWrapper | Callable[[Callable], HelionKernelWrapper]: hidden_size = x.shape[-1]
""" batch_size = x.shape[0]
Decorator to register a Helion kernel function as a HelionKernelWrapper. for key in config_keys:
if key == f"hiddensize_{hidden_size}_batchsize_{batch_size}":
Wraps the raw kernel function in a HelionKernelWrapper and registers it return key
in the global kernel registry. Auto-generates fake_impl if not provided. return "default" if "default" in config_keys else None
input_generator: Optional. Function that returns
``dict[str, tuple]`` where each key is a configuration
identifier (e.g. ``"4096"``, ``"hidden_4096"``) and each
value is a tuple of arguments to pass to the kernel.
Example::
def generate_inputs():
return {
"4096": (torch.randn(4096, device="cuda"), 0.5),
"8192": (torch.randn(8192, device="cuda"), 0.5),
}
""" """
def decorator(kernel_func: Callable) -> HelionKernelWrapper: def decorator(kernel_func: Callable) -> HelionKernelWrapper:
op_name = op_name_or_func if isinstance(op_name_or_func, str) else None
final_op_name = op_name if op_name else kernel_func.__name__ final_op_name = op_name if op_name else kernel_func.__name__
if final_op_name in _REGISTERED_KERNELS: if final_op_name in _REGISTERED_KERNELS:
...@@ -525,7 +520,9 @@ def register_kernel( ...@@ -525,7 +520,9 @@ def register_kernel(
raw_kernel_func=kernel_func, raw_kernel_func=kernel_func,
op_name=final_op_name, op_name=final_op_name,
fake_impl=final_fake_impl, fake_impl=final_fake_impl,
config_picker=config_picker,
helion_settings=helion_settings, helion_settings=helion_settings,
input_generator=input_generator,
) )
_REGISTERED_KERNELS[final_op_name] = kernel_wrapper _REGISTERED_KERNELS[final_op_name] = kernel_wrapper
...@@ -537,9 +534,4 @@ def register_kernel( ...@@ -537,9 +534,4 @@ def register_kernel(
return kernel_wrapper return kernel_wrapper
if callable(op_name_or_func) and not isinstance(op_name_or_func, str): return decorator
# Bare decorator usage: @register_kernel
return decorator(op_name_or_func)
else:
# Decorator with arguments: @register_kernel(...)
return decorator
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment