Commit 6f1db287 authored by zhuwenwen's avatar zhuwenwen
Browse files

support prefix cache on kme

fix the error in test_moe caused by moe align not supporting 511
multi-modal switching to torch implementation on z100l&k100
parent 384b6bd9
......@@ -734,7 +734,7 @@ def test_moe_align_block_size_opcheck():
@pytest.mark.parametrize("m", [1, 33, 64, 222])
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("k", [128, 512, 1024])
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
......
......@@ -12,6 +12,7 @@ from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
from vllm.utils import SUPPORT_TC
logger = init_logger(__name__)
......@@ -82,6 +83,8 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
if current_platform.is_cuda() or current_platform.is_rocm():
if not SUPPORT_TC:
selected_backend = _Backend.TORCH_SDPA
device_available = current_platform.has_device_capability(80)
if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available
......
......@@ -16,14 +16,17 @@ from vllm.utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
from vllm.utils import SUPPORT_TC
from vllm.utils import is_kme, SUPPORT_TC
if not SUPPORT_TC:
os.environ['VLLM_USE_V1'] = '0'
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
os.environ['VLLM_USE_FLASH_MLA'] = '0'
if is_kme:
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
......@@ -301,6 +304,8 @@ class RocmPlatform(Platform):
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
if is_kme:
os.environ['VLLM_USE_TRITON_FLASH_ATTN'] = '1'
logger.info("Using ROCmFlashAttention backend.")
return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501
......
......@@ -85,6 +85,7 @@ POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
gpuname = torch.cuda.get_device_properties(torch.cuda.current_device()).name
is_kme = gpuname.startswith('K100_AI') or gpuname.startswith('K500SM_AI')
SUPPORT_TC = gpuname.startswith('K100_AI') or gpuname.startswith('K500SM_AI') or gpuname.startswith('BW')
def _generate_random_int8(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment