Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6832707e
Unverified
Commit
6832707e
authored
Mar 06, 2025
by
Michael Goin
Committed by
GitHub
Mar 06, 2025
Browse files
[V1][Bugfix] Standardize quantized kv cache rejection for attention backends (#14221)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
6b2ef5cd
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
59 additions
and
20 deletions
+59
-20
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+4
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+8
-1
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+6
-3
vllm/attention/backends/hpu_attn.py
vllm/attention/backends/hpu_attn.py
+6
-1
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+3
-2
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+3
-2
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+6
-2
vllm/attention/backends/triton_mla.py
vllm/attention/backends/triton_mla.py
+6
-3
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+5
-1
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+6
-4
vllm/v1/attention/backends/mla/triton_mla.py
vllm/v1/attention/backends/mla/triton_mla.py
+6
-1
No files found.
vllm/attention/backends/abstract.py
View file @
6832707e
...
...
@@ -294,3 +294,7 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
is_quantized_kv_cache
(
kv_cache_dtype
:
str
)
->
bool
:
return
kv_cache_dtype
!=
"auto"
vllm/attention/backends/flash_attn.py
View file @
6832707e
...
...
@@ -8,11 +8,15 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import
torch
from
vllm
import
_custom_ops
as
ops
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionType
)
AttentionType
,
is_quantized_kv_cache
)
# yapf: enable
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
CommonAttentionState
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
get_flash_attn_version
,
...
...
@@ -626,6 +630,9 @@ class FlashAttentionImpl(AttentionImpl):
self
.
sliding_window
=
((
sliding_window
-
1
,
0
)
if
sliding_window
is
not
None
else
(
-
1
,
-
1
))
self
.
kv_cache_dtype
=
kv_cache_dtype
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"FlashAttention with FP8 KV cache not yet supported"
)
if
logits_soft_cap
is
None
:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap
=
0
...
...
vllm/attention/backends/flashmla.py
View file @
6832707e
...
...
@@ -6,7 +6,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import
torch
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.backends.abstract
import
(
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.mla.common
import
(
MLACommonBackend
,
MLACommonImpl
,
MLACommonMetadata
,
...
...
@@ -207,6 +208,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"are not implemented for "
"FlashMLAImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"FlashMLA with FP8 KV cache not yet supported"
)
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
...
...
@@ -215,8 +220,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
attn_metadata
:
FlashMLAMetadata
,
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
raise
NotImplementedError
(
"FP8 FlashMLA not yet supported"
)
decode_meta
=
attn_metadata
.
decode_metadata
assert
decode_meta
is
not
None
...
...
vllm/attention/backends/hpu_attn.py
View file @
6832707e
...
...
@@ -15,7 +15,8 @@ from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.ops.hpu_paged_attn
import
(
HPUPagedAttention
,
HPUPagedAttentionMetadata
)
...
...
@@ -158,6 +159,10 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
"are not implemented for "
"HPUAttentionImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"HPUAttention with FP8 KV cache not yet supported"
)
def
forward
(
self
,
layer
:
AttentionLayer
,
...
...
vllm/attention/backends/ipex_attn.py
View file @
6832707e
...
...
@@ -9,7 +9,8 @@ import torch
from
vllm._ipex_ops
import
ipex_ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
...
...
@@ -145,7 +146,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
if
kv_cache_dtype
!=
"auto"
:
if
is_quantized_kv_cache
(
kv_cache_dtype
)
:
raise
NotImplementedError
(
"IPEX backend does not support FP8 KV cache. "
"Please use xFormers backend instead."
)
...
...
vllm/attention/backends/pallas.py
View file @
6832707e
...
...
@@ -8,7 +8,8 @@ import torch_xla.experimental.custom_kernel # Required to register custom ops.
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.utils
import
CommonAttentionState
...
...
@@ -119,7 +120,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
raise
NotImplementedError
(
"Alibi slopes is not supported."
)
if
sliding_window
is
not
None
:
raise
NotImplementedError
(
"Sliding window is not supported."
)
if
kv_cache_dtype
!=
"auto"
:
if
is_quantized_kv_cache
(
kv_cache_dtype
)
:
raise
NotImplementedError
(
"FP8 KV cache dtype is not supported."
)
if
blocksparse_params
is
not
None
:
raise
NotImplementedError
(
"Blocksparse is not supported."
)
...
...
vllm/attention/backends/torch_sdpa.py
View file @
6832707e
...
...
@@ -7,11 +7,15 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import
torch
from
torch.nn.functional
import
scaled_dot_product_attention
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionType
)
AttentionType
,
is_quantized_kv_cache
)
# yapf: enable
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.ops.ipex_attn
import
PagedAttention
from
vllm.attention.ops.paged_attn
import
PagedAttentionMetadata
...
...
@@ -427,7 +431,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
raise
ValueError
(
f
"Head size
{
head_size
}
is not supported by PagedAttention. "
f
"Supported head sizes are:
{
supported_head_sizes
}
."
)
if
kv_cache_dtype
!=
"auto"
:
if
is_quantized_kv_cache
(
kv_cache_dtype
)
:
raise
NotImplementedError
(
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead."
)
...
...
vllm/attention/backends/triton_mla.py
View file @
6832707e
...
...
@@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional, Type
import
torch
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.backends.abstract
import
(
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.mla.common
import
(
MLACommonBackend
,
MLACommonImpl
,
MLACommonMetadata
)
...
...
@@ -58,6 +59,10 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
"are not implemented for "
"TritonMLAImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"TritonMLA with FP8 KV cache not yet supported"
)
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
...
...
@@ -66,8 +71,6 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
attn_metadata
:
MLACommonMetadata
,
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
raise
NotImplementedError
(
"FP8 Triton MLA not yet supported"
)
decode_meta
=
attn_metadata
.
decode_metadata
assert
decode_meta
is
not
None
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
6832707e
...
...
@@ -7,7 +7,8 @@ import numpy as np
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.utils
import
get_flash_attn_version
from
vllm.attention.ops.triton_merge_attn_states
import
merge_attn_states
from
vllm.logger
import
init_logger
...
...
@@ -180,6 +181,9 @@ class FlashAttentionImpl(AttentionImpl):
else
:
self
.
sliding_window
=
(
sliding_window
-
1
,
0
)
self
.
kv_cache_dtype
=
kv_cache_dtype
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"FlashAttention V1 with FP8 KV cache not yet supported"
)
if
logits_soft_cap
is
None
:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap
=
0
...
...
vllm/v1/attention/backends/mla/flashmla.py
View file @
6832707e
...
...
@@ -5,7 +5,8 @@ from typing import Any, Optional
import
torch
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.backends.abstract
import
(
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.ops.flashmla
import
(
flash_mla_with_kvcache
,
get_mla_metadata
,
is_flashmla_supported
)
...
...
@@ -115,6 +116,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"are not implemented for "
"FlashMLAImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"FlashMLA V1 with FP8 KV cache not yet supported"
)
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
...
...
@@ -125,9 +130,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
if
self
.
kv_cache_dtype
.
startswith
(
"fp8"
):
raise
NotImplementedError
(
"FP8 FlashMLA not yet supported"
)
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
\
.
unsqueeze
(
1
)
# Add seqlen dim of 1 (decode)
...
...
vllm/v1/attention/backends/mla/triton_mla.py
View file @
6832707e
...
...
@@ -4,7 +4,8 @@ from typing import Any, Optional
import
torch
from
vllm.attention.backends.abstract
import
AttentionType
from
vllm.attention.backends.abstract
import
(
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.ops.triton_decode_attention
import
decode_attention_fwd
from
vllm.logger
import
init_logger
from
vllm.v1.attention.backends.mla.common
import
(
MLACommonBackend
,
...
...
@@ -61,6 +62,10 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
"are not implemented for "
"TritonMLAImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"TritonMLA V1 with FP8 KV cache not yet supported"
)
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment