Unverified Commit 3df619ac authored by Jiangyun Zhu's avatar Jiangyun Zhu Committed by GitHub
Browse files

[CI] fix `test_concat_and_cache_mla_rope_fused` (#32117)


Signed-off-by: default avatarzjy0516 <riverclouds.zhu@qq.com>
parent d74132ca
......@@ -13,7 +13,7 @@ from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
from vllm import _custom_ops as ops
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16, torch.float])
......@@ -31,6 +31,7 @@ from vllm.platforms import current_platform
)
@torch.inference_mode()
def test_concat_and_cache_mla_rope_fused(
default_vllm_config,
dtype: torch.dtype,
is_neox_style: bool,
seq_len: int,
......@@ -45,7 +46,7 @@ def test_concat_and_cache_mla_rope_fused(
max_position: int = 8192,
base: float = 10000,
) -> None:
current_platform.seed_everything(seed)
set_random_seed(seed)
torch.set_default_device(device)
rope = RotaryEmbedding(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment