Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
86bfb6db
Unverified
Commit
86bfb6db
authored
Jan 20, 2025
by
wangxiyuan
Committed by
GitHub
Jan 20, 2025
Browse files
[Misc] Pass `attention` to impl backend (#12218)
Signed-off-by:
wangxiyuan
<
wangxiyuan1007@gmail.com
>
parent
5f0ec393
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
86 additions
and
78 deletions
+86
-78
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+19
-4
vllm/attention/backends/blocksparse_attn.py
vllm/attention/backends/blocksparse_attn.py
+6
-6
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+5
-5
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+8
-8
vllm/attention/backends/hpu_attn.py
vllm/attention/backends/hpu_attn.py
+2
-2
vllm/attention/backends/ipex_attn.py
vllm/attention/backends/ipex_attn.py
+9
-9
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+3
-3
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+10
-10
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+8
-10
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+9
-11
vllm/attention/layer.py
vllm/attention/layer.py
+3
-5
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+4
-5
No files found.
vllm/attention/backends/abstract.py
View file @
86bfb6db
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
fields
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
)
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
List
,
Optional
,
Protocol
,
Set
,
Tuple
,
Type
,
TypeVar
)
import
torch
...
...
@@ -223,6 +223,22 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
raise
NotImplementedError
class
AttentionLayer
(
Protocol
):
_k_scale
:
float
_v_scale
:
float
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
...
class
AttentionImpl
(
ABC
,
Generic
[
T
]):
@
abstractmethod
...
...
@@ -244,13 +260,12 @@ class AttentionImpl(ABC, Generic[T]):
@
abstractmethod
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
T
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
vllm/attention/backends/blocksparse_attn.py
View file @
86bfb6db
...
...
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type
import
torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
CommonAttentionState
,
CommonMetadataBuilder
)
...
...
@@ -358,13 +359,12 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
BlocksparseFlashAttentionMetadata
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
...
...
@@ -401,8 +401,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
layer
.
_
k_scale
,
layer
.
_
v_scale
,
)
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
...
...
@@ -439,8 +439,8 @@ class BlocksparseFlashAttentionImpl(AttentionImpl):
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
k_scale
,
v_scale
,
layer
.
_
k_scale
,
layer
.
_
v_scale
,
tp_rank
=
self
.
tp_rank
,
blocksparse_local_blocks
=
self
.
local_blocks
,
blocksparse_vert_stride
=
self
.
vert_stride
,
...
...
vllm/attention/backends/flash_attn.py
View file @
86bfb6db
...
...
@@ -8,6 +8,7 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionType
)
...
...
@@ -634,13 +635,12 @@ class FlashAttentionImpl(AttentionImpl):
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashAttentionMetadata
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention.
...
...
@@ -657,7 +657,7 @@ class FlashAttentionImpl(AttentionImpl):
NOTE: It in-place updates the output tensor.
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
assert
layer
.
_
k_scale
==
1.0
and
layer
.
_
v_scale
==
1.0
,
(
"key/v_scale is not supported in FlashAttention."
)
assert
output
is
not
None
,
"Output tensor must be provided."
...
...
@@ -709,8 +709,8 @@ class FlashAttentionImpl(AttentionImpl):
kv_cache
[
1
],
updated_slot_mapping
.
flatten
(),
# type: ignore[union-attr]
kv_cache_dtype
,
k_scale
,
v_scale
,
layer
.
_
k_scale
,
layer
.
_
v_scale
,
)
(
num_prefill_query_tokens
,
num_prefill_kv_tokens
,
...
...
vllm/attention/backends/flashinfer.py
View file @
86bfb6db
...
...
@@ -23,6 +23,7 @@ import torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionState
,
AttentionType
)
...
...
@@ -792,13 +793,12 @@ class FlashInferImpl(AttentionImpl):
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashInferMetadata
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
...
...
@@ -826,8 +826,8 @@ class FlashInferImpl(AttentionImpl):
kv_cache
[:,
1
],
attn_metadata
.
slot_mapping
.
flatten
(),
kv_cache_dtype
,
k_scale
,
v_scale
,
layer
.
_
k_scale
,
layer
.
_
v_scale
,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
...
...
@@ -886,8 +886,8 @@ class FlashInferImpl(AttentionImpl):
kv_cache
,
logits_soft_cap
=
logits_soft_cap
,
causal
=
True
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
k_scale
=
layer
.
_
k_scale
,
v_scale
=
layer
.
_
v_scale
,
window_left
=
window_left
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
assert
decode_meta
is
not
None
...
...
@@ -897,8 +897,8 @@ class FlashInferImpl(AttentionImpl):
kv_cache
,
sm_scale
=
softmax_scale
,
logits_soft_cap
=
logits_soft_cap
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
k_scale
=
layer
.
_
k_scale
,
v_scale
=
layer
.
_
v_scale
,
window_left
=
window_left
)
if
prefill_output
is
None
and
decode_output
is
not
None
:
...
...
vllm/attention/backends/hpu_attn.py
View file @
86bfb6db
...
...
@@ -11,6 +11,7 @@ import vllm_hpu_extension.ops as ops
from
vllm_hpu_extension.utils
import
Matmul
,
Softmax
,
VLLMKVCache
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.ops.hpu_paged_attn
import
(
HPUPagedAttention
,
...
...
@@ -152,13 +153,12 @@ class HPUAttentionImpl(AttentionImpl, torch.nn.Module):
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
HPUAttentionMetadata
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
...
...
vllm/attention/backends/ipex_attn.py
View file @
86bfb6db
...
...
@@ -7,6 +7,7 @@ import torch
from
vllm._ipex_ops
import
ipex_ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonAttentionState
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
...
...
@@ -171,13 +172,12 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
IpexAttnMetadata
,
# type: ignore
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with IPEX varlen_attention and PagedAttention.
...
...
@@ -193,7 +193,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert
k_scale
==
1.0
and
v_scale
==
1.0
assert
layer
.
_
k_scale
==
1.0
and
layer
.
_
v_scale
==
1.0
num_tokens
,
hidden_size
=
query
.
shape
# Reshape the query, key, and value tensors.
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
...
...
@@ -210,8 +210,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
value_cache
,
attn_metadata
.
slot_mapping
.
flatten
(),
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
layer
.
_
k_scale
,
layer
.
_
v_scale
,
)
if
attn_metadata
.
is_prompt
:
...
...
@@ -296,8 +296,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
layer
.
_
k_scale
,
layer
.
_
v_scale
,
)
else
:
# Run PagedAttention V2.
...
...
@@ -329,8 +329,8 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
max_seq_len
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
layer
.
_
k_scale
,
layer
.
_
v_scale
,
)
# Reshape the output tensor.
...
...
vllm/attention/backends/pallas.py
View file @
86bfb6db
...
...
@@ -5,6 +5,7 @@ import torch
import
torch_xla.experimental.custom_kernel
# Required to register custom ops.
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
CommonAttentionState
...
...
@@ -150,13 +151,12 @@ class PallasAttentionBackendImpl(AttentionImpl):
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
attn_metadata
:
PallasMetadata
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with Pallas attention.
...
...
@@ -173,7 +173,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
assert
k_scale
==
1.0
and
v_scale
==
1.0
assert
layer
.
_
k_scale
==
1.0
and
layer
.
_
v_scale
==
1.0
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_size
)
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
86bfb6db
...
...
@@ -7,6 +7,7 @@ import torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
CommonAttentionState
,
CommonMetadataBuilder
)
...
...
@@ -414,13 +415,12 @@ class ROCmFlashAttentionImpl(AttentionImpl):
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
ROCmFlashAttentionMetadata
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
...
...
@@ -458,8 +458,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
layer
.
_
k_scale
,
layer
.
_
v_scale
,
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
...
...
@@ -567,8 +567,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
sliding_window
[
0
],
k_scale
,
v_scale
,
layer
.
_
k_scale
,
layer
.
_
v_scale
,
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
...
...
@@ -613,8 +613,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
max_seq_len
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
layer
.
_
k_scale
,
layer
.
_
v_scale
,
)
else
:
output
[
num_prefill_tokens
:]
=
PagedAttention
.
forward_decode
(
...
...
@@ -628,8 +628,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
k_scale
,
v_scale
,
layer
.
_
k_scale
,
layer
.
_
v_scale
,
)
# Reshape the output tensor.
...
...
vllm/attention/backends/torch_sdpa.py
View file @
86bfb6db
...
...
@@ -7,6 +7,7 @@ import torch
from
torch.nn.functional
import
scaled_dot_product_attention
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionMetadataBuilder
,
AttentionType
)
...
...
@@ -429,13 +430,12 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with torch SDPA and PagedAttention.
...
...
@@ -451,7 +451,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns:
shape = [num_tokens, num_heads * head_size]
"""
assert
k_scale
==
1.0
and
v_scale
==
1.0
assert
layer
.
_
k_scale
==
1.0
and
layer
.
_
v_scale
==
1.0
attn_type
=
self
.
attn_type
if
(
attn_type
==
AttentionType
.
ENCODER
and
(
not
attn_metadata
.
is_all_encoder_attn_metadata_set
)):
...
...
@@ -493,11 +493,9 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
# Update self-attention KV cache (prefill/decode)
updated_slot_mapping
=
attn_metadata
.
slot_mapping
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
updated_slot_mapping
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
)
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
updated_slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
)
if
attn_type
!=
AttentionType
.
ENCODER
:
# Decoder self-attention supports chunked prefill.
...
...
@@ -571,8 +569,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
k_scale
,
v_scale
,
layer
.
_
k_scale
,
layer
.
_
v_scale
,
)
# Reshape the output tensor.
...
...
vllm/attention/backends/xformers.py
View file @
86bfb6db
...
...
@@ -10,6 +10,7 @@ from xformers.ops.fmha.attn_bias import (AttentionBias,
LowerTriangularMaskWithTensorBias
)
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionLayer
,
AttentionMetadata
,
AttentionType
)
from
vllm.attention.backends.utils
import
(
CommonAttentionState
,
CommonMetadataBuilder
,
...
...
@@ -412,13 +413,12 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
def
forward
(
self
,
layer
:
AttentionLayer
,
query
:
torch
.
Tensor
,
key
:
Optional
[
torch
.
Tensor
],
value
:
Optional
[
torch
.
Tensor
],
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
"XFormersMetadata"
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
...
...
@@ -524,11 +524,9 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory
# profiling run.
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
updated_slot_mapping
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
)
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
updated_slot_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
)
(
num_prefill_query_tokens
,
num_prefill_kv_tokens
,
num_decode_query_tokens
)
=
\
get_num_prefill_decode_query_kv_tokens
(
attn_metadata
,
attn_type
)
...
...
@@ -580,8 +578,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
prefill_meta
.
max_query_len
,
self
.
alibi_slopes
,
self
.
sliding_window
,
k_scale
,
v_scale
,
layer
.
_
k_scale
,
layer
.
_
v_scale
,
)
assert
output
[:
num_prefill_query_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_query_tokens
]
=
out
...
...
@@ -607,8 +605,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
self
.
num_kv_heads
,
self
.
scale
,
self
.
alibi_slopes
,
k_scale
,
v_scale
,
layer
.
_
k_scale
,
layer
.
_
v_scale
,
)
# Reshape the output tensor.
...
...
vllm/attention/layer.py
View file @
86bfb6db
...
...
@@ -243,8 +243,7 @@ def unified_attention(
attn_metadata
=
forward_context
.
attn_metadata
self
=
forward_context
.
attn_layers
[
layer_name
]
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
return
self
.
impl
.
forward
(
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
self
.
_k_scale
,
self
.
_v_scale
)
return
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
kv_cache
,
attn_metadata
)
def
unified_attention_fake
(
...
...
@@ -276,13 +275,12 @@ def unified_attention_with_output(
attn_metadata
=
forward_context
.
attn_metadata
self
=
forward_context
.
attn_layers
[
layer_name
]
kv_cache
=
self
.
kv_cache
[
forward_context
.
virtual_engine
]
self
.
impl
.
forward
(
query
,
self
.
impl
.
forward
(
self
,
query
,
key
,
value
,
kv_cache
,
attn_metadata
,
self
.
_k_scale
,
self
.
_v_scale
,
output
=
output
)
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
86bfb6db
...
...
@@ -130,13 +130,12 @@ class FlashAttentionImpl(AttentionImpl):
def
forward
(
self
,
layer
:
torch
.
nn
.
Module
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashAttentionMetadata
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention.
...
...
@@ -151,7 +150,7 @@ class FlashAttentionImpl(AttentionImpl):
shape = [num_tokens, num_heads * head_size]
"""
# NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
assert
layer
.
_
k_scale
==
1.0
and
layer
.
_
v_scale
==
1.0
,
(
"key/v_scale is not supported in FlashAttention."
)
assert
output
is
not
None
,
"Output tensor must be provided."
...
...
@@ -183,8 +182,8 @@ class FlashAttentionImpl(AttentionImpl):
value_cache
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
k_scale
,
v_scale
,
layer
.
_
k_scale
,
layer
.
_
v_scale
,
)
# Compute attention and update output up to `num_actual_tokens`.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment