Commit 67c37795 authored by Tri Dao's avatar Tri Dao
Browse files

Reorganize directories, add banner figure

parent 7025a092
## FlashAttention - Alpha release (0.1). # FlashAttention
This repository provides the official implementation of FlashAttention from the
following paper.
**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness***
Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré
![FlashAttention](assets/flashattn_banner.pdf)
## Alpha release (0.1).
To compile (requiring NVCC and an A100 GPU): To compile (requiring NVCC and an A100 GPU):
``` ```
...@@ -40,14 +48,14 @@ Our graphs show sequence lengths between 128 and 4096 (when standard attention r ...@@ -40,14 +48,14 @@ Our graphs show sequence lengths between 128 and 4096 (when standard attention r
#### Speedup #### Speedup
![FlashAttention speedup](images/flashattn_speedup.jpg) ![FlashAttention speedup](assets/flashattn_speedup.jpg)
We generally see 2-4X speedup at sequence lengths between 128 and 4K, and we see more speedup when using dropout and masking, since we fuse the kernels. We generally see 2-4X speedup at sequence lengths between 128 and 4K, and we see more speedup when using dropout and masking, since we fuse the kernels.
At sequence lengths that are popular with language models like 512 and 1K, we see speedups up to 4X when using dropout and masking. At sequence lengths that are popular with language models like 512 and 1K, we see speedups up to 4X when using dropout and masking.
#### Memory #### Memory
![FlashAttention memory](images/flashattn_memory.jpg) ![FlashAttention memory](assets/flashattn_memory.jpg)
We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking). We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking).
Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length. Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length.
......
...@@ -7,8 +7,8 @@ import torch.nn.functional as F ...@@ -7,8 +7,8 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from benchmarks.utils import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined from benchmarks.utils import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined
from bert_padding import unpad_input, pad_input from src.bert_padding import unpad_input, pad_input
from flash_attn_interface import flash_attn_func from src.flash_attn_interface import flash_attn_func
def attention_ref(qkv, attn_mask, dropout_p, upcast=False, causal=False): def attention_ref(qkv, attn_mask, dropout_p, upcast=False, causal=False):
......
...@@ -99,27 +99,7 @@ def pytorch_profiler(fn, *inputs, repeats=10): ...@@ -99,27 +99,7 @@ def pytorch_profiler(fn, *inputs, repeats=10):
) as p: ) as p:
# benchmark_forward(repeats, fn, *inputs) # benchmark_forward(repeats, fn, *inputs)
fn(*inputs) fn(*inputs)
print(p.key_averages().table( print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
sort_by="self_cuda_time_total", row_limit=-1))
def convert_data(*tensors, device='cuda'):
tensors = tuple(t.to(device) for t in tensors)
for t in tensors:
if t.is_leaf: t.requires_grad = True
t.retain_grad()
return tensors
def log_backward(output, *inputs):
""" Perform backward pass of output and print gradients of input tensors. """
#print(f"{output=}")
output.sum().backward(retain_graph=True)
print("Gradients:")
for t in inputs:
print(t.grad)
t.grad.zero_()
def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs): def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs):
......
...@@ -4,9 +4,9 @@ import torch.nn as nn ...@@ -4,9 +4,9 @@ import torch.nn as nn
from einops import rearrange from einops import rearrange
from rotary import RotaryEmbedding, RotaryEmbedding2D from src.rotary import RotaryEmbedding, RotaryEmbedding2D
from flash_attn_interface import flash_attn_func from src.flash_attn_interface import flash_attn_func
from bert_padding import unpad_input, pad_input, index_first_axis from src.bert_padding import unpad_input, pad_input, index_first_axis
class FlashAttention(nn.Module): class FlashAttention(nn.Module):
......
...@@ -6,9 +6,10 @@ from einops import rearrange ...@@ -6,9 +6,10 @@ from einops import rearrange
import hydra import hydra
from flash_blocksparse_attn_interface import flash_blocksparse_attn_func from src.flash_blocksparse_attn_interface import flash_blocksparse_attn_func
from flash_blocksparse_attn_interface import convert_blockmask from src.flash_blocksparse_attn_interface import convert_blockmask
from bert_padding import unpad_input, pad_input, index_first_axis from src.bert_padding import unpad_input, pad_input, index_first_axis
class FlashBlocksparseAttention(nn.Module): class FlashBlocksparseAttention(nn.Module):
"""Implement the scaled dot product attention with softmax. """Implement the scaled dot product attention with softmax.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment