Commit 1faa2c78 authored by zhuwenwen's avatar zhuwenwen
Browse files

add dca and sparse attention support on rocm

parent a5dcaef9
......@@ -19,7 +19,12 @@ from vllm.attention.backends.flash_attn import (FlashAttentionBackend,
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.utils import async_tensor_h2d
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
from vllm.platforms import current_platform
if not current_platform.is_rocm():
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache, sparse_attn_func)
else:
from flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache, sparse_attn_func)
if TYPE_CHECKING:
......
......@@ -1107,8 +1107,8 @@ class EngineArgs:
"Cuda graph is not supported with DualChunkFlashAttention. "
"To run the model in eager mode, set 'enforce_eager=True' "
"or use '--enforce-eager' in the CLI.")
assert current_platform.is_cuda(), (
"DualChunkFlashAttention is only supported on CUDA platform.")
assert current_platform.is_cuda() or current_platform.is_rocm(), (
"DualChunkFlashAttention is only supported on CUDA/ROCM platform.")
assert not use_v1, (
"DualChunkFlashAttention is not supported on V1 engine. "
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'")
......
......@@ -297,6 +297,11 @@ class RocmPlatform(Platform):
logger.info_once("Using Triton backend on V1 engine.")
return TRITON_ATTN_VLLM_V1
if selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
logger.info("Using DualChunkFlashAttention backend.")
return ("vllm.attention.backends.dual_chunk_flash_attn."
"DualChunkFlashAttentionBackend")
if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90):
# not Instinct series GPUs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment