Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7e8977fc
Unverified
Commit
7e8977fc
authored
Jun 20, 2025
by
Chendi.Xue
Committed by
GitHub
Jun 20, 2025
Browse files
[custom_op][vllm-plugin] update custom_op class to use op_registry (#19164)
Signed-off-by:
Chendi.Xue
<
chendi.xue@intel.com
>
parent
f1e840e8
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
120 additions
and
6 deletions
+120
-6
tests/plugins/vllm_add_dummy_platform/setup.py
tests/plugins/vllm_add_dummy_platform/setup.py
+3
-1
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py
...lm_add_dummy_platform/vllm_add_dummy_platform/__init__.py
+4
-0
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py
...atform/vllm_add_dummy_platform/dummy_attention_backend.py
+3
-2
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_custom_ops.py
...ummy_platform/vllm_add_dummy_platform/dummy_custom_ops.py
+20
-0
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py
..._dummy_platform/vllm_add_dummy_platform/dummy_platform.py
+20
-3
tests/plugins_tests/test_platform_plugins.py
tests/plugins_tests/test_platform_plugins.py
+14
-0
vllm/model_executor/custom_op.py
vllm/model_executor/custom_op.py
+56
-0
No files found.
tests/plugins/vllm_add_dummy_platform/setup.py
View file @
7e8977fc
...
@@ -10,5 +10,7 @@ setup(
...
@@ -10,5 +10,7 @@ setup(
entry_points
=
{
entry_points
=
{
'vllm.platform_plugins'
:
[
'vllm.platform_plugins'
:
[
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin"
# noqa
"dummy_platform_plugin = vllm_add_dummy_platform:dummy_platform_plugin"
# noqa
]
],
"vllm.general_plugins"
:
[
"dummy_custom_ops = vllm_add_dummy_platform:register_ops"
],
})
})
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/__init__.py
View file @
7e8977fc
...
@@ -6,3 +6,7 @@ from typing import Optional
...
@@ -6,3 +6,7 @@ from typing import Optional
def
dummy_platform_plugin
()
->
Optional
[
str
]:
def
dummy_platform_plugin
()
->
Optional
[
str
]:
return
"vllm_add_dummy_platform.dummy_platform.DummyPlatform"
return
"vllm_add_dummy_platform.dummy_platform.DummyPlatform"
def
register_ops
():
import
vllm_add_dummy_platform.dummy_custom_ops
# noqa
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_attention_backend.py
View file @
7e8977fc
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
vllm.attention.backends.flash_attn
import
FlashAttentionBackend
from
vllm.attention.backends.placeholder_attn
import
(
PlaceholderAttentionBackend
)
class
DummyAttentionBackend
(
F
la
sh
AttentionBackend
):
class
DummyAttentionBackend
(
P
la
ceholder
AttentionBackend
):
@
staticmethod
@
staticmethod
def
get_name
()
->
str
:
def
get_name
()
->
str
:
...
...
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_custom_ops.py
0 → 100644
View file @
7e8977fc
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
# Register CustomRotaryEmbedding to CustomOP.
@
RotaryEmbedding
.
register_oot
class
DummyRotaryEmbedding
(
RotaryEmbedding
):
"""Original rotary positional embedding."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
addition_config
=
True
def
forward_oot
(
self
,
*
args
,
**
kwargs
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
return
super
().
forward_oot
(
*
args
,
**
kwargs
)
tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py
View file @
7e8977fc
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
TYPE_CHECKING
from
vllm.platforms.
cuda
import
Cuda
Platform
from
vllm.platforms.
interface
import
Platform
,
Platform
Enum
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
else
:
VllmConfig
=
None
from
vllm
import
envs
class
DummyPlatform
(
CudaPlatform
):
class
DummyPlatform
(
Platform
):
_enum
=
PlatformEnum
.
OOT
device_name
=
"DummyDevice"
device_name
=
"DummyDevice"
device_type
:
str
=
"privateuseone"
dispatch_key
:
str
=
"PrivateUse1"
@
classmethod
def
check_and_update_config
(
cls
,
vllm_config
:
VllmConfig
)
->
None
:
if
envs
.
VLLM_USE_V1
:
compilation_config
=
vllm_config
.
compilation_config
# Activate custom ops for v1.
compilation_config
.
custom_ops
=
[
"all"
]
def
get_attn_backend_cls
(
self
,
backend_name
,
head_size
,
dtype
,
def
get_attn_backend_cls
(
self
,
backend_name
,
head_size
,
dtype
,
kv_cache_dtype
,
block_size
,
use_v1
,
use_mla
):
kv_cache_dtype
,
block_size
,
use_v1
,
use_mla
):
...
...
tests/plugins_tests/test_platform_plugins.py
View file @
7e8977fc
...
@@ -5,6 +5,7 @@ import pytest
...
@@ -5,6 +5,7 @@ import pytest
import
torch
import
torch
from
vllm.attention.selector
import
get_attn_backend
from
vllm.attention.selector
import
get_attn_backend
from
vllm.plugins
import
load_general_plugins
from
vllm.utils
import
STR_BACKEND_ENV_VAR
,
STR_INVALID_VAL
from
vllm.utils
import
STR_BACKEND_ENV_VAR
,
STR_INVALID_VAL
...
@@ -32,3 +33,16 @@ def test_oot_attention_backend(monkeypatch: pytest.MonkeyPatch):
...
@@ -32,3 +33,16 @@ def test_oot_attention_backend(monkeypatch: pytest.MonkeyPatch):
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
STR_INVALID_VAL
)
m
.
setenv
(
STR_BACKEND_ENV_VAR
,
STR_INVALID_VAL
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
"auto"
,
16
,
False
)
backend
=
get_attn_backend
(
16
,
torch
.
float16
,
"auto"
,
16
,
False
)
assert
backend
.
get_name
()
==
"Dummy_Backend"
assert
backend
.
get_name
()
==
"Dummy_Backend"
def
test_oot_custom_op
(
monkeypatch
:
pytest
.
MonkeyPatch
):
# simulate workload by running an example
load_general_plugins
()
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
layer
=
RotaryEmbedding
(
16
,
16
,
16
,
16
,
True
,
torch
.
float16
)
assert
layer
.
__class__
.
__name__
==
"DummyRotaryEmbedding"
,
(
f
"Expected DummyRotaryEmbedding, got
{
layer
.
__class__
.
__name__
}
, "
"possibly because the custom op is not registered correctly."
)
assert
hasattr
(
layer
,
"addition_config"
),
(
"Expected DummyRotaryEmbedding to have an 'addition_config' attribute, "
"which is set by the custom op."
)
vllm/model_executor/custom_op.py
View file @
7e8977fc
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
get_current_vllm_config
from
vllm.config
import
get_current_vllm_config
...
@@ -16,6 +18,24 @@ class CustomOp(nn.Module):
...
@@ -16,6 +18,24 @@ class CustomOp(nn.Module):
Dispatches the forward method to the appropriate backend.
Dispatches the forward method to the appropriate backend.
"""
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
try
:
op_name
=
cls
.
__name__
except
AttributeError
:
raise
TypeError
(
f
"Cannot instantiate '
{
cls
.
__name__
}
': its 'name' attribute "
f
"was not set, possibly because it was not decorated with "
f
"@CustomOp.register, or it's the CustomOp base class itself."
)
from
None
if
op_name
not
in
cls
.
op_registry_oot
:
op_cls_to_instantiate
=
cls
else
:
op_cls_to_instantiate
=
cls
.
op_registry_oot
[
op_name
]
logger
.
debug
(
"Instantiating custom op: %s using %s"
,
op_name
,
str
(
op_cls_to_instantiate
))
return
super
().
__new__
(
op_cls_to_instantiate
)
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
self
.
_forward_method
=
self
.
dispatch_forward
()
self
.
_forward_method
=
self
.
dispatch_forward
()
...
@@ -138,6 +158,7 @@ class CustomOp(nn.Module):
...
@@ -138,6 +158,7 @@ class CustomOp(nn.Module):
# - MyOp.enabled()
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
# - op_registry["my_op"].enabled()
op_registry
:
dict
[
str
,
type
[
'CustomOp'
]]
=
{}
op_registry
:
dict
[
str
,
type
[
'CustomOp'
]]
=
{}
op_registry_oot
:
dict
[
str
,
type
[
'CustomOp'
]]
=
{}
# Decorator to register custom ops.
# Decorator to register custom ops.
@
classmethod
@
classmethod
...
@@ -150,3 +171,38 @@ class CustomOp(nn.Module):
...
@@ -150,3 +171,38 @@ class CustomOp(nn.Module):
return
op_cls
return
op_cls
return
decorator
return
decorator
# Decorator to register out-of-tree(oot) custom ops.
# For OOT custom ops:
# if in-tree layer class is registered with an oot_custom_op layer,
# the oot_custom_op layer will be used instead.
# Example:
# - @UnquantizedFusedMoEMethod.register_oot
# class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod)
# or
# - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod")
@
classmethod
def
register_oot
(
cls
,
_decorated_op_cls
=
None
,
name
:
Optional
[
str
]
=
None
):
def
decorator
(
op_cls
):
reg_name
=
name
if
name
is
not
None
else
cls
.
__name__
assert
reg_name
not
in
cls
.
op_registry_oot
,
\
f
"Duplicate op name:
{
reg_name
}
"
op_cls
.
name
=
reg_name
cls
.
op_registry_oot
[
reg_name
]
=
op_cls
return
op_cls
if
_decorated_op_cls
is
None
:
# Called with parentheses: @CustomOP.register_oot()
# or @CustomOP.register_oot(name="...")
# So, _decorated_op_cls is None.
# We return the actual decorator function.
return
decorator
elif
isinstance
(
_decorated_op_cls
,
type
):
# Check if it's a class
# Called without parentheses: @CustomOP.register_oot
# The first argument is the class itself.
# We call the 'decorator' function immediately with the class.
return
decorator
(
_decorated_op_cls
)
else
:
# Handle other unexpected cases if necessary
raise
TypeError
(
"Decorator can only be applied to classes."
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment