Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9c97a1c3
Unverified
Commit
9c97a1c3
authored
Aug 11, 2025
by
vllmellm
Committed by
GitHub
Aug 10, 2025
Browse files
[ROCm][AITER] Support AITER Rope ops in RotaryEmbedding Module. (#22521)
Signed-off-by:
vllmellm
<
vllm.ellm@embeddedllm.com
>
parent
f919d4cb
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
204 additions
and
10 deletions
+204
-10
vllm/model_executor/layers/rotary_embedding/base.py
vllm/model_executor/layers/rotary_embedding/base.py
+71
-0
vllm/model_executor/layers/rotary_embedding/common.py
vllm/model_executor/layers/rotary_embedding/common.py
+2
-2
vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
...executor/layers/rotary_embedding/deepseek_scaling_rope.py
+4
-8
vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py
...l_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py
+127
-0
No files found.
vllm/model_executor/layers/rotary_embedding/base.py
View file @
9c97a1c3
...
@@ -8,6 +8,7 @@ import torch
...
@@ -8,6 +8,7 @@ import torch
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
.common
import
apply_rotary_emb_dispatch
,
apply_rotary_emb_torch
from
.common
import
apply_rotary_emb_dispatch
,
apply_rotary_emb_torch
from
.rocm_aiter_rope_ops
import
is_rocm_rotary_embedding_enabled
@
CustomOp
.
register
(
"rotary_embedding"
)
@
CustomOp
.
register
(
"rotary_embedding"
)
...
@@ -35,6 +36,7 @@ class RotaryEmbedding(CustomOp):
...
@@ -35,6 +36,7 @@ class RotaryEmbedding(CustomOp):
cache
=
cache
.
to
(
dtype
)
cache
=
cache
.
to
(
dtype
)
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
self
.
is_rocm_aiter_enabled
=
is_rocm_rotary_embedding_enabled
()
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
base
:
float
)
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
"""Compute the inverse frequency."""
...
@@ -119,6 +121,75 @@ class RotaryEmbedding(CustomOp):
...
@@ -119,6 +121,75 @@ class RotaryEmbedding(CustomOp):
self
.
cos_sin_cache
,
self
.
is_neox_style
)
self
.
cos_sin_cache
,
self
.
is_neox_style
)
return
query
,
key
return
query
,
key
def
forward_hip
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
is_nope_first
=
False
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
# currently only rotary embedding ops from AITER package are
# supported for HiP forward.
if
self
.
is_rocm_aiter_enabled
:
return
self
.
forward_hip_rocm_aiter
(
positions
,
query
,
key
,
offsets
,
is_nope_first
)
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
)
def
forward_hip_rocm_aiter
(
self
,
positions
:
torch
.
Tensor
,
# if is_nope_first
# [[batch_size, seq_len, num_heads, nope_size+rope_size]
# if NOT is_nope_first
# [[batch_size, seq_len, num_heads, rope_size+nope_size],
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
is_nope_first
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
self
.
cos_sin_cache
.
device
!=
query
.
device
or
\
self
.
cos_sin_cache
.
dtype
!=
query
.
dtype
:
self
.
cos_sin_cache
=
self
.
cos_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
cos
,
sin
=
self
.
cos_sin_cache
.
chunk
(
2
,
dim
=-
1
)
cos
=
cos
.
unsqueeze
(
-
2
).
unsqueeze
(
-
2
)
sin
=
sin
.
unsqueeze
(
-
2
).
unsqueeze
(
-
2
)
rotate_style
=
0
if
self
.
is_neox_style
else
1
num_tokens
=
positions
.
numel
()
query_shape
=
query
.
shape
query
=
query
.
view
(
1
,
num_tokens
,
-
1
,
self
.
head_size
)
if
key
is
not
None
:
key_shape
=
key
.
shape
key
=
key
.
view
(
1
,
num_tokens
,
-
1
,
self
.
head_size
)
positions
=
positions
.
view
(
*
query
.
shape
[:
2
])
if
offsets
is
not
None
:
offsets
=
offsets
.
view
(
*
query
.
shape
[:
2
])
if
not
is_nope_first
:
query_
=
query
[...,
:
self
.
rotary_dim
]
key_
=
key
[...,
:
self
.
rotary_dim
]
if
key
is
not
None
else
None
else
:
query_
=
query
[...,
-
self
.
rotary_dim
:]
key_
=
key
[...,
-
self
.
rotary_dim
:]
if
key
is
not
None
else
None
if
key_
is
None
:
torch
.
ops
.
vllm
.
rocm_aiter_rotary_emb_without_key_forward_hip
(
positions
,
sin
,
cos
,
query_
,
offsets
,
rotate_style
,
is_nope_first
)
return
query
.
view
(
query_shape
),
None
torch
.
ops
.
vllm
.
rocm_aiter_rotary_emb_with_key_forward_hip
(
positions
,
sin
,
cos
,
query_
,
key_
,
offsets
,
rotate_style
,
is_nope_first
)
return
query
.
view
(
query_shape
),
key
.
view
(
key_shape
)
def
forward_xpu
(
def
forward_xpu
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/rotary_embedding/common.py
View file @
9c97a1c3
...
@@ -99,7 +99,7 @@ def yarn_linear_ramp_mask(low: float, high: float, dim: int,
...
@@ -99,7 +99,7 @@ def yarn_linear_ramp_mask(low: float, high: float, dim: int,
return
ramp_func
return
ramp_func
def
yarn_get_mscale
(
scale
:
float
=
1
)
->
float
:
def
yarn_get_mscale
(
scale
:
float
=
1
,
mscale
:
float
=
1
)
->
float
:
if
scale
<=
1
:
if
scale
<=
1
:
return
1.0
return
1.0
return
0.1
*
math
.
log
(
scale
)
+
1.0
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
View file @
9c97a1c3
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
math
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
...
@@ -10,13 +9,7 @@ from vllm.platforms import current_platform
...
@@ -10,13 +9,7 @@ from vllm.platforms import current_platform
from
.base
import
RotaryEmbedding
from
.base
import
RotaryEmbedding
from
.common
import
(
rotate_gptj
,
rotate_neox
,
yarn_find_correction_range
,
from
.common
import
(
rotate_gptj
,
rotate_neox
,
yarn_find_correction_range
,
yarn_linear_ramp_mask
)
yarn_get_mscale
,
yarn_linear_ramp_mask
)
def
yarn_get_mscale
(
scale
:
float
=
1
,
mscale
:
float
=
1
)
->
float
:
if
scale
<=
1
:
return
1.0
return
0.1
*
mscale
*
math
.
log
(
scale
)
+
1.0
class
DeepseekScalingRotaryEmbedding
(
RotaryEmbedding
):
class
DeepseekScalingRotaryEmbedding
(
RotaryEmbedding
):
...
@@ -96,6 +89,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -96,6 +89,9 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""PyTorch-native implementation equivalent to forward()."""
"""PyTorch-native implementation equivalent to forward()."""
if
self
.
is_rocm_aiter_enabled
:
return
self
.
forward_hip_rocm_aiter
(
positions
,
query
,
key
,
offsets
)
assert
key
is
not
None
assert
key
is
not
None
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
...
...
vllm/model_executor/layers/rotary_embedding/rocm_aiter_rope_ops.py
0 → 100644
View file @
9c97a1c3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Optional
import
torch
import
vllm.envs
as
envs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
def
is_rocm_rotary_embedding_enabled
()
->
bool
:
return
(
current_platform
.
is_rocm
()
and
envs
.
VLLM_ROCM_USE_AITER
)
def
rocm_aiter_rotary_emb_without_key_forward_hip_impl
(
positions
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
rotate_style
:
int
=
0
,
is_nope_first
:
bool
=
False
,
)
->
None
:
import
aiter
as
ops
if
offsets
is
None
:
ops
.
rope_cached_positions_fwd_inplace
(
query
,
cos
,
sin
,
positions
,
rotate_style
,
reuse_freqs_front_part
=
True
,
nope_first
=
is_nope_first
,
)
else
:
ops
.
rope_cached_positions_offsets_fwd_inplace
(
query
,
cos
,
sin
,
positions
,
offsets
,
rotate_style
,
reuse_freqs_front_part
=
True
,
nope_first
=
is_nope_first
,
)
def
rocm_aiter_rotary_emb_with_key_forward_hip_impl
(
positions
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
rotate_style
:
int
=
0
,
is_nope_first
:
bool
=
False
,
)
->
None
:
import
aiter
as
ops
if
offsets
is
None
:
ops
.
rope_cached_positions_2c_fwd_inplace
(
query
,
key
,
cos
,
sin
,
positions
,
rotate_style
,
reuse_freqs_front_part
=
True
,
nope_first
=
is_nope_first
,
)
else
:
ops
.
rope_cached_positions_offsets_2c_fwd_inplace
(
query
,
key
,
cos
,
sin
,
positions
,
offsets
,
rotate_style
,
reuse_freqs_front_part
=
True
,
nope_first
=
is_nope_first
,
)
def
rocm_aiter_rotary_emb_with_key_forward_hip_fake
(
positions
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
rotate_style
:
int
=
0
,
is_nope_first
:
bool
=
False
,
)
->
None
:
pass
def
rocm_aiter_rotary_emb_without_key_forward_hip_fake
(
positions
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
rotate_style
:
int
=
0
,
is_nope_first
:
bool
=
False
,
)
->
None
:
pass
if
is_rocm_rotary_embedding_enabled
():
direct_register_custom_op
(
op_name
=
"rocm_aiter_rotary_emb_with_key_forward_hip"
,
op_func
=
rocm_aiter_rotary_emb_with_key_forward_hip_impl
,
mutates_args
=
[
"key"
,
"query"
],
fake_impl
=
rocm_aiter_rotary_emb_with_key_forward_hip_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_rotary_emb_without_key_forward_hip"
,
op_func
=
rocm_aiter_rotary_emb_without_key_forward_hip_impl
,
mutates_args
=
[
"query"
],
fake_impl
=
rocm_aiter_rotary_emb_without_key_forward_hip_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment