Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a5255270
Unverified
Commit
a5255270
authored
Jan 26, 2025
by
Roger Wang
Committed by
GitHub
Jan 26, 2025
Browse files
[Misc] Revert FA on ViT #12355 and #12435 (#12445)
parent
0ee349b5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
37 deletions
+4
-37
vllm/attention/layer.py
vllm/attention/layer.py
+4
-37
No files found.
vllm/attention/layer.py
View file @
a5255270
...
@@ -210,9 +210,6 @@ class MultiHeadAttention(nn.Module):
...
@@ -210,9 +210,6 @@ class MultiHeadAttention(nn.Module):
self
.
scale
=
scale
self
.
scale
=
scale
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
dtype
=
torch
.
get_default_dtype
()
dtype
=
torch
.
get_default_dtype
()
attn_backend
=
get_attn_backend
(
head_size
,
attn_backend
=
get_attn_backend
(
head_size
,
dtype
,
dtype
,
...
@@ -220,12 +217,12 @@ class MultiHeadAttention(nn.Module):
...
@@ -220,12 +217,12 @@ class MultiHeadAttention(nn.Module):
block_size
=
16
,
block_size
=
16
,
is_attention_free
=
False
)
is_attention_free
=
False
)
backend
=
backend_name_to_enum
(
attn_backend
.
get_name
())
backend
=
backend_name_to_enum
(
attn_backend
.
get_name
())
if
backend
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
FLASH_ATTN_VLLM_V1
}:
backend
=
_Backend
.
XFORMERS
self
.
attn_backend
=
backend
if
backend
in
{
self
.
attn_backend
=
backend
if
backend
in
{
_Backend
.
TORCH_SDPA
,
_Backend
.
TORCH_SDPA
,
_Backend
.
XFORMERS
,
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
,
_Backend
.
FLASH_ATTN_VLLM_V1
,
}
else
_Backend
.
TORCH_SDPA
}
else
_Backend
.
TORCH_SDPA
def
forward
(
def
forward
(
...
@@ -235,6 +232,7 @@ class MultiHeadAttention(nn.Module):
...
@@ -235,6 +232,7 @@ class MultiHeadAttention(nn.Module):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Input shape: batch_size x seq_len x hidden_size"""
"""Input shape: batch_size x seq_len x hidden_size"""
# TODO(Isotr0py): Use existing backend implementations and support FA3
bsz
,
q_len
,
_
=
query
.
size
()
bsz
,
q_len
,
_
=
query
.
size
()
kv_len
=
key
.
size
(
1
)
kv_len
=
key
.
size
(
1
)
...
@@ -242,38 +240,7 @@ class MultiHeadAttention(nn.Module):
...
@@ -242,38 +240,7 @@ class MultiHeadAttention(nn.Module):
key
=
key
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
key
=
key
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
bsz
,
kv_len
,
self
.
num_kv_heads
,
self
.
head_size
)
if
(
num_repeat
:
=
self
.
num_queries_per_kv
)
>
1
:
if
self
.
attn_backend
==
_Backend
.
XFORMERS
:
# Handle MQA and GQA
key
=
torch
.
repeat_interleave
(
key
,
num_repeat
,
dim
=
2
)
value
=
torch
.
repeat_interleave
(
value
,
num_repeat
,
dim
=
2
)
if
self
.
attn_backend
in
{
_Backend
.
FLASH_ATTN
,
_Backend
.
FLASH_ATTN_VLLM_V1
,
}:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
cu_seqlens_q
=
torch
.
arange
(
0
,
(
bsz
+
1
)
*
q_len
,
step
=
q_len
,
dtype
=
torch
.
int32
,
device
=
query
.
device
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
bsz
+
1
)
*
kv_len
,
step
=
kv_len
,
dtype
=
torch
.
int32
,
device
=
key
.
device
)
out
=
flash_attn_varlen_func
(
query
.
flatten
(
0
,
1
),
key
.
flatten
(
0
,
1
),
value
.
flatten
(
0
,
1
),
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_q
=
q_len
,
max_seqlen_k
=
kv_len
,
softmax_scale
=
self
.
scale
,
)
out
=
out
.
reshape
(
bsz
,
q_len
,
-
1
)
elif
self
.
attn_backend
==
_Backend
.
XFORMERS
:
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
out
=
xops
.
memory_efficient_attention_forward
(
query
,
out
=
xops
.
memory_efficient_attention_forward
(
query
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment