Commit 651e756b authored by zhuwenwen's avatar zhuwenwen
Browse files

the prefix cache interface implemented using fa on kme

parent dc54fefe
...@@ -16,16 +16,13 @@ from vllm.utils import cuda_device_count_stateless ...@@ -16,16 +16,13 @@ from vllm.utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
from vllm.utils import is_kme, SUPPORT_TC from vllm.utils import SUPPORT_TC
if not SUPPORT_TC: if not SUPPORT_TC:
os.environ['VLLM_USE_V1'] = '0' os.environ['VLLM_USE_V1'] = '0'
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0' os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
os.environ['VLLM_USE_FLASH_MLA'] = '0' os.environ['VLLM_USE_FLASH_MLA'] = '0'
if is_kme:
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
...@@ -299,8 +296,6 @@ class RocmPlatform(Platform): ...@@ -299,8 +296,6 @@ class RocmPlatform(Platform):
logger.info("flash_attn is not supported on NAVI GPUs.") logger.info("flash_attn is not supported on NAVI GPUs.")
else: else:
logger.info("%s is not supported in AMD GPUs.", selected_backend) logger.info("%s is not supported in AMD GPUs.", selected_backend)
if is_kme:
os.environ['VLLM_USE_TRITON_FLASH_ATTN'] = '1'
logger.info("Using ROCmFlashAttention backend.") logger.info("Using ROCmFlashAttention backend.")
return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501 return "vllm.attention.backends.rocm_flash_attn.ROCmFlashAttentionBackend" # noqa: E501
......
...@@ -85,7 +85,6 @@ POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 ...@@ -85,7 +85,6 @@ POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
is_kme = any(arch in GPU_ARCH for arch in ["gfx928"])
SUPPORT_TC = any(arch in GPU_ARCH for arch in ["gfx928", "gfx936"]) SUPPORT_TC = any(arch in GPU_ARCH for arch in ["gfx928", "gfx936"])
def _generate_random_int8( def _generate_random_int8(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment