Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aaec845f
Unverified
Commit
aaec845f
authored
Apr 18, 2025
by
Luka Govedič
Committed by
GitHub
Apr 18, 2025
Browse files
[ROCm] [Attention] Cleanup ROCm output passing (#16431)
Signed-off-by:
Luka Govedič
<
lgovedic@redhat.com
>
parent
7bdfd29a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
23 deletions
+18
-23
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+18
-23
No files found.
vllm/attention/backends/rocm_flash_attn.py
View file @
aaec845f
...
@@ -27,6 +27,7 @@ _PARTITION_SIZE_ROCM = 256
...
@@ -27,6 +27,7 @@ _PARTITION_SIZE_ROCM = 256
class
ROCmFlashAttentionBackend
(
AttentionBackend
):
class
ROCmFlashAttentionBackend
(
AttentionBackend
):
accept_output_buffer
:
bool
=
True
@
staticmethod
@
staticmethod
def
get_name
()
->
str
:
def
get_name
()
->
str
:
...
@@ -515,7 +516,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -515,7 +516,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
from
vllm.attention.ops.triton_flash_attention
import
(
# noqa: F401
from
vllm.attention.ops.triton_flash_attention
import
(
# noqa: F401
triton_attention
)
triton_attention
)
self
.
attn_func
=
triton_attention
self
.
triton_
attn_func
=
triton_attention
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
if
self
.
sliding_window
!=
(
-
1
,
-
1
):
if
self
.
sliding_window
!=
(
-
1
,
-
1
):
logger
.
warning
(
"ROCm Triton FA does not currently support "
logger
.
warning
(
"ROCm Triton FA does not currently support "
...
@@ -531,7 +532,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -531,7 +532,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else
:
else
:
try
:
try
:
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
self
.
attn_func
=
flash_attn_varlen_func
self
.
fa_
attn_func
=
flash_attn_varlen_func
logger
.
debug
(
"Using CK FA in ROCmBackend"
)
logger
.
debug
(
"Using CK FA in ROCmBackend"
)
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
self
.
use_naive_attn
=
True
self
.
use_naive_attn
=
True
...
@@ -542,7 +543,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -542,7 +543,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
"ROCm Naive FlashAttention does not support "
"ROCm Naive FlashAttention does not support "
"attention logits soft capping."
)
"attention logits soft capping."
)
self
.
attn_func
=
_sdpa_attention
self
.
sdpa_
attn_func
=
_sdpa_attention
logger
.
debug
(
"Using naive (SDPA) attention in ROCmBackend"
)
logger
.
debug
(
"Using naive (SDPA) attention in ROCmBackend"
)
def
repeat_kv
(
self
,
x
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
def
repeat_kv
(
self
,
x
:
torch
.
Tensor
,
n_rep
:
int
)
->
torch
.
Tensor
:
...
@@ -613,6 +614,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -613,6 +614,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
assert
output
is
not
None
,
"Output tensor must be provided."
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
if
key
is
not
None
:
if
key
is
not
None
:
assert
value
is
not
None
assert
value
is
not
None
...
@@ -656,7 +659,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -656,7 +659,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
assert
attn_metadata
.
num_encoder_tokens
is
not
None
assert
attn_metadata
.
num_encoder_tokens
is
not
None
num_prefill_tokens
=
attn_metadata
.
num_encoder_tokens
num_prefill_tokens
=
attn_metadata
.
num_encoder_tokens
output
=
torch
.
empty_like
(
query
)
# Query for decode. KV is not needed because it is already cached.
# Query for decode. KV is not needed because it is already cached.
decode_query
=
query
[
num_prefill_tokens
:]
decode_query
=
query
[
num_prefill_tokens
:]
# QKV for prefill.
# QKV for prefill.
...
@@ -704,11 +706,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -704,11 +706,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
query
.
dtype
,
query
.
dtype
,
seq_lens
,
seq_lens
,
make_attn_mask
=
causal_mask
)
# type: ignore
make_attn_mask
=
causal_mask
)
# type: ignore
out
,
_
=
self
.
attn_func
(
self
.
triton_
attn_func
(
query
,
query
,
key
,
key
,
value
,
value
,
None
,
output
[:
num_prefill_tokens
]
,
query_seq_start_loc
,
query_seq_start_loc
,
key_seq_start_loc
,
key_seq_start_loc
,
query_max_seq_len
,
query_max_seq_len
,
...
@@ -733,10 +735,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -733,10 +735,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
key
=
key
.
movedim
(
0
,
key
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
value
=
value
.
movedim
(
0
,
value
.
dim
()
-
2
)
# sdpa math backend attention
# sdpa math backend attention
out
=
self
.
attn_func
(
self
.
sdpa_
attn_func
(
query
,
query
,
key
,
key
,
value
,
value
,
output
[:
num_prefill_tokens
],
query_seq_start_loc
,
query_seq_start_loc
,
num_prefill_tokens
,
num_prefill_tokens
,
self
.
num_heads
,
self
.
num_heads
,
...
@@ -745,7 +748,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -745,7 +748,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks
,
attn_masks
,
)
)
else
:
else
:
out
=
self
.
attn_func
(
# upstream FA does not support an output arg, copy
output
[:
num_prefill_tokens
]
=
self
.
fa_attn_func
(
q
=
query
,
q
=
query
,
k
=
key
,
k
=
key
,
v
=
value
,
v
=
value
,
...
@@ -760,12 +764,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -760,12 +764,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
softcap
=
self
.
logits_soft_cap
,
softcap
=
self
.
logits_soft_cap
,
)
)
# common code for prefill
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
if
output
.
shape
[
0
]
>
num_prefill_tokens
:
output
[:
num_prefill_tokens
]
=
out
else
:
output
=
out
else
:
else
:
# prefix-enabled attention -
# prefix-enabled attention -
# not applicable for encoder-only models
# not applicable for encoder-only models
...
@@ -818,14 +816,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -818,14 +816,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
device
=
output
.
device
,
device
=
output
.
device
,
)
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
if
num_prefill_tokens
>
0
:
out
=
output
[
num_prefill_tokens
:]
else
:
out
=
output
query_start_loc
=
None
query_start_loc
=
None
ops
.
paged_attention_rocm
(
ops
.
paged_attention_rocm
(
out
,
out
put
[
num_prefill_tokens
:]
,
exp_sums
,
exp_sums
,
max_logits
,
max_logits
,
tmp_output
,
tmp_output
,
...
@@ -878,7 +872,8 @@ def _sdpa_attention(
...
@@ -878,7 +872,8 @@ def _sdpa_attention(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
seq_lens
:
List
[
int
],
output
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
num_tokens
:
int
,
num_tokens
:
int
,
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
...
@@ -886,9 +881,9 @@ def _sdpa_attention(
...
@@ -886,9 +881,9 @@ def _sdpa_attention(
attn_masks
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
attn_masks
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
start
=
0
start
=
0
output
=
torch
.
empty
(
(
num_tokens
,
num_heads
,
head_size
)
,
assert
output
.
shape
==
(
num_tokens
,
num_heads
,
head_size
)
dtype
=
query
.
dtype
,
assert
output
.
dtype
==
query
.
dtype
device
=
query
.
device
)
assert
output
.
device
==
query
.
device
for
i
,
seq_len
in
enumerate
(
seq_lens
):
for
i
,
seq_len
in
enumerate
(
seq_lens
):
end
=
start
+
seq_len
end
=
start
+
seq_len
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment