Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
18d23f64
Unverified
Commit
18d23f64
authored
Apr 27, 2024
by
Hongxia Yang
Committed by
GitHub
Apr 26, 2024
Browse files
[ROCm][Hardware][AMD] Enable group query attention for triton FA (#4406)
parent
87f545ba
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
41 deletions
+36
-41
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+25
-28
vllm/attention/ops/triton_flash_attention.py
vllm/attention/ops/triton_flash_attention.py
+11
-13
No files found.
vllm/attention/backends/rocm_flash_attn.py
View file @
18d23f64
...
@@ -253,22 +253,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -253,22 +253,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# triton attention
# triton attention
# When block_tables are not filled, it means q and k are the
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
# prompt, and they have the same length.
if
self
.
use_triton_flash_attn
or
self
.
use_naive_attn
:
if
self
.
use_triton_flash_attn
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Interleave for MQA workaround.
key
=
self
.
repeat_kv
(
key
,
self
.
num_queries_per_kv
)
value
=
self
.
repeat_kv
(
value
,
self
.
num_queries_per_kv
)
if
self
.
use_naive_attn
:
out
=
self
.
attn_func
(
query
,
key
,
value
,
prefill_meta
.
prompt_lens
,
self
.
scale
,
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
out
,
_
=
self
.
attn_func
(
out
,
_
=
self
.
attn_func
(
query
,
query
,
key
,
key
,
...
@@ -281,8 +266,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -281,8 +266,18 @@ class ROCmFlashAttentionImpl(AttentionImpl):
True
,
True
,
self
.
scale
,
self
.
scale
,
)
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
elif
self
.
use_naive_attn
:
output
[:
num_prefill_tokens
]
=
out
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Interleave for MQA workaround.
key
=
self
.
repeat_kv
(
key
,
self
.
num_queries_per_kv
)
value
=
self
.
repeat_kv
(
value
,
self
.
num_queries_per_kv
)
out
=
self
.
attn_func
(
query
,
key
,
value
,
prefill_meta
.
prompt_lens
,
self
.
scale
,
)
else
:
else
:
out
=
self
.
attn_func
(
out
=
self
.
attn_func
(
q
=
query
,
q
=
query
,
...
@@ -295,6 +290,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -295,6 +290,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
)
)
# common code for prefill
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
output
[:
num_prefill_tokens
]
=
out
else
:
else
:
...
...
vllm/attention/ops/triton_flash_attention.py
View file @
18d23f64
...
@@ -293,7 +293,7 @@ def _attn_fwd_inner(
...
@@ -293,7 +293,7 @@ def _attn_fwd_inner(
num_warps
=
4
,
num_warps
=
4
,
),
),
],
],
key
=
[
"hq"
,
"hk"
,
"
IS_CAUSAL
"
,
"
dropout_p
"
,
"
BLOCK_DMODEL
"
],
key
=
[
'
IS_CAUSAL
'
,
'
dropout_p
'
,
'
BLOCK_DMODEL
'
],
)
)
@
triton
.
jit
@
triton
.
jit
def
attn_fwd
(
def
attn_fwd
(
...
@@ -330,8 +330,8 @@ def attn_fwd(
...
@@ -330,8 +330,8 @@ def attn_fwd(
philox_seed
,
philox_seed
,
philox_offset_base
,
philox_offset_base
,
encoded_softmax
,
encoded_softmax
,
hq
,
HQ
:
tl
.
constexpr
,
hk
,
HK
:
tl
.
constexpr
,
ACTUAL_BLOCK_DMODEL
:
tl
.
constexpr
,
ACTUAL_BLOCK_DMODEL
:
tl
.
constexpr
,
MAX_SEQLENS_Q
:
tl
.
constexpr
,
MAX_SEQLENS_Q
:
tl
.
constexpr
,
MAX_SEQLENS_K
:
tl
.
constexpr
,
MAX_SEQLENS_K
:
tl
.
constexpr
,
...
@@ -403,7 +403,7 @@ def attn_fwd(
...
@@ -403,7 +403,7 @@ def attn_fwd(
# We still need to write 0s to the result
# We still need to write 0s to the result
# tl.store(O_block_ptr,
# tl.store(O_block_ptr,
# acc.to(Out.type.element_ty), boundary_check=(0,1))
# acc.to(Out.type.element_ty), boundary_check=(0,1))
# l_ptrs = L + off_z *
hq
* MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# l_ptrs = L + off_z *
HQ
* MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q
# + offs_m
# + offs_m
# We store inf to LSE, not -inf because in the bwd pass,
# We store inf to LSE, not -inf because in the bwd pass,
# we subtract this
# we subtract this
...
@@ -414,11 +414,9 @@ def attn_fwd(
...
@@ -414,11 +414,9 @@ def attn_fwd(
# TODO: Should dropout and return encoded softmax be handled here?
# TODO: Should dropout and return encoded softmax be handled here?
return
return
is_mqa
=
hq
!=
hk
# If MQA / GQA, set the K and V head offsets appropriately.
if
is_mqa
:
# noqa: SIM108
GROUP_SIZE
:
tl
.
constexpr
=
HQ
//
HK
off_h_k
=
off_h_q
%
hk
off_h_k
=
off_h_q
//
GROUP_SIZE
if
GROUP_SIZE
!=
1
else
off_h_q
else
:
off_h_k
=
off_h_q
n_extra_tokens
=
0
n_extra_tokens
=
0
if
seqlen_k
<
BLOCK_N
:
if
seqlen_k
<
BLOCK_N
:
...
@@ -471,7 +469,7 @@ def attn_fwd(
...
@@ -471,7 +469,7 @@ def attn_fwd(
bias_ptr
=
None
bias_ptr
=
None
if
ENABLE_DROPOUT
:
if
ENABLE_DROPOUT
:
batch_philox_offset
=
philox_offset_base
\
batch_philox_offset
=
philox_offset_base
\
+
(
off_z
*
hq
+
off_h_q
)
\
+
(
off_z
*
HQ
+
off_h_q
)
\
*
seqlen_q
*
seqlen_k
*
seqlen_q
*
seqlen_k
else
:
else
:
batch_philox_offset
=
0
batch_philox_offset
=
0
...
@@ -624,7 +622,7 @@ def attn_fwd(
...
@@ -624,7 +622,7 @@ def attn_fwd(
z
=
0.0
z
=
0.0
acc
=
tl
.
where
(
out_ptrs_mask
,
acc
,
z
.
to
(
acc
.
type
.
element_ty
))
acc
=
tl
.
where
(
out_ptrs_mask
,
acc
,
z
.
to
(
acc
.
type
.
element_ty
))
# write back LSE
# write back LSE
# l_ptrs = L + off_z *
hq
* MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# l_ptrs = L + off_z *
HQ
* MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last
# few rows. This is only true for the last M block. For others,
# few rows. This is only true for the last M block. For others,
# overflow_size will be -ve
# overflow_size will be -ve
...
@@ -784,8 +782,8 @@ class _attention(torch.autograd.Function):
...
@@ -784,8 +782,8 @@ class _attention(torch.autograd.Function):
philox_seed
=
philox_seed
,
philox_seed
=
philox_seed
,
philox_offset_base
=
philox_offset
,
philox_offset_base
=
philox_offset
,
encoded_softmax
=
encoded_softmax
,
encoded_softmax
=
encoded_softmax
,
hq
=
nheads_q
,
HQ
=
nheads_q
,
hk
=
nheads_k
,
HK
=
nheads_k
,
ACTUAL_BLOCK_DMODEL
=
head_size
,
ACTUAL_BLOCK_DMODEL
=
head_size
,
MAX_SEQLENS_Q
=
max_seqlens_q
,
MAX_SEQLENS_Q
=
max_seqlens_q
,
MAX_SEQLENS_K
=
max_seqlens_k
,
MAX_SEQLENS_K
=
max_seqlens_k
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment