"vscode:/vscode.git/clone" did not exist on "b0e1854485b1a0fe455fdedc4a72bb79980f4fbe"
Commit 50d144c9 authored by Tri Dao's avatar Tri Dao
Browse files

Mention Alibi in README

parent 8448c028
...@@ -82,7 +82,8 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func ...@@ -82,7 +82,8 @@ from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
``` ```
```python ```python
flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1)): flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), alibi_slopes=None):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
...@@ -96,13 +97,16 @@ Arguments: ...@@ -96,13 +97,16 @@ Arguments:
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to
the attention score of query i and key j.
Return: Return:
out: (batch_size, seqlen, nheads, headdim). out: (batch_size, seqlen, nheads, headdim).
""" """
``` ```
```python ```python
flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1)): flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
window_size=(-1, -1), alibi_slopes=None):
"""dropout_p should be set to 0.0 during evaluation """dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
...@@ -121,6 +125,9 @@ Arguments: ...@@ -121,6 +125,9 @@ Arguments:
Default to 1 / sqrt(headdim). Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
window_size: (left, right). If not (-1, -1), implements sliding window local attention. window_size: (left, right). If not (-1, -1), implements sliding window local attention.
alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
(-alibi_slope * |i + seqlen_k - seqlen_q - j|)
is added to the attention score of query i and key j.
Return: Return:
out: (batch_size, seqlen, nheads, headdim). out: (batch_size, seqlen, nheads, headdim).
""" """
...@@ -141,6 +148,7 @@ def flash_attn_with_kvcache( ...@@ -141,6 +148,7 @@ def flash_attn_with_kvcache(
causal=False, causal=False,
window_size=(-1, -1), # -1 means infinite context window window_size=(-1, -1), # -1 means infinite context window
rotary_interleaved=True, rotary_interleaved=True,
alibi_slopes=None,
): ):
""" """
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
...@@ -183,10 +191,9 @@ def flash_attn_with_kvcache( ...@@ -183,10 +191,9 @@ def flash_attn_with_kvcache(
If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False,
rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1
(i.e. GPT-NeoX style). (i.e. GPT-NeoX style).
num_splits: int. If > 1, split the key/value into this many chunks along the sequence. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of
If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic (-alibi_slope * |i + seqlen_k - seqlen_q - j|)
to automatically determine the number of splits. is added to the attention score of query i and key j.
Don't change this unless you know what you are doing.
Return: Return:
out: (batch_size, seqlen, nheads, headdim). out: (batch_size, seqlen, nheads, headdim).
...@@ -262,6 +269,10 @@ Implement sliding window attention (i.e., local attention). Thanks to [Mistral ...@@ -262,6 +269,10 @@ Implement sliding window attention (i.e., local attention). Thanks to [Mistral
AI](https://mistral.ai/) and in particular Timothée Lacroix for this AI](https://mistral.ai/) and in particular Timothée Lacroix for this
contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model. contribution. Sliding window was used in the [Mistral 7B](https://mistral.ai/news/announcing-mistral-7b/) model.
### 2.4: ALiBi (attention with linear bias)
Implement ALiBi (Press et el., 2021). Thanks to Sanghun Cho from Kakao Brain for this contribution.
## Performance ## Performance
We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory). We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment