Unverified Commit 5e5a7eb1 authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[CI/Build] Make test_attention_selector.py run tests on correct platform (#29064)


Signed-off-by: default avatarRandall Smith <ransmith@amd.com>
Signed-off-by: default avatarrasmith <Randall.Smith@amd.com>
Co-authored-by: default avatarRandall Smith <ransmith@amd.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 3d84ef90
...@@ -7,6 +7,7 @@ import pytest ...@@ -7,6 +7,7 @@ import pytest
import torch import torch
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.rocm import RocmPlatform from vllm.platforms.rocm import RocmPlatform
...@@ -47,9 +48,11 @@ DEVICE_MLA_BLOCK_SIZES = { ...@@ -47,9 +48,11 @@ DEVICE_MLA_BLOCK_SIZES = {
def generate_params(): def generate_params():
is_rocm = current_platform.is_rocm()
params = [] params = []
device_list = ["cuda", "cpu"] if not is_rocm else ["hip", "cpu"]
for use_mla in [True, False]: for use_mla in [True, False]:
for device in ["cuda", "hip", "cpu"]: for device in device_list:
backends = ( backends = (
DEVICE_MLA_BACKENDS[device] DEVICE_MLA_BACKENDS[device]
if use_mla if use_mla
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment