Unverified Commit 50b788a1 authored by Zhewen Li's avatar Zhewen Li Committed by GitHub
Browse files

[CI/Build] Fix AMD CI: test_cpu_gpu.py (#27388)


Signed-off-by: default avatarzhewenli <zhewenli@meta.com>
parent fc059c70
...@@ -8,11 +8,20 @@ import torch ...@@ -8,11 +8,20 @@ import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.attention.backends.flashinfer import FlashInferBackend
from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLABackend
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler
BACKENDS_TO_TEST = [FlashAttentionBackend]
if not current_platform.is_rocm():
from vllm.v1.attention.backends.flashinfer import FlashInferBackend
BACKENDS_TO_TEST.append(FlashInferBackend)
from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLABackend
BACKENDS_TO_TEST.append(FlashAttnMLABackend)
NUM_GPU_BLOCKS = [64] NUM_GPU_BLOCKS = [64]
NUM_CPU_BLOCKS = [256] NUM_CPU_BLOCKS = [256]
GPU_BLOCK_SIZES = [16] GPU_BLOCK_SIZES = [16]
...@@ -55,8 +64,8 @@ def test_transfer( ...@@ -55,8 +64,8 @@ def test_transfer(
) -> None: ) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
# create per-layer GPU KV caches # create per-layer GPU KV caches based on available attn_backends
attn_backends_list = [FlashAttentionBackend, FlashInferBackend, FlashAttnMLABackend] attn_backends_list = BACKENDS_TO_TEST
gpu_caches = {} gpu_caches = {}
attn_backends = {} attn_backends = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment