Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d9f83d62
Unverified
Commit
d9f83d62
authored
Mar 12, 2025
by
Sage Moore
Committed by
GitHub
Mar 12, 2025
Browse files
[ROCm] Enable chunked prefill/paged attention in MLA on ROCm (#14316)
Signed-off-by:
Sage Moore
<
sage@neuralmagic.com
>
parent
4a754fcf
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
18 deletions
+4
-18
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+2
-16
vllm/config.py
vllm/config.py
+2
-2
No files found.
vllm/attention/backends/mla/common.py
View file @
d9f83d62
...
@@ -1327,21 +1327,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1327,21 +1327,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
value
=
0
)
if
is_hip
and
envs
.
VLLM_USE_TRITON_FLASH_ATTN
:
if
is_vllm_fa
:
attn_output
,
attn_softmax_lse
=
self
.
triton_fa_func
(
q
,
k
,
v_padded
,
None
,
prefill_metadata
.
query_start_loc
,
prefill_metadata
.
context_chunk_cu_seq_lens
[
i
],
prefill_metadata
.
max_query_len
,
prefill_metadata
.
context_chunk_max_seq_lens
[
i
],
False
,
# causal
self
.
scale
,
None
,
# attn_mask is None unless applying ALiBi mask
)
elif
is_vllm_fa
:
attn_output
,
attn_softmax_lse
=
self
.
flash_attn_varlen_func
(
attn_output
,
attn_softmax_lse
=
self
.
flash_attn_varlen_func
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
...
@@ -1416,7 +1402,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1416,7 +1402,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
value
=
0
)
if
is_hip
and
envs
.
VLLM_USE_TRITON_FLASH_ATTN
:
if
is_hip
and
envs
.
VLLM_USE_TRITON_FLASH_ATTN
and
not
has_context
:
output
=
self
.
triton_fa_func
(
output
=
self
.
triton_fa_func
(
q
,
q
,
k
,
k
,
...
...
vllm/config.py
View file @
d9f83d62
...
@@ -3450,9 +3450,9 @@ class VllmConfig:
...
@@ -3450,9 +3450,9 @@ class VllmConfig:
self
.
compilation_config
.
level
=
CompilationLevel
.
NO_COMPILATION
self
.
compilation_config
.
level
=
CompilationLevel
.
NO_COMPILATION
if
self
.
model_config
and
self
.
model_config
.
use_mla
and
\
if
self
.
model_config
and
self
.
model_config
.
use_mla
and
\
not
current_platform
.
is_cuda
():
not
(
current_platform
.
is_cuda
()
or
current_platform
.
is_rocm
())
:
logger
.
info
(
logger
.
info
(
"MLA is enabled on a non-
cuda
platform; forcing chunked "
"MLA is enabled on a non-
GPU
platform; forcing chunked "
"prefill and prefix caching to be disabled."
)
"prefill and prefix caching to be disabled."
)
self
.
scheduler_config
.
enable_chunked_prefill
=
False
self
.
scheduler_config
.
enable_chunked_prefill
=
False
self
.
scheduler_config
.
chunked_prefill_enabled
=
False
self
.
scheduler_config
.
chunked_prefill_enabled
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment