Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1faa2c78
Commit
1faa2c78
authored
Sep 01, 2025
by
zhuwenwen
Browse files
add dca and sparse attention support on rocm
parent
a5dcaef9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
5 deletions
+15
-5
vllm/attention/backends/dual_chunk_flash_attn.py
vllm/attention/backends/dual_chunk_flash_attn.py
+7
-2
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+2
-2
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+6
-1
No files found.
vllm/attention/backends/dual_chunk_flash_attn.py
View file @
1faa2c78
...
...
@@ -19,8 +19,13 @@ from vllm.attention.backends.flash_attn import (FlashAttentionBackend,
from
vllm.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
vllm.logger
import
init_logger
from
vllm.utils
import
async_tensor_h2d
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
,
sparse_attn_func
)
from
vllm.platforms
import
current_platform
if
not
current_platform
.
is_rocm
():
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
,
sparse_attn_func
)
else
:
from
flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
,
sparse_attn_func
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
...
...
vllm/engine/arg_utils.py
View file @
1faa2c78
...
...
@@ -1107,8 +1107,8 @@ class EngineArgs:
"Cuda graph is not supported with DualChunkFlashAttention. "
"To run the model in eager mode, set 'enforce_eager=True' "
"or use '--enforce-eager' in the CLI."
)
assert
current_platform
.
is_cuda
(),
(
"DualChunkFlashAttention is only supported on CUDA platform."
)
assert
current_platform
.
is_cuda
()
or
current_platform
.
is_rocm
()
,
(
"DualChunkFlashAttention is only supported on CUDA
/ROCM
platform."
)
assert
not
use_v1
,
(
"DualChunkFlashAttention is not supported on V1 engine. "
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'"
)
...
...
vllm/platforms/rocm.py
View file @
1faa2c78
...
...
@@ -296,7 +296,12 @@ class RocmPlatform(Platform):
else
:
logger
.
info_once
(
"Using Triton backend on V1 engine."
)
return
TRITON_ATTN_VLLM_V1
if
selected_backend
==
_Backend
.
DUAL_CHUNK_FLASH_ATTN
:
logger
.
info
(
"Using DualChunkFlashAttention backend."
)
return
(
"vllm.attention.backends.dual_chunk_flash_attn."
"DualChunkFlashAttentionBackend"
)
if
selected_backend
==
_Backend
.
ROCM_FLASH
:
if
not
cls
.
has_device_capability
(
90
):
# not Instinct series GPUs.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment