Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c96fc067
Unverified
Commit
c96fc067
authored
Jun 07, 2024
by
Hongxia Yang
Committed by
GitHub
Jun 07, 2024
Browse files
[ROCm][AMD] Use pytorch sdpa math backend to do naive attention (#4965)
parent
b3376e5c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
33 deletions
+29
-33
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+29
-33
No files found.
vllm/attention/backends/rocm_flash_attn.py
View file @
c96fc067
...
@@ -247,7 +247,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -247,7 +247,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
use_naive_attn
=
True
self
.
use_naive_attn
=
True
if
self
.
use_naive_attn
:
if
self
.
use_naive_attn
:
self
.
attn_func
=
_
naive
_attention
self
.
attn_func
=
_
sdpa
_attention
logger
.
debug
(
"Using naive attention in ROCmBackend"
)
logger
.
debug
(
"Using naive attention in ROCmBackend"
)
def
repeat_kv
(
self
,
x
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
def
repeat_kv
(
self
,
x
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
...
@@ -342,11 +342,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -342,11 +342,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# Interleave for MQA workaround.
# Interleave for MQA workaround.
key
=
self
.
repeat_kv
(
key
,
self
.
num_queries_per_kv
)
key
=
self
.
repeat_kv
(
key
,
self
.
num_queries_per_kv
)
value
=
self
.
repeat_kv
(
value
,
self
.
num_queries_per_kv
)
value
=
self
.
repeat_kv
(
value
,
self
.
num_queries_per_kv
)
query
=
query
.
movedim
(
0
,
query
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
# sdpa math backend attention
out
=
self
.
attn_func
(
out
=
self
.
attn_func
(
query
,
query
,
key
,
key
,
value
,
value
,
prefill_meta
.
seq_lens
,
prefill_meta
.
seq_lens
,
num_tokens
,
self
.
num_heads
,
self
.
head_size
,
self
.
scale
,
self
.
scale
,
)
)
else
:
else
:
...
@@ -402,45 +409,34 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -402,45 +409,34 @@ class ROCmFlashAttentionImpl(AttentionImpl):
return
output
.
view
(
num_tokens
,
hidden_size
)
return
output
.
view
(
num_tokens
,
hidden_size
)
def
_
naive
_attention
(
def
_
sdpa
_attention
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
seq_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
scale
:
float
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
start
=
0
start
=
0
for
_
,
seq_len
in
enumerate
(
seq_lens
):
output
=
torch
.
empty
((
num_tokens
,
num_heads
,
head_size
),
dtype
=
query
.
dtype
,
device
=
query
.
device
)
for
seq_len
in
seq_lens
:
end
=
start
+
seq_len
end
=
start
+
seq_len
out
=
_naive_masked_attention
(
with
torch
.
backends
.
cuda
.
sdp_kernel
(
enable_math
=
True
,
query
[
start
:
end
],
enable_flash
=
False
,
key
[
start
:
end
],
enable_mem_efficient
=
False
):
value
[
start
:
end
],
sub_out
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
scale
,
query
[:,
start
:
end
,
:],
)
key
[:,
start
:
end
,
:],
# TODO(woosuk): Unnecessary copy. Optimize.
value
[:,
start
:
end
,
:],
output
[
start
:
end
].
copy_
(
out
)
dropout_p
=
0.0
,
start
+=
seq_len
is_causal
=
True
,
scale
=
scale
).
movedim
(
query
.
dim
()
-
2
,
0
)
output
[
start
:
end
,
:,
:]
=
sub_out
start
=
end
return
output
return
output
def
_naive_masked_attention
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
scale
:
float
,
)
->
torch
.
Tensor
:
seq_len
,
head_size
,
head_dim
=
query
.
shape
attn_mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
,
dtype
=
query
.
dtype
,
device
=
query
.
device
),
diagonal
=
1
)
attn_mask
=
attn_mask
*
torch
.
finfo
(
query
.
dtype
).
min
attn_weights
=
scale
*
torch
.
einsum
(
"qhd,khd->hqk"
,
query
,
key
).
float
()
attn_weights
=
attn_weights
+
attn_mask
.
float
()
attn_weights
=
torch
.
softmax
(
attn_weights
,
dim
=-
1
).
to
(
value
.
dtype
)
out
=
torch
.
einsum
(
"hqk,khd->qhd"
,
attn_weights
,
value
)
return
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment