Commit 1faa2c78 authored by zhuwenwen's avatar zhuwenwen
Browse files

add dca and sparse attention support on rocm

parent a5dcaef9
...@@ -19,8 +19,13 @@ from vllm.attention.backends.flash_attn import (FlashAttentionBackend, ...@@ -19,8 +19,13 @@ from vllm.attention.backends.flash_attn import (FlashAttentionBackend,
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import async_tensor_h2d from vllm.utils import async_tensor_h2d
from vllm.vllm_flash_attn import (flash_attn_varlen_func, from vllm.platforms import current_platform
flash_attn_with_kvcache, sparse_attn_func) if not current_platform.is_rocm():
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache, sparse_attn_func)
else:
from flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache, sparse_attn_func)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder from vllm.worker.model_runner import ModelInputForGPUBuilder
......
...@@ -1107,8 +1107,8 @@ class EngineArgs: ...@@ -1107,8 +1107,8 @@ class EngineArgs:
"Cuda graph is not supported with DualChunkFlashAttention. " "Cuda graph is not supported with DualChunkFlashAttention. "
"To run the model in eager mode, set 'enforce_eager=True' " "To run the model in eager mode, set 'enforce_eager=True' "
"or use '--enforce-eager' in the CLI.") "or use '--enforce-eager' in the CLI.")
assert current_platform.is_cuda(), ( assert current_platform.is_cuda() or current_platform.is_rocm(), (
"DualChunkFlashAttention is only supported on CUDA platform.") "DualChunkFlashAttention is only supported on CUDA/ROCM platform.")
assert not use_v1, ( assert not use_v1, (
"DualChunkFlashAttention is not supported on V1 engine. " "DualChunkFlashAttention is not supported on V1 engine. "
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'") "To run the model in V0 engine, try set 'VLLM_USE_V1=0'")
......
...@@ -296,7 +296,12 @@ class RocmPlatform(Platform): ...@@ -296,7 +296,12 @@ class RocmPlatform(Platform):
else: else:
logger.info_once("Using Triton backend on V1 engine.") logger.info_once("Using Triton backend on V1 engine.")
return TRITON_ATTN_VLLM_V1 return TRITON_ATTN_VLLM_V1
if selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
logger.info("Using DualChunkFlashAttention backend.")
return ("vllm.attention.backends.dual_chunk_flash_attn."
"DualChunkFlashAttentionBackend")
if selected_backend == _Backend.ROCM_FLASH: if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90): if not cls.has_device_capability(90):
# not Instinct series GPUs. # not Instinct series GPUs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment