Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8ae9d4bb
Unverified
Commit
8ae9d4bb
authored
Oct 23, 2025
by
b8zhong
Committed by
GitHub
Oct 23, 2025
Browse files
Revert "[ROCm] Remove vLLM rope dependency & use AITER impl" (#12028)
parent
1c304aa9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
0 additions
and
353 deletions
+0
-353
python/sglang/srt/layers/rotary_embedding.py
python/sglang/srt/layers/rotary_embedding.py
+0
-120
test/srt/test_rope_rocm.py
test/srt/test_rope_rocm.py
+0
-233
No files found.
python/sglang/srt/layers/rotary_embedding.py
View file @
8ae9d4bb
...
...
@@ -124,23 +124,6 @@ class RotaryEmbedding(CustomOp):
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
self
.
_hip_cached_cos
:
Optional
[
torch
.
Tensor
]
=
None
self
.
_hip_cached_sin
:
Optional
[
torch
.
Tensor
]
=
None
if
_use_aiter
:
half_rotary
=
cache
.
shape
[
-
1
]
//
2
cos_cache
=
(
cache
[:,
:
half_rotary
]
.
contiguous
()
.
view
(
self
.
max_position_embeddings
,
1
,
1
,
half_rotary
)
)
sin_cache
=
(
cache
[:,
half_rotary
:]
.
contiguous
()
.
view
(
self
.
max_position_embeddings
,
1
,
1
,
half_rotary
)
)
self
.
register_buffer
(
"_hip_cos_cache"
,
cos_cache
,
persistent
=
False
)
self
.
register_buffer
(
"_hip_sin_cache"
,
sin_cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
...
...
@@ -201,109 +184,6 @@ class RotaryEmbedding(CustomOp):
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
return
query
,
key
def
forward_hip
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
fused_set_kv_buffer_arg
:
Optional
[
"FusedSetKVBufferArg"
]
=
None
,
*
,
is_nope_first
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
if
not
_use_aiter
:
return
self
.
forward_native
(
positions
,
query
,
key
,
offsets
,
fused_set_kv_buffer_arg
)
if
fused_set_kv_buffer_arg
is
not
None
:
raise
NotImplementedError
(
"fused_set_kv_buffer_arg is not supported for HIP path"
)
import
aiter
as
ops
if
not
hasattr
(
self
,
"_hip_cos_cache"
)
or
not
hasattr
(
self
,
"_hip_sin_cache"
):
raise
RuntimeError
(
"HIP caches not initialised"
)
cos
=
self
.
_hip_cached_cos
sin
=
self
.
_hip_cached_sin
if
cos
is
None
or
cos
.
device
!=
query
.
device
or
cos
.
dtype
!=
query
.
dtype
:
cos
=
self
.
_hip_cos_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
sin
=
self
.
_hip_sin_cache
.
to
(
query
.
device
,
dtype
=
query
.
dtype
)
self
.
_hip_cached_cos
=
cos
self
.
_hip_cached_sin
=
sin
rotate_style
=
0
if
self
.
is_neox_style
else
1
num_tokens
=
positions
.
numel
()
query_shape
=
query
.
shape
query
=
query
.
view
(
1
,
num_tokens
,
-
1
,
self
.
head_size
)
key_shape
=
key
.
shape
if
key
is
not
None
else
None
if
key
is
not
None
:
key
=
key
.
view
(
1
,
num_tokens
,
-
1
,
self
.
head_size
)
positions
=
positions
.
view
(
*
query
.
shape
[:
2
])
if
offsets
is
not
None
:
offsets
=
offsets
.
view
(
*
query
.
shape
[:
2
])
if
not
is_nope_first
:
query_rot
=
query
[...,
:
self
.
rotary_dim
]
key_rot
=
key
[...,
:
self
.
rotary_dim
]
if
key
is
not
None
else
None
else
:
query_rot
=
query
[...,
-
self
.
rotary_dim
:]
key_rot
=
key
[...,
-
self
.
rotary_dim
:]
if
key
is
not
None
else
None
if
key_rot
is
None
:
if
offsets
is
None
:
ops
.
rope_cached_positions_fwd_inplace
(
query_rot
,
cos
,
sin
,
positions
,
rotate_style
,
reuse_freqs_front_part
=
True
,
nope_first
=
is_nope_first
,
)
else
:
ops
.
rope_cached_positions_offsets_fwd_inplace
(
query_rot
,
cos
,
sin
,
positions
,
offsets
,
rotate_style
,
reuse_freqs_front_part
=
True
,
nope_first
=
is_nope_first
,
)
return
query
.
view
(
query_shape
),
None
if
offsets
is
None
:
ops
.
rope_cached_positions_2c_fwd_inplace
(
query_rot
,
key_rot
,
cos
,
sin
,
positions
,
rotate_style
,
reuse_freqs_front_part
=
True
,
nope_first
=
is_nope_first
,
)
else
:
ops
.
rope_cached_positions_offsets_2c_fwd_inplace
(
query_rot
,
key_rot
,
cos
,
sin
,
positions
,
offsets
,
rotate_style
,
reuse_freqs_front_part
=
True
,
nope_first
=
is_nope_first
,
)
return
query
.
view
(
query_shape
),
key
.
view
(
key_shape
)
if
key
is
not
None
else
None
def
forward_npu
(
self
,
positions
:
torch
.
Tensor
,
...
...
test/srt/test_rope_rocm.py
View file @
8ae9d4bb
...
...
@@ -111,239 +111,6 @@ class TestRotaryEmbeddingAITer(CustomTestCase):
with
self
.
subTest
(
case
=
case
):
self
.
_run_case_aiter
(
*
case
)
def
test_ops_equivalence_basic
(
self
)
->
None
:
import
aiter
as
ops
from
aiter.rotary_embedding
import
RotaryEmbedding
as
AiterRotaryEmbedding
(
head_size
,
rotary_dim
,
max_pos
,
base
,
is_neox
,
dtype
,
device
,
bs
,
seq_len
,
num_q
,
num_kv
,
)
=
(
128
,
64
,
2048
,
10000
,
True
,
torch
.
bfloat16
,
"cuda"
,
2
,
32
,
4
,
2
,
)
rope
=
AiterRotaryEmbedding
(
head_size
,
rotary_dim
,
max_pos
,
base
,
is_neox
,
dtype
).
to
(
device
)
positions
=
torch
.
arange
(
seq_len
,
device
=
device
).
repeat
(
bs
)
num_tokens
=
positions
.
numel
()
q2d
=
torch
.
randn
(
num_tokens
,
num_q
*
head_size
,
dtype
=
dtype
,
device
=
device
)
k2d
=
torch
.
randn
(
num_tokens
,
num_kv
*
head_size
,
dtype
=
dtype
,
device
=
device
)
q_ref
,
k_ref
=
rope
.
forward_hip
(
positions
.
clone
(),
q2d
.
clone
(),
k2d
.
clone
())
q_sbhd
=
q2d
.
view
(
1
,
num_tokens
,
num_q
,
head_size
)
k_sbhd
=
k2d
.
view
(
1
,
num_tokens
,
num_kv
,
head_size
)
cos
=
rope
.
cos_cache
.
to
(
device
=
device
,
dtype
=
dtype
)
sin
=
rope
.
sin_cache
.
to
(
device
=
device
,
dtype
=
dtype
)
pos_b_s
=
positions
.
view
(
1
,
num_tokens
)
rotate_style
=
0
if
is_neox
else
1
ops
.
rope_cached_positions_2c_fwd_inplace
(
q_sbhd
,
k_sbhd
,
cos
,
sin
,
pos_b_s
,
rotate_style
,
reuse_freqs_front_part
=
True
,
nope_first
=
False
,
)
self
.
assertTrue
(
q_ref
.
shape
==
q2d
.
shape
)
self
.
assertTrue
(
k_ref
.
shape
==
k2d
.
shape
)
torch
.
testing
.
assert_close
(
q_ref
,
q_sbhd
.
view_as
(
q2d
),
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
k_ref
,
k_sbhd
.
view_as
(
k2d
),
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_ops_equivalence_nope_first
(
self
)
->
None
:
import
aiter
as
ops
from
aiter.rotary_embedding
import
RotaryEmbedding
as
AiterRotaryEmbedding
(
head_size
,
rotary_dim
,
max_pos
,
base
,
is_neox
,
dtype
,
device
,
bs
,
seq_len
,
num_q
,
num_kv
,
)
=
(
128
,
64
,
2048
,
10000
,
True
,
torch
.
bfloat16
,
"cuda"
,
1
,
16
,
2
,
2
,
)
rope
=
AiterRotaryEmbedding
(
head_size
,
rotary_dim
,
max_pos
,
base
,
is_neox
,
dtype
).
to
(
device
)
positions
=
torch
.
arange
(
seq_len
,
device
=
device
).
repeat
(
bs
)
num_tokens
=
positions
.
numel
()
q2d
=
torch
.
randn
(
num_tokens
,
num_q
*
head_size
,
dtype
=
dtype
,
device
=
device
)
k2d
=
torch
.
randn
(
num_tokens
,
num_kv
*
head_size
,
dtype
=
dtype
,
device
=
device
)
q_ref
,
k_ref
=
rope
.
forward_hip
(
positions
.
clone
(),
q2d
.
clone
(),
k2d
.
clone
(),
is_nope_first
=
True
)
q_sbhd
=
q2d
.
view
(
1
,
num_tokens
,
num_q
,
head_size
)
k_sbhd
=
k2d
.
view
(
1
,
num_tokens
,
num_kv
,
head_size
)
cos
=
rope
.
cos_cache
.
to
(
device
=
device
,
dtype
=
dtype
)
sin
=
rope
.
sin_cache
.
to
(
device
=
device
,
dtype
=
dtype
)
pos_b_s
=
positions
.
view
(
1
,
num_tokens
)
rotate_style
=
0
if
is_neox
else
1
q_rot
=
q_sbhd
[...,
-
rotary_dim
:]
k_rot
=
k_sbhd
[...,
-
rotary_dim
:]
ops
.
rope_cached_positions_2c_fwd_inplace
(
q_rot
,
k_rot
,
cos
,
sin
,
pos_b_s
,
rotate_style
,
reuse_freqs_front_part
=
True
,
nope_first
=
True
,
)
torch
.
testing
.
assert_close
(
q_ref
,
q_sbhd
.
view_as
(
q2d
),
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
k_ref
,
k_sbhd
.
view_as
(
k2d
),
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_sglang_rotary_embedding_forward_hip_matches_native
(
self
)
->
None
:
from
sglang.srt.layers.rotary_embedding
import
(
RotaryEmbedding
as
SglRotaryEmbedding
,
)
(
head_size
,
rotary_dim
,
max_pos
,
base
,
is_neox
,
dtype
,
device
,
bs
,
seq_len
,
num_q
,
num_kv
,
)
=
(
128
,
64
,
2048
,
10000
,
True
,
torch
.
bfloat16
,
"cuda"
,
2
,
64
,
4
,
2
,
)
rope
=
SglRotaryEmbedding
(
head_size
,
rotary_dim
,
max_pos
,
base
,
is_neox
,
dtype
).
to
(
device
)
positions
=
torch
.
arange
(
seq_len
,
device
=
device
).
repeat
(
bs
)
q
=
torch
.
randn
(
bs
*
seq_len
,
num_q
*
head_size
,
dtype
=
dtype
,
device
=
device
)
k
=
torch
.
randn
(
bs
*
seq_len
,
num_kv
*
head_size
,
dtype
=
dtype
,
device
=
device
)
q_ref
,
k_ref
=
rope
.
forward_native
(
positions
.
clone
(),
q
.
clone
(),
k
.
clone
())
q_hip
,
k_hip
=
rope
.
forward_hip
(
positions
.
clone
(),
q
.
clone
(),
k
.
clone
())
torch
.
testing
.
assert_close
(
q_ref
,
q_hip
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
k_ref
,
k_hip
,
atol
=
1e-2
,
rtol
=
1e-2
)
def
test_llama3_rotary_embedding_forward_hip_matches_native
(
self
)
->
None
:
from
sglang.srt.layers.rotary_embedding
import
get_rope
as
sgl_get_rope
(
head_size
,
rotary_dim
,
max_pos
,
base
,
is_neox
,
dtype
,
device
,
bs
,
seq_len
,
num_q
,
num_kv
,
)
=
(
128
,
128
,
2048
,
10000
,
True
,
torch
.
bfloat16
,
"cuda"
,
2
,
64
,
4
,
2
,
)
rope
=
sgl_get_rope
(
head_size
,
rotary_dim
,
max_pos
,
base
,
is_neox
,
rope_scaling
=
{
"rope_type"
:
"llama3"
,
"factor"
:
1.0
,
"low_freq_factor"
:
1.0
,
"high_freq_factor"
:
1.0
,
"original_max_position_embeddings"
:
max_pos
,
},
dtype
=
dtype
,
).
to
(
device
)
positions
=
torch
.
arange
(
seq_len
,
device
=
device
).
repeat
(
bs
)
q
=
torch
.
randn
(
bs
*
seq_len
,
num_q
*
head_size
,
dtype
=
dtype
,
device
=
device
)
k
=
torch
.
randn
(
bs
*
seq_len
,
num_kv
*
head_size
,
dtype
=
dtype
,
device
=
device
)
q_ref
,
k_ref
=
rope
.
forward_native
(
positions
.
clone
(),
q
.
clone
(),
k
.
clone
())
q_hip
,
k_hip
=
rope
.
forward_hip
(
positions
.
clone
(),
q
.
clone
(),
k
.
clone
())
torch
.
testing
.
assert_close
(
q_ref
,
q_hip
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
k_ref
,
k_hip
,
atol
=
1e-2
,
rtol
=
1e-2
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment