Unverified Commit 3df619ac authored by Jiangyun Zhu's avatar Jiangyun Zhu Committed by GitHub
Browse files

[CI] fix `test_concat_and_cache_mla_rope_fused` (#32117)


Signed-off-by: default avatarzjy0516 <riverclouds.zhu@qq.com>
parent d74132ca
...@@ -13,7 +13,7 @@ from tests.kernels.allclose_default import get_default_atol, get_default_rtol ...@@ -13,7 +13,7 @@ from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16, torch.float]) @pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16, torch.float])
...@@ -31,6 +31,7 @@ from vllm.platforms import current_platform ...@@ -31,6 +31,7 @@ from vllm.platforms import current_platform
) )
@torch.inference_mode() @torch.inference_mode()
def test_concat_and_cache_mla_rope_fused( def test_concat_and_cache_mla_rope_fused(
default_vllm_config,
dtype: torch.dtype, dtype: torch.dtype,
is_neox_style: bool, is_neox_style: bool,
seq_len: int, seq_len: int,
...@@ -45,7 +46,7 @@ def test_concat_and_cache_mla_rope_fused( ...@@ -45,7 +46,7 @@ def test_concat_and_cache_mla_rope_fused(
max_position: int = 8192, max_position: int = 8192,
base: float = 10000, base: float = 10000,
) -> None: ) -> None:
current_platform.seed_everything(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
rope = RotaryEmbedding( rope = RotaryEmbedding(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment