usage.md 6.12 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# FlashAttention adoption

We've been very happy to see FlashAttention being adopted by many organizations
and research labs to speed up their training / inference (within 6 months after
FlashAttention's release, at the time of writing).
This page contains a partial list of places where FlashAttention is being used.
If you'd like to add links to your organization / product / codebase, please open a
PR or email us. We'd very much like to hear from you!

## Integrated into machine learning frameworks

- Pytorch: [integrated](https://github.com/pytorch/pytorch/pull/81434) into core Pytorch in nn.Transformer.

- Huggingface's [transformers](https://github.com/huggingface/transformers) library.
  [On-going](https://github.com/huggingface/transformers/pull/18439), blogpost
  coming soon.

18
19
20
- Microsoft's [DeepSpeed](https://github.com/microsoft/DeepSpeed):
  FlashAttention is [integrated](https://github.com/microsoft/DeepSpeed/blob/ec13da6ba7cabc44bb4745a64a208b8580792954/deepspeed/ops/transformer/inference/triton_ops.py) into DeepSpeed's inference engine.

Tri Dao's avatar
Tri Dao committed
21
22
23
- Nvidia's [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/pull/267). This
  library is a popular framework on training large transformer language models at scale.

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
- MosaicML [Composer](https://github.com/mosaicml/composer)
  [library](https://www.mosaicml.com/blog/gpt-3-quality-for-500k). Composer is a
  library for efficient neural network training.

## MLPerf benchmarks

[MLPerf](https://mlcommons.org/en/) is a competitive machine learning performance benchmark. FlashAttention
yields the fastest BERT training on cloud instances in MLPerf training 2.0 (June
2022) and MLPerf training 2.1 (November 2022).

- MLPerf 2.0: IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention.

- MLPerf 2.1 -
  collaboration
  between [Azure and Hazy Research](https://techcommunity.microsoft.com/t5/azure-high-performance-computing/azure-collaborates-with-hazy-research-and-nvidia-to-achieve/ba-p/3667511): for the first time, we can train MLPerf BERT
  in under 2 minutes on 16 nodes.

- MLPerf 2.1 -
  [Nvidia](https://developer.nvidia.com/blog/leading-mlperf-training-2-1-with-full-stack-optimizations-for-ai/):
  Nvidia uses techniques from FlashAttention to make their (already extremely optimized) BERT
  implementation go even faster.

- MLPerf 2.1 - [MosaicML](https://www.mosaicml.com/blog/mlperf-nlp-nov2022): FlashAttention
  helps train BERT 2.7x faster in the open division.

## Language model training & inference

Tri Dao's avatar
Tri Dao committed
51
52
53
54
55
- [PubMedGPT 2.7B](https://crfm.stanford.edu/2022/12/15/pubmedgpt.html), a
  domain-specific LLM for biomedicine, by Stanford CRFM, trained on
  [MosaicML](https://www.mosaicml.com/blog/introducing-pubmed-gpt) Cloud. Just
  using FlashAttention nearly halves the total training time.

56
57
58
59
- Meta's
  [AITemplate](https://ai.facebook.com/blog/gpu-inference-engine-nvidia-amd-open-source/)
  uses FlashAttention as part of their approach to speed up Transformer
  inference (up to 5.3x on BERT).
60

61
62
63
64
65
- Nvidia's [FasterTransformer](https://github.com/NVIDIA/FasterTransformer) is a
  state-of-the-art Transformer inference library. As of version
  [5.2](https://github.com/NVIDIA/FasterTransformer/commit/b672f49e256ba7a2d4fc9691d270b60b7fc1a2ff),
  FlashAttention is used as a component of FasterTransformer to speed up GPT inference.

66
67
68
69
70
71
72
73
74
75
76
- [Kernl](https://github.com/ELS-RD/kernl) is a library for fast Transformer
  inference. They use FlashAttention as part of their
  [approach](https://twitter.com/pommedeterre33/status/1585284221014245377) to
  speed up Transformers by up to 12x.

## Diffusion model training and inference

- Huggingface's [diffusers](https://github.com/huggingface/diffusers) library
  for diffusion models. FlashAttention is integrated into [diffusers
  v0.7.0](https://github.com/huggingface/diffusers/releases/tag/v0.7.0).
  Up to 2x faster inference and lower memory usage.
77

78
79
80
81
- Colossal-AI's
  [implementation](https://github.com/hpcaitech/ColossalAI/tree/main/examples/images/diffusion)
  of Stable Diffusion: with FlashAttention as one of its components, it speeds up
  pretraining by up to 6.5x, and reduces the hardware cost of fine-tuning by 7x.
82
83
84
85
86
87

- Meta's
  [AITemplate](https://ai.facebook.com/blog/gpu-inference-engine-nvidia-amd-open-source/)
  with FlashAttention one of the components, is currently the [fastest](https://twitter.com/bing_xu_/status/1590447334055632897) Stable
  Diffusion inference engine that we know of.

88
89
- Stable Diffusion inference from
  [Labml.ai](https://twitter.com/labmlai/status/1573634095732490240): 50% speedup.
90

91
92
- Our own Stable Diffusion [fork](https://twitter.com/realDanFu/status/1580641495991754752) uses FlashAttention to get 3-4x speedup compared
  to the original version.
93

94
95
96
97
98
99
100
## Other models

- [Uni-Fold](https://github.com/dptech-corp/Uni-Fold): Uni-Fold is an
  open-source platform for developing protein models beyond AlphaFold. With
  FlashAttention, Uni-Fold is 2.6x
  [faster](https://twitter.com/guolin_ke/status/1580532071901995008) than AlphaFold.

Tri Dao's avatar
Tri Dao committed
101
102
103
104
- [OpenFold](https://github.com/aqlaboratory/openfold): a trainable,
  memory-efficient, and GPU-friendly PyTorch reproduction of AlphaFold 2. With
  FlashAttention as one of its
  [components](https://twitter.com/gahdritz/status/1595420944880779266), it is
Tri Dao's avatar
Tri Dao committed
105
106
  up to 3x faster than AlphaFold2 to run inference on short sequences, and can
  predict 2x longer structures.
Tri Dao's avatar
Tri Dao committed
107

108
109
110
111
112
## Different implementations

- [Triton](https://github.com/openai/triton): an [implementation](https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py) of
  FlashAttention in Triton by Phil Tillet from OpenAI. Triton is a Python-based
  language and compiler for parallel programming.
113

114
- [xformers](https://github.com/facebookresearch/xformers): The xformers team
115
116
117
118
119
  has implemented [memory-efficient
  attention](https://twitter.com/fvsmassa/status/1580229170629849089) in a
  similar spirit to FlashAttention.
  xformers dynamically dispatches to whichever implementation is available / faster.

120
121
- [Jax](https://github.com/google/jax): an [implementation](https://github.com/lucidrains/flash-attention-jax)
  in Jax by [lucidrains](https://github.com/lucidrains/).