Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d5c41db3
Unverified
Commit
d5c41db3
authored
Jan 30, 2026
by
Yanan Cao
Committed by
GitHub
Jan 31, 2026
Browse files
[Kernel] [Helion] [3/N] Helion kernel registry (#33203)
Signed-off-by:
Yanan Cao
<
gmagogsfm@gmail.com
>
parent
1618e254
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
319 additions
and
3 deletions
+319
-3
tests/kernels/helion/test_register.py
tests/kernels/helion/test_register.py
+191
-0
vllm/kernels/helion/__init__.py
vllm/kernels/helion/__init__.py
+6
-0
vllm/kernels/helion/register.py
vllm/kernels/helion/register.py
+122
-3
No files found.
tests/kernels/helion/test_register.py
View file @
d5c41db3
...
...
@@ -27,6 +27,9 @@ from vllm.kernels.helion.config_manager import ConfigManager
from
vllm.kernels.helion.register
import
(
ConfiguredHelionKernel
,
HelionKernelWrapper
,
get_kernel_by_name
,
get_registered_kernels
,
register_kernel
,
validate_helion_settings
,
)
...
...
@@ -545,3 +548,191 @@ class TestHelionKernelWrapper:
assert
result
is
new_op
# Check that op_func is the decorated kernel, not ConfiguredHelionKernel
assert
mock_register
.
call_args
[
1
][
"op_func"
]
is
mock_decorated
class
TestKernelRegistry
:
"""Test suite for kernel registry functionality."""
def
setup_method
(
self
):
"""Clear the registry before each test."""
from
vllm.kernels.helion.register
import
_REGISTERED_KERNELS
_REGISTERED_KERNELS
.
clear
()
def
test_get_registered_kernels_returns_copy
(
self
):
"""Test get_registered_kernels returns copy of registry."""
result1
=
get_registered_kernels
()
result2
=
get_registered_kernels
()
# Should be separate objects
assert
result1
is
not
result2
# Should have same content
assert
result1
==
result2
def
test_get_kernel_by_name_returns_kernel
(
self
):
"""Test get_kernel_by_name returns registered kernel."""
wrapper
=
HelionKernelWrapper
(
raw_kernel_func
=
Mock
(),
op_name
=
"test_kernel"
,
fake_impl
=
Mock
(),
)
from
vllm.kernels.helion.register
import
_REGISTERED_KERNELS
_REGISTERED_KERNELS
[
"test_kernel"
]
=
wrapper
result
=
get_kernel_by_name
(
"test_kernel"
)
assert
result
is
wrapper
def
test_get_kernel_by_name_returns_none_for_missing
(
self
):
"""Test get_kernel_by_name returns None for missing kernel."""
result
=
get_kernel_by_name
(
"nonexistent"
)
assert
result
is
None
def
test_register_kernel_auto_generates_fake_impl
(
self
):
"""Test register_kernel auto-generates fake_impl when not provided."""
with
patch
(
"vllm.kernels.helion.register.infer_fake_impl"
)
as
mock_infer
:
mock_fake
=
Mock
()
mock_infer
.
return_value
=
mock_fake
def
original_kernel
(
x
):
return
x
wrapper
=
register_kernel
(
original_kernel
)
mock_infer
.
assert_called_once_with
(
original_kernel
,
None
)
assert
wrapper
.
_fake_impl
is
mock_fake
def
test_register_kernel_creates_wrapper
(
self
):
"""Test register_kernel creates HelionKernelWrapper."""
def
test_kernel
(
x
):
return
x
result
=
register_kernel
(
"test_name"
)(
test_kernel
)
assert
isinstance
(
result
,
HelionKernelWrapper
)
assert
result
.
op_name
==
"test_name"
assert
result
.
raw_kernel_func
is
test_kernel
def
test_register_kernel_auto_detects_name
(
self
):
"""Test register_kernel uses function name when no name provided."""
@
register_kernel
def
my_test_kernel
(
x
):
return
x
assert
my_test_kernel
.
op_name
==
"my_test_kernel"
def
test_register_kernel_registers_in_global_registry
(
self
):
"""Test register_kernel adds wrapper to global registry."""
@
register_kernel
def
test_kernel
(
x
):
return
x
registered_kernels
=
get_registered_kernels
()
assert
"test_kernel"
in
registered_kernels
assert
registered_kernels
[
"test_kernel"
]
is
test_kernel
def
test_register_kernel_passes_helion_settings
(
self
):
"""Test register_kernel passes helion_settings to wrapper."""
mock_settings
=
Mock
()
mock_settings
.
to_dict
.
return_value
=
{
"debug"
:
True
}
@
register_kernel
(
"test_name"
,
helion_settings
=
mock_settings
)
def
test_kernel
(
x
):
return
x
assert
test_kernel
.
helion_settings
is
mock_settings
def
test_register_kernel_supports_decorator_syntax
(
self
):
"""Test register_kernel works with decorator arguments."""
mock_fake
=
Mock
()
wrapper
=
register_kernel
(
"custom_name"
,
fake_impl
=
mock_fake
)
def
test_kernel
(
x
):
return
x
result
=
wrapper
(
test_kernel
)
assert
result
.
op_name
==
"custom_name"
assert
result
.
_fake_impl
is
mock_fake
def
test_register_kernel_bare_decorator
(
self
):
"""Test register_kernel works as bare decorator."""
@
register_kernel
def
test_kernel
(
x
):
return
x
assert
isinstance
(
test_kernel
,
HelionKernelWrapper
)
assert
test_kernel
.
op_name
==
"test_kernel"
def
test_registered_wrapper_can_register_config_picker
(
self
):
"""Test that registered wrapper can register config picker."""
@
register_kernel
def
test_kernel
(
x
):
return
x
def
my_picker
(
args
,
config_keys
):
return
"default"
result
=
test_kernel
.
register_config_picker
(
my_picker
)
assert
result
is
my_picker
assert
test_kernel
.
_config_picker
is
my_picker
def
test_register_kernel_raises_on_duplicate_registration
(
self
):
"""Test register_kernel raises error on duplicate names."""
@
register_kernel
(
"duplicate_name"
)
def
kernel1
(
x
):
return
x
with
pytest
.
raises
(
ValueError
,
match
=
"already registered"
):
@
register_kernel
(
"duplicate_name"
)
def
kernel2
(
x
):
return
x
def
test_register_kernel_rejects_autotuner_fn_in_settings
(
self
):
"""Test register_kernel rejects conflicting autotuner_fn."""
mock_settings
=
Mock
()
mock_settings
.
to_dict
.
return_value
=
{
"autotuner_fn"
:
Mock
()}
with
pytest
.
raises
(
ValueError
,
match
=
"uses a custom autotuner"
):
@
register_kernel
(
"test"
,
helion_settings
=
mock_settings
)
def
test_kernel
(
x
):
return
x
def
test_register_kernel_warns_with_static_shapes_true
(
self
):
"""Test register_kernel warns when static_shapes=True."""
mock_settings
=
Mock
()
mock_settings
.
to_dict
.
return_value
=
{
"static_shapes"
:
True
}
with
patch
(
"vllm.kernels.helion.register.logger"
)
as
mock_logger
:
@
register_kernel
(
"test"
,
helion_settings
=
mock_settings
)
def
test_kernel
(
x
):
return
x
mock_logger
.
warning
.
assert_called_once
()
assert
"static_shapes=True"
in
mock_logger
.
warning
.
call_args
[
0
][
0
]
def
test_register_kernel_no_warning_with_static_shapes_false
(
self
):
"""Test register_kernel doesn't warn with static_shapes=False."""
mock_settings
=
Mock
()
mock_settings
.
to_dict
.
return_value
=
{
"static_shapes"
:
False
}
with
patch
(
"vllm.kernels.helion.register.logger"
)
as
mock_logger
:
@
register_kernel
(
"test"
,
helion_settings
=
mock_settings
)
def
test_kernel
(
x
):
return
x
# Should not call warning
mock_logger
.
warning
.
assert_not_called
()
vllm/kernels/helion/__init__.py
View file @
d5c41db3
...
...
@@ -9,6 +9,9 @@ from vllm.kernels.helion.config_manager import (
from
vllm.kernels.helion.register
import
(
ConfiguredHelionKernel
,
HelionKernelWrapper
,
get_kernel_by_name
,
get_registered_kernels
,
register_kernel
,
vllm_helion_lib
,
)
from
vllm.kernels.helion.utils
import
canonicalize_gpu_name
,
get_canonical_gpu_name
...
...
@@ -20,6 +23,9 @@ __all__ = [
# Kernel registration
"ConfiguredHelionKernel"
,
"HelionKernelWrapper"
,
"get_kernel_by_name"
,
"get_registered_kernels"
,
"register_kernel"
,
"vllm_helion_lib"
,
# Utilities
"canonicalize_gpu_name"
,
...
...
vllm/kernels/helion/register.py
View file @
d5c41db3
...
...
@@ -37,7 +37,7 @@ Key Classes
"""
from
collections.abc
import
Callable
from
typing
import
Any
from
typing
import
Any
,
cast
,
overload
import
torch
from
torch.library
import
Library
...
...
@@ -114,7 +114,7 @@ class ConfiguredHelionKernel:
def
__init__
(
self
,
op_name
:
str
,
config_picker
:
Callable
[[
tuple
[
Any
,
...],
list
[
str
]],
str
|
None
],
config_picker
:
Callable
[[
tuple
[
Any
,
...],
list
[
str
]],
str
|
None
]
|
None
,
raw_kernel_func
:
Callable
,
helion_settings
:
"helion.Settings | None"
=
None
,
):
...
...
@@ -140,9 +140,16 @@ class ConfiguredHelionKernel:
f
"Use @
{
self
.
op_name
}
.register_config_picker to register one."
)
# After None check, config_picker is guaranteed to be non-None
assert
self
.
config_picker
is
not
None
def
key_computer
(
*
args
):
config_keys
=
list
(
self
.
configs
.
keys
())
selected_key
=
self
.
config_picker
(
args
,
config_keys
)
# Cast is safe because we checked for None above
config_picker
=
cast
(
Callable
[[
tuple
[
Any
,
...],
list
[
str
]],
str
|
None
],
self
.
config_picker
)
selected_key
=
config_picker
(
args
,
config_keys
)
if
selected_key
:
return
selected_key
return
"default"
if
"default"
in
self
.
configs
else
None
...
...
@@ -272,3 +279,115 @@ class HelionKernelWrapper:
target_lib
=
vllm_helion_lib
,
)
return
getattr
(
torch
.
ops
.
vllm_helion
,
self
.
op_name
)
# Global registry for tracking all registered HelionKernelWrapper instances
_REGISTERED_KERNELS
:
dict
[
str
,
HelionKernelWrapper
]
=
{}
def
get_registered_kernels
()
->
dict
[
str
,
HelionKernelWrapper
]:
return
_REGISTERED_KERNELS
.
copy
()
def
get_kernel_by_name
(
kernel_name
:
str
)
->
HelionKernelWrapper
|
None
:
return
_REGISTERED_KERNELS
.
get
(
kernel_name
)
def
infer_fake_impl
(
kernel_func
:
Callable
,
helion_settings
:
"helion.Settings | None"
=
None
,
)
->
Callable
:
def
helion_fake_kernel
(
*
args
,
**
kwargs
):
kernel_kwargs
=
{}
if
helion_settings
:
kernel_kwargs
.
update
(
helion_settings
.
to_dict
())
temp_decorated_kernel
=
helion
.
kernel
(
**
kernel_kwargs
)(
kernel_func
)
# Bind with args to get config_spec, then get a valid default config
bound
=
temp_decorated_kernel
.
bind
(
args
)
default_config
=
bound
.
config_spec
.
default_config
()
compiled_runner
=
bound
.
compile_config
(
default_config
)
return
compiled_runner
(
*
args
,
**
kwargs
,
_launcher
=
lambda
*
a
,
**
kw
:
None
)
return
helion_fake_kernel
# Overloads are necessary for proper mypy type inference.
# Without overloads, the union return type HelionKernelWrapper | Callable[...]
# causes mypy to complain about missing attributes when tests do:
# wrapper = register_kernel(func) # Should return HelionKernelWrapper
# wrapper._fake_impl # mypy error: "Callable has no attribute _fake_impl"
# The overloads tell mypy the exact return type based on the argument pattern.
@
overload
def
register_kernel
(
op_name_or_func
:
Callable
,
*
,
fake_impl
:
Callable
|
None
=
None
,
helion_settings
:
"helion.Settings | None"
=
None
,
)
->
HelionKernelWrapper
:
...
@
overload
def
register_kernel
(
op_name_or_func
:
str
|
None
=
None
,
*
,
fake_impl
:
Callable
|
None
=
None
,
helion_settings
:
"helion.Settings | None"
=
None
,
)
->
Callable
[[
Callable
],
HelionKernelWrapper
]:
...
def
register_kernel
(
op_name_or_func
:
str
|
Callable
|
None
=
None
,
*
,
fake_impl
:
Callable
|
None
=
None
,
helion_settings
:
"helion.Settings | None"
=
None
,
)
->
HelionKernelWrapper
|
Callable
[[
Callable
],
HelionKernelWrapper
]:
"""
Decorator to register a Helion kernel function as a HelionKernelWrapper.
Wraps the raw kernel function in a HelionKernelWrapper and registers it
in the global kernel registry. Auto-generates fake_impl if not provided.
"""
def
decorator
(
kernel_func
:
Callable
)
->
HelionKernelWrapper
:
op_name
=
op_name_or_func
if
isinstance
(
op_name_or_func
,
str
)
else
None
final_op_name
=
op_name
if
op_name
else
kernel_func
.
__name__
if
final_op_name
in
_REGISTERED_KERNELS
:
raise
ValueError
(
f
"Helion kernel '
{
final_op_name
}
' is already registered. "
f
"Use a different op_name or check for duplicate registrations."
)
final_fake_impl
=
fake_impl
if
final_fake_impl
is
None
:
final_fake_impl
=
infer_fake_impl
(
kernel_func
,
helion_settings
)
logger
.
debug
(
"Auto-generated fake_impl for Helion kernel '%s'"
,
kernel_func
.
__name__
,
)
kernel_wrapper
=
HelionKernelWrapper
(
raw_kernel_func
=
kernel_func
,
op_name
=
final_op_name
,
fake_impl
=
final_fake_impl
,
helion_settings
=
helion_settings
,
)
_REGISTERED_KERNELS
[
final_op_name
]
=
kernel_wrapper
logger
.
info
(
"Registered Helion kernel '%s' as HelionKernelWrapper"
,
kernel_func
.
__name__
,
)
return
kernel_wrapper
if
callable
(
op_name_or_func
)
and
not
isinstance
(
op_name_or_func
,
str
):
# Bare decorator usage: @register_kernel
return
decorator
(
op_name_or_func
)
else
:
# Decorator with arguments: @register_kernel(...)
return
decorator
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment