Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1e911dbd
Commit
1e911dbd
authored
Sep 30, 2025
by
zhuwenwen
Browse files
[kernels] add rotary_embedding_deepseek_fuse
off rotary_embedding_deepseek_fuse
parent
63f1c793
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
6 deletions
+32
-6
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+4
-3
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+28
-3
No files found.
vllm/model_executor/layers/layernorm.py
View file @
1e911dbd
...
...
@@ -135,9 +135,10 @@ if current_platform.is_rocm():
def
dispatch_rocm_rmsnorm_func
(
with_fused_add
:
bool
,
dtype
:
torch
.
dtype
):
use_aiter
=
is_rocm_aiter_rmsnorm_enabled
()
and
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
# use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
# torch.float16, torch.bfloat16
# ]
use_aiter
=
False
if
use_aiter
and
with_fused_add
:
return
torch
.
ops
.
vllm
.
rocm_aiter_rmsnorm2d_fwd_with_add
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
1e911dbd
...
...
@@ -37,6 +37,9 @@ from transformers import PretrainedConfig
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.platforms
import
current_platform
import
vllm.envs
as
envs
from
vllm.utils
import
direct_register_custom_op
if
current_platform
.
is_cuda
():
from
vllm.vllm_flash_attn.layers.rotary
import
apply_rotary_emb
...
...
@@ -700,7 +703,7 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
sin
=
freqs
.
sin
()
*
mscale
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -841,6 +844,24 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
sin
=
(
freqs
.
sin
()
*
self
.
mscale
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
rotary_embedding_deepseek_fuse
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
)
->
None
:
from
lightop
import
op
op
.
rotary_embedding_deepseek_fuse
(
positions
,
query
,
key
,
head_size
,
cos_sin_cache
,
is_neox_style
)
def
rotary_embedding_deepseek_fuse_fake
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox_style
:
bool
)
->
None
:
pass
direct_register_custom_op
(
op_name
=
"rotary_embedding_deepseek_fuse"
,
op_func
=
rotary_embedding_deepseek_fuse
,
mutates_args
=
[],
fake_impl
=
rotary_embedding_deepseek_fuse_fake
,
)
def
forward
(
self
,
...
...
@@ -880,8 +901,12 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
BLOCK_SIZE
=
BLOCK_SIZE
,
num_warps
=
1
)
call
(
query
)
call
(
key
)
# if envs.VLLM_USE_LIGHTOP:
if
False
:
torch
.
ops
.
vllm
.
rotary_embedding_deepseek_fuse
(
positions
,
query
,
key
,
self
.
head_size
,
self
.
cos_sin_cache
,
self
.
is_neox_style
)
else
:
call
(
query
)
call
(
key
)
return
query
,
key
else
:
query_rot
=
query
[...,
:
self
.
rotary_dim
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment