Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
11bbf86f
Unverified
Commit
11bbf86f
authored
Jan 19, 2026
by
Matt
Committed by
GitHub
Jan 19, 2026
Browse files
[CI][Hardware][AMD] Fix test_rotary_embedding_mla_cache_fused (#32408)
Signed-off-by:
Matthew Wong
<
Matthew.Wong2@amd.com
>
parent
3c8740aa
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
3 deletions
+12
-3
tests/kernels/core/test_rotary_embedding_mla_cache_fused.py
tests/kernels/core/test_rotary_embedding_mla_cache_fused.py
+12
-3
No files found.
tests/kernels/core/test_rotary_embedding_mla_cache_fused.py
View file @
11bbf86f
...
@@ -13,6 +13,7 @@ from tests.kernels.allclose_default import get_default_atol, get_default_rtol
...
@@ -13,6 +13,7 @@ from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
,
opcheck
from
tests.kernels.utils
import
DEFAULT_OPCHECK_TEST_UTILS
,
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.platforms
import
current_platform
from
vllm.utils.torch_utils
import
set_random_seed
from
vllm.utils.torch_utils
import
set_random_seed
...
@@ -68,6 +69,14 @@ def test_concat_and_cache_mla_rope_fused(
...
@@ -68,6 +69,14 @@ def test_concat_and_cache_mla_rope_fused(
k_pe
=
torch
.
flatten
(
key
[...,
:
qk_rope_head_dim
],
start_dim
=
1
).
to
(
device
=
device
)
k_pe
=
torch
.
flatten
(
key
[...,
:
qk_rope_head_dim
],
start_dim
=
1
).
to
(
device
=
device
)
kv_c
=
torch
.
flatten
(
key
[...,
qk_rope_head_dim
:],
start_dim
=
1
).
to
(
device
=
device
)
kv_c
=
torch
.
flatten
(
key
[...,
qk_rope_head_dim
:],
start_dim
=
1
).
to
(
device
=
device
)
if
current_platform
.
is_rocm
():
# We use forward_hip for the same numerics as the fused custom kernel on ROCm
# when dtype is FP16. The torch-native implementation implicitly upcasts
# FP16 x FP16 multiplications to FP32 before downcasting them, which leads
# to notable output divergences.
# Clone the tensors because the implementation modifies them in-place
ref_q_pe
,
ref_k_pe
=
rope
.
forward_hip
(
positions
,
query
.
clone
(),
k_pe
.
clone
())
else
:
# NOTE(woosuk): The reference implementation should be executed first
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
# because the custom kernel is in-place.
ref_q_pe
,
ref_k_pe
=
rope
.
forward_native
(
positions
,
query
,
k_pe
)
ref_q_pe
,
ref_k_pe
=
rope
.
forward_native
(
positions
,
query
,
k_pe
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment