Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6fc61a0d
Commit
6fc61a0d
authored
Jan 29, 2026
by
zhuwenwen
Browse files
fix fa interface and kvcache
not supported FlashMLASchedMeta
parent
ae59e10f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
53 additions
and
17 deletions
+53
-17
vllm/model_executor/layers/attention/mla_attention.py
vllm/model_executor/layers/attention/mla_attention.py
+8
-3
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+41
-11
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+2
-1
vllm/v1/attention/ops/flashmla.py
vllm/v1/attention/ops/flashmla.py
+2
-2
No files found.
vllm/model_executor/layers/attention/mla_attention.py
View file @
6fc61a0d
...
@@ -1414,9 +1414,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
...
@@ -1414,9 +1414,14 @@ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
self
.
flash_attn_varlen_func
=
flash_attn_varlen_func
self
.
flash_attn_varlen_func
=
flash_attn_varlen_func
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
if
self
.
vllm_flash_attn_version
is
not
None
:
if
self
.
vllm_flash_attn_version
is
not
None
:
self
.
flash_attn_varlen_func
=
functools
.
partial
(
if
current_platform
.
is_rocm
():
flash_attn_varlen_func
,
fa_version
=
self
.
vllm_flash_attn_version
self
.
flash_attn_varlen_func
=
functools
.
partial
(
)
flash_attn_varlen_func
)
else
:
self
.
flash_attn_varlen_func
=
functools
.
partial
(
flash_attn_varlen_func
,
fa_version
=
self
.
vllm_flash_attn_version
)
# For MLA the v head dim is smaller than qk head dim so we pad out
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# v with 0s to match the qk head dim for attention backends that do
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
6fc61a0d
...
@@ -893,7 +893,10 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -893,7 +893,10 @@ class FlashAttentionImpl(AttentionImpl):
):
):
return
return
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
if
not
current_platform
.
is_rocm
():
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
else
:
key_cache
,
value_cache
=
kv_cache
# Reshape the input keys and values and store them in the cache.
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# Skip this if sharing KV cache with an earlier attention layer.
...
@@ -902,16 +905,43 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -902,16 +905,43 @@ class FlashAttentionImpl(AttentionImpl):
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
# actual tokens.
reshape_and_cache_flash
(
if
not
current_platform
.
is_rocm
():
key
,
reshape_and_cache_flash
(
value
,
key
,
key_cache
,
value
,
value_cache
,
key_cache
,
slot_mapping
,
value_cache
,
self
.
kv_cache_dtype
,
slot_mapping
,
layer
.
_k_scale
,
self
.
kv_cache_dtype
,
layer
.
_v_scale
,
layer
.
_k_scale
,
)
layer
.
_v_scale
,
)
else
:
if
envs
.
VLLM_USE_OPT_RESHAPE_AND_CACHE
and
key
.
dtype
==
value
.
dtype
==
torch
.
float16
:
from
lightop
import
reshape_and_cache_cuda
reshape_and_cache_cuda
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
)
else
:
from
vllm.v1.attention.backends.fa_utils
import
reshape_and_cache_cuda
reshape_and_cache_cuda
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
def
_forward_with_dcp
(
def
_forward_with_dcp
(
self
,
self
,
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
6fc61a0d
...
@@ -27,6 +27,7 @@ from vllm.v1.attention.backend import (
...
@@ -27,6 +27,7 @@ from vllm.v1.attention.backend import (
AttentionType
,
AttentionType
,
MultipleOf
,
MultipleOf
,
)
)
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.attention.backends.utils
import
(
reshape_attn_output_for_spec_decode
,
reshape_attn_output_for_spec_decode
,
reshape_query_for_spec_decode
,
reshape_query_for_spec_decode
,
...
@@ -41,7 +42,6 @@ from vllm.v1.attention.ops.flashmla import (
...
@@ -41,7 +42,6 @@ from vllm.v1.attention.ops.flashmla import (
)
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm
import
envs
from
vllm
import
envs
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -320,6 +320,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -320,6 +320,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
cache_seqlens
=
attn_metadata
.
decode
.
seq_lens
,
head_dim_v
=
self
.
kv_lora_rank
,
head_dim_v
=
self
.
kv_lora_rank
,
tile_scheduler_metadata
=
scheduler_metadata
,
tile_scheduler_metadata
=
scheduler_metadata
,
num_splits
=
scheduler_metadata
.
num_splits
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
is_fp8_kvcache
=
False
,
is_fp8_kvcache
=
False
,
...
...
vllm/v1/attention/ops/flashmla.py
View file @
6fc61a0d
...
@@ -100,7 +100,7 @@ def _raise_flashmla_unavailable(*_args, **_kwargs):
...
@@ -100,7 +100,7 @@ def _raise_flashmla_unavailable(*_args, **_kwargs):
if
_is_flashmla_available
()[
0
]:
if
_is_flashmla_available
()[
0
]:
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
from
flash_mla.flash_mla_interface
import
(
# noqa: F401
from
flash_mla.flash_mla_interface
import
(
# noqa: F401
#
FlashMLASchedMeta,
FlashMLASchedMeta
,
# need new flashmla
# flash_attn_varlen_func,
# flash_attn_varlen_func,
# flash_attn_varlen_kvpacked_func,
# flash_attn_varlen_kvpacked_func,
# flash_attn_varlen_qkvpacked_func,
# flash_attn_varlen_qkvpacked_func,
...
@@ -122,7 +122,7 @@ else:
...
@@ -122,7 +122,7 @@ else:
class
FlashMLASchedMeta
:
# type: ignore[no-redef]
class
FlashMLASchedMeta
:
# type: ignore[no-redef]
pass
pass
flash_attn_varlen_func
=
_raise_flashmla_unavailable
# type: ignore[assignment]
flash_attn_varlen_func
=
_raise_flashmla_unavailable
# type: ignore[assignment]
flash_attn_varlen_kvpacked_func
=
_raise_flashmla_unavailable
# type: ignore[assignment]
flash_attn_varlen_kvpacked_func
=
_raise_flashmla_unavailable
# type: ignore[assignment]
flash_attn_varlen_qkvpacked_func
=
_raise_flashmla_unavailable
# type: ignore[assignment]
flash_attn_varlen_qkvpacked_func
=
_raise_flashmla_unavailable
# type: ignore[assignment]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment