Commit 840f7925 authored by Tri Dao's avatar Tri Dao
Browse files

[Docs] Fix mention of MQA/GQA in qkvpacked functions

parent 60499abc
......@@ -285,10 +285,8 @@ def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=Fal
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in KV must be divisible by the number of heads in Q.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
For multi-query and grouped-query attention (MQA/GQA), please see
flash_attn_kvpacked_func and flash_attn_func.
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
......@@ -381,8 +379,8 @@ def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0,
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
For multi-query and grouped-query attention (MQA/GQA), please see
flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment