Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
23a1946e
Unverified
Commit
23a1946e
authored
Dec 19, 2025
by
Shanshan Shen
Committed by
GitHub
Dec 19, 2025
Browse files
[CustomOp][Refactor] Extract common methods for ApplyRotaryEmb CustomOp (#31021)
Signed-off-by:
shen-shanshan
<
467638484@qq.com
>
parent
b5545d9d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
28 deletions
+35
-28
vllm/model_executor/layers/rotary_embedding/common.py
vllm/model_executor/layers/rotary_embedding/common.py
+35
-28
No files found.
vllm/model_executor/layers/rotary_embedding/common.py
View file @
23a1946e
...
@@ -178,6 +178,37 @@ class ApplyRotaryEmb(CustomOp):
...
@@ -178,6 +178,37 @@ class ApplyRotaryEmb(CustomOp):
output
=
output
.
to
(
origin_dtype
)
output
=
output
.
to
(
origin_dtype
)
return
output
return
output
def
_pre_process
(
self
,
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Size
,
torch
.
dtype
]:
origin_shape
=
x
.
shape
if
len
(
origin_shape
)
==
3
:
# x: [seq_len, num_heads, head_size]
x
=
x
.
unsqueeze
(
0
)
origin_dtype
=
x
.
dtype
if
self
.
enable_fp32_compute
:
x
=
x
.
float
()
cos
=
cos
.
float
()
sin
=
sin
.
float
()
return
x
,
cos
,
sin
,
origin_shape
,
origin_dtype
def
_post_process
(
self
,
output
:
torch
.
Tensor
,
origin_shape
:
torch
.
Size
,
origin_dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
if
len
(
origin_shape
)
==
3
:
output
=
output
.
squeeze
(
0
)
if
self
.
enable_fp32_compute
:
output
=
output
.
to
(
origin_dtype
)
return
output
def
forward_native
(
def
forward_native
(
self
,
self
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
@@ -197,16 +228,7 @@ class ApplyRotaryEmb(CustomOp):
...
@@ -197,16 +228,7 @@ class ApplyRotaryEmb(CustomOp):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
from
vllm.vllm_flash_attn.layers.rotary
import
apply_rotary_emb
from
vllm.vllm_flash_attn.layers.rotary
import
apply_rotary_emb
origin_dtype
=
x
.
dtype
x
,
cos
,
sin
,
origin_shape
,
origin_dtype
=
self
.
_pre_process
(
x
,
cos
,
sin
)
if
self
.
enable_fp32_compute
:
x
=
x
.
float
()
cos
=
cos
.
float
()
sin
=
sin
.
float
()
origin_shape
=
x
.
shape
if
len
(
origin_shape
)
==
3
:
# x: [seq_len, num_heads, head_size]
x
=
x
.
unsqueeze
(
0
)
"""
"""
Arguments of apply_rotary_emb() in vllm_flash_attn:
Arguments of apply_rotary_emb() in vllm_flash_attn:
...
@@ -218,10 +240,7 @@ class ApplyRotaryEmb(CustomOp):
...
@@ -218,10 +240,7 @@ class ApplyRotaryEmb(CustomOp):
interleaved
=
not
self
.
is_neox_style
interleaved
=
not
self
.
is_neox_style
output
=
apply_rotary_emb
(
x
,
cos
,
sin
,
interleaved
)
output
=
apply_rotary_emb
(
x
,
cos
,
sin
,
interleaved
)
if
len
(
origin_shape
)
==
3
:
output
=
self
.
_post_process
(
output
,
origin_shape
,
origin_dtype
)
output
=
output
.
squeeze
(
0
)
if
self
.
enable_fp32_compute
:
output
=
output
.
to
(
origin_dtype
)
return
output
return
output
def
forward_hip
(
def
forward_hip
(
...
@@ -231,16 +250,7 @@ class ApplyRotaryEmb(CustomOp):
...
@@ -231,16 +250,7 @@ class ApplyRotaryEmb(CustomOp):
sin
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
apply_rotary_emb_flash_attn
is
not
None
:
if
self
.
apply_rotary_emb_flash_attn
is
not
None
:
origin_dtype
=
x
.
dtype
x
,
cos
,
sin
,
origin_shape
,
origin_dtype
=
self
.
_pre_process
(
x
,
cos
,
sin
)
if
self
.
enable_fp32_compute
:
x
=
x
.
float
()
cos
=
cos
.
float
()
sin
=
sin
.
float
()
origin_shape
=
x
.
shape
if
len
(
origin_shape
)
==
3
:
# x: [seq_len, num_heads, head_size]
x
=
x
.
unsqueeze
(
0
)
"""
"""
Arguments of apply_rotary() in flash_attn:
Arguments of apply_rotary() in flash_attn:
...
@@ -254,10 +264,7 @@ class ApplyRotaryEmb(CustomOp):
...
@@ -254,10 +264,7 @@ class ApplyRotaryEmb(CustomOp):
x
,
cos
,
sin
,
interleaved
=
interleaved
x
,
cos
,
sin
,
interleaved
=
interleaved
).
type_as
(
x
)
).
type_as
(
x
)
if
len
(
origin_shape
)
==
3
:
output
=
self
.
_post_process
(
output
,
origin_shape
,
origin_dtype
)
output
=
output
.
squeeze
(
0
)
if
self
.
enable_fp32_compute
:
output
=
output
.
to
(
origin_dtype
)
else
:
else
:
# Falling back to PyTorch native implementation.
# Falling back to PyTorch native implementation.
output
=
self
.
forward_native
(
x
,
cos
,
sin
)
output
=
self
.
forward_native
(
x
,
cos
,
sin
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment