Unverified Commit 11bbf86f authored by Matt's avatar Matt Committed by GitHub
Browse files

[CI][Hardware][AMD] Fix test_rotary_embedding_mla_cache_fused (#32408)


Signed-off-by: default avatarMatthew Wong <Matthew.Wong2@amd.com>
parent 3c8740aa
...@@ -13,6 +13,7 @@ from tests.kernels.allclose_default import get_default_atol, get_default_rtol ...@@ -13,6 +13,7 @@ from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import set_random_seed
...@@ -68,9 +69,17 @@ def test_concat_and_cache_mla_rope_fused( ...@@ -68,9 +69,17 @@ def test_concat_and_cache_mla_rope_fused(
k_pe = torch.flatten(key[..., :qk_rope_head_dim], start_dim=1).to(device=device) k_pe = torch.flatten(key[..., :qk_rope_head_dim], start_dim=1).to(device=device)
kv_c = torch.flatten(key[..., qk_rope_head_dim:], start_dim=1).to(device=device) kv_c = torch.flatten(key[..., qk_rope_head_dim:], start_dim=1).to(device=device)
# NOTE(woosuk): The reference implementation should be executed first if current_platform.is_rocm():
# because the custom kernel is in-place. # We use forward_hip for the same numerics as the fused custom kernel on ROCm
ref_q_pe, ref_k_pe = rope.forward_native(positions, query, k_pe) # when dtype is FP16. The torch-native implementation implicitly upcasts
# FP16 x FP16 multiplications to FP32 before downcasting them, which leads
# to notable output divergences.
# Clone the tensors because the implementation modifies them in-place
ref_q_pe, ref_k_pe = rope.forward_hip(positions, query.clone(), k_pe.clone())
else:
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_q_pe, ref_k_pe = rope.forward_native(positions, query, k_pe)
assert ref_k_pe is not None assert ref_k_pe is not None
ref_k_pe = torch.flatten(ref_k_pe, start_dim=1).to(device=device) ref_k_pe = torch.flatten(ref_k_pe, start_dim=1).to(device=device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment