Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1ac66942
Unverified
Commit
1ac66942
authored
Mar 31, 2026
by
zhangyiming
Committed by
GitHub
Mar 31, 2026
Browse files
[OOT] Add OOT support for linear kernel. (#37989)
Signed-off-by:
menogrey
<
1299267905@qq.com
>
parent
6cc7abdc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
0 deletions
+76
-0
tests/kernels/quantization/test_scaled_mm_kernel_selection.py
...s/kernels/quantization/test_scaled_mm_kernel_selection.py
+42
-0
vllm/model_executor/kernels/linear/__init__.py
vllm/model_executor/kernels/linear/__init__.py
+34
-0
No files found.
tests/kernels/quantization/test_scaled_mm_kernel_selection.py
View file @
1ac66942
...
...
@@ -7,15 +7,21 @@ Run `pytest tests/kernels/quantization/test_scaled_mm_kernel_selection.py`.
import
inspect
from
abc
import
ABC
from
unittest.mock
import
patch
import
pytest
import
torch
from
vllm.model_executor.kernels.linear
import
(
AiterInt8ScaledMMLinearKernel
,
CPUInt8ScaledMMLinearKernel
,
Int8ScaledMMLinearKernel
,
Int8ScaledMMLinearLayerConfig
,
ScaledMMLinearKernel
,
init_int8_linear_kernel
,
register_linear_kernel
,
)
from
vllm.platforms
import
PlatformEnum
pytestmark
=
pytest
.
mark
.
cpu_test
...
...
@@ -85,3 +91,39 @@ def test_cpu_kernel_accepts_all_configs():
assert
can_impl
,
(
f
"CPUInt8ScaledMMLinearKernel should accept config
{
config
}
:
{
reason
}
"
)
class
OOTInt8ScaledMMLinearKernel
(
Int8ScaledMMLinearKernel
):
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
@
classmethod
def
can_implement
(
cls
,
c
:
Int8ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
pass
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
pass
@
patch
(
"vllm.model_executor.kernels.linear.current_platform"
)
def
test_register_oot_linear_kernel
(
platform_mock
):
"""Test that the linear kernel registration works correctly."""
platform_mock
.
_enum
=
PlatformEnum
.
OOT
register_linear_kernel
(
OOTInt8ScaledMMLinearKernel
,
PlatformEnum
.
OOT
,
"int8"
)
kernel
=
init_int8_linear_kernel
(
True
,
True
,
True
,
"module"
)
assert
isinstance
(
kernel
,
OOTInt8ScaledMMLinearKernel
),
(
"init_int8_linear_kernel should return an instance of the registered kernel"
)
vllm/model_executor/kernels/linear/__init__.py
View file @
1ac66942
...
...
@@ -367,10 +367,44 @@ def choose_mp_linear_kernel(
)
def
register_linear_kernel
(
kernel_class
:
type
,
platform
:
PlatformEnum
,
kernel_type
:
str
=
"mp"
,
)
->
None
:
"""
Register a new linear kernel class to be considered in kernel selection.
Args:
kernel_class (type): The kernel class to register.
platform (PlatformEnum): The platform for which this kernel is applicable.
kernel_type (str): The type of the kernel, either "mp", "int8", or "fp8".
Defaults to "mp".
Raises:
ValueError: If the kernel_type is not recognized.
"""
if
kernel_type
==
"mp"
:
if
platform
not
in
_POSSIBLE_KERNELS
:
_POSSIBLE_KERNELS
[
platform
]
=
[]
_POSSIBLE_KERNELS
[
platform
].
append
(
kernel_class
)
elif
kernel_type
==
"int8"
:
if
platform
not
in
_POSSIBLE_INT8_KERNELS
:
_POSSIBLE_INT8_KERNELS
[
platform
]
=
[]
_POSSIBLE_INT8_KERNELS
[
platform
].
append
(
kernel_class
)
elif
kernel_type
==
"fp8"
:
if
platform
not
in
_POSSIBLE_FP8_KERNELS
:
_POSSIBLE_FP8_KERNELS
[
platform
]
=
[]
_POSSIBLE_FP8_KERNELS
[
platform
].
append
(
kernel_class
)
else
:
raise
ValueError
(
f
"Unrecognized kernel type:
{
kernel_type
}
"
)
__all__
=
[
"init_fp8_linear_kernel"
,
"init_int8_linear_kernel"
,
"choose_mp_linear_kernel"
,
"register_linear_kernel"
,
"FP8ScaledMMLinearKernel"
,
"Int8ScaledMMLinearKernel"
,
"ScaledMMLinearKernel"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment