Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
840f7925
Commit
840f7925
authored
Jul 28, 2023
by
Tri Dao
Browse files
[Docs] Fix mention of MQA/GQA in qkvpacked functions
parent
60499abc
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
6 deletions
+4
-6
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+4
-6
No files found.
flash_attn/flash_attn_interface.py
View file @
840f7925
...
@@ -285,10 +285,8 @@ def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=Fal
...
@@ -285,10 +285,8 @@ def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=Fal
If Q, K, V are already stacked into 1 tensor, this function will be faster than
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
of the gradients of Q, K, V.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
For multi-query and grouped-query attention (MQA/GQA), please see
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
flash_attn_kvpacked_func and flash_attn_func.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
Arguments:
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
qkv: (batch_size, seqlen, 3, nheads, headdim)
...
@@ -381,8 +379,8 @@ def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0,
...
@@ -381,8 +379,8 @@ def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0,
If Q, K, V are already stacked into 1 tensor, this function will be faster than
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
of the gradients of Q, K, V.
For
example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
For
multi-query and grouped-query attention (MQA/GQA), please see
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V
.
flash_attn_varlen_kvpacked_func and flash_attn_varlen_func
.
Arguments:
Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment