README.md 9.78 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
# Optimized Transformer implementation
This repo contains examples of how FlashAttention can be integrated into a model
(e.g., GPT, ViT) and trained end-to-end. We also provide optimized
implementations of other layers (e.g., MLP, LayerNorm, cross-entropy loss,
rotary embedding). Overall this speeds up training by 3-5x compared to the
baseline implementation from Huggingface, reaching up to 189 TFLOPs/sec per A100,
equivalent to 60.6\% model FLOPs utilization (we don't need any activation
checkpointing). All without changing the model architecture (i.e., no
approximation).
Tri Dao's avatar
Tri Dao committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

Goals:
- Performance: we optimize for model speed and memory, especially on 1-node
  (e.g., with 8 A100s).
- Flexibility: we provide optimized building blocks (MLP, attention, LayerNorm),
  and the model code illustrates how these components can be put together.
  The training code also aims to be model- & task-agnostic.

Non-goals (and other resources):
- Support as many models as possible: Huggingface's
  [transformers](https://github.com/huggingface/transformers) and
  [timm](https://github.com/rwightman/pytorch-image-models/) are great for this.
- Large-scale distributed training: our codebase has been used for multi-GPU and multi-node
  training for models up to 2.7B parameters. However, if you're looking for large-scale distributed
  training techniques (e.g., pipeline parallelism, tensor parallelism),
  check out [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/) and
  [DeepSpeed](https://github.com/microsoft/deepspeed).
- Inference: we currently focus on training (this might change in the future).
  If you want fast inference, take a look at
  [FasterTransformer](https://github.com/NVIDIA/FasterTransformer).
- Production: this codebase was written during several research projects to validate ideas
  on speeding up ML models.

## Model Components

The GPT model is implemented
[here](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py).
Tri Dao's avatar
Tri Dao committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
And here's an example to construct the GPT3-1.3B model with rotary embedding:
```python
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from flash_attn.models.gpt import GPTLMHeadModel

seqlen = 2048
hidden_dim = 2048
nheads = 16
n_layer = 24
rotary_emb_fraction = 0.5
config = GPT2Config(vocab_size=50257, n_positions=seqlen, n_embd=hidden_dim,
                    n_layer=n_layer, n_head=nheads, 
                    scale_attn_by_inverse_layer_idx=True, 
                    rotary_emb_fraction=rotary_emb_fraction,
                    use_flash_attn=True, fused_dense_gelu_dense=True,
                    fused_bias_fc=True, fused_dropout_add_ln=True, 
                    pad_vocab_size_multiple=8)
model = GPTLMHeadModel(config)
```
Tri Dao's avatar
Tri Dao committed
56
57
58

We provide the following optimized components:

Tri Dao's avatar
Tri Dao committed
59
1. FlashAttention: fast and memory-efficient exact attention. This makes
Tri Dao's avatar
Tri Dao committed
60
61
62
63
64
65
attention much faster and saves a lot of activation memory. As a result we don't need
to use any activation checkpointing.
```sh
pip install flash-attn
```

Tri Dao's avatar
Tri Dao committed
66
2. Fused matmul + bias (forward and backward), and fused matmul + bias + gelu
Tri Dao's avatar
Tri Dao committed
67
68
69
70
71
72
73
(forward and backward), adapted from Apex's
[FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense). We
make it work for bfloat16. For best performance, you should use CUDA >= 11.8. CuBLAS versions before
this doesn't have the best matmul + bias + gelu performance for bfloat16.
```sh
cd ../csrc/fused_dense_lib && pip install .
```
Tri Dao's avatar
Tri Dao committed
74
3. Optimized cross-entropy loss, adapted from Apex's
Tri Dao's avatar
Tri Dao committed
75
76
77
78
[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). We make it work for bfloat16 and support in-place backward to save memory.
```sh
cd ../csrc/xentropy && pip install .
```
Tri Dao's avatar
Tri Dao committed
79
4. Fused rotary embedding:
Tri Dao's avatar
Tri Dao committed
80
81
82
```sh
cd ../csrc/rotary && pip install .
```
Tri Dao's avatar
Tri Dao committed
83
5. Fused dropout + residual + LayerNorm, adapted from Apex's
Tri Dao's avatar
Tri Dao committed
84
85
86
87
88
89
90
91
[FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
This only supports a limited set of dimensions, see `csrc/layer_norm/ln_fwd_cuda_kernel.cu`.
```sh
cd ../csrc/layer_norm && pip install .
```

## Training

Tri Dao's avatar
Tri Dao committed
92
93
94
We also provide here training scripts to train GPT2 on Openwebtext and GPT3 on
The Pile as examples. Feel free to use the model in your own training setup as
well.
Tri Dao's avatar
Tri Dao committed
95
96
97
98
99
100
101
102

We use [Hydra](https://hydra.cc/) for configuration,
[Pytorch-Lightning](https://github.com/Lightning-AI/lightning) for training, and
[Wandb](https://wandb.ai/) for logging.

We use the template from `https://github.com/ashleve/lightning-hydra-template`.
Please read the instructions there to understand the repo structure.

Tri Dao's avatar
Tri Dao committed
103
104
105
106
107
108
109
110
### Requirements

Python 3.8+, Pytorch 1.12+, torchvision, einops, timm, hydra-core,
hydra-colorlog, python-dotenv, rich, pytorch-lightning, triton, flash-attn.
We recommend CUDA 11.8 (e.g., using the Nvidia's Pytorch Docker image from https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)

We provide a Dockerfile that lists all the required packages.

Tri Dao's avatar
Tri Dao committed
111
112
113
114
115
### Dataset preparation

Running the training command would automatically download the datasets
(Openwebtext, Pile), tokenize with the GPT2 tokenizer, concatenate all the
tokens, then save this cache to disk. Alternatively, you can also prepare the
Tri Dao's avatar
Tri Dao committed
116
datasets as a separate step.
Tri Dao's avatar
Tri Dao committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

The cached datasets are saved to `${DATA_DIR}/openwebtext` and
`${DATA_DIR}/the_pile`. If `${DATA_DIR}` is not set, they will be saved to
`./data/{openwebtext,the_pile}`. 

- Openwebtext:
```sh
export PYTHONPATH=$PWD:$PYTHONPATH
pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "openwebtext"
```
This takes around 1h on a 64-core CPU. The processed dataset has size 17GB.

- The Pile:
```sh
export PYTHONPATH=$PWD:$PYTHONPATH
pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "pile"
```
Tri Dao's avatar
Tri Dao committed
134
This takes around 20h on a 64-core CPU. The processed dataset has size 699GB.
Tri Dao's avatar
Tri Dao committed
135
136
137
138

### GPT2 training on Openwebtext
To train GPT2 on Openwebtext with 8 GPUs:
```sh
Tri Dao's avatar
Tri Dao committed
139
140
141
142
python run.py experiment=owt/gpt2s-flash trainer.devices=8  # 125M
python run.py experiment=owt/gpt2m-flash trainer.devices=8  # 355M
python run.py experiment=owt/gpt2l-flash trainer.devices=8  # 760M
python run.py experiment=owt/gpt2xl-flash trainer.devices=8  # 1.6B
Tri Dao's avatar
Tri Dao committed
143
144
145
146
147
148
149
150
```
The default parameters are set for 8 x A100 80GB.

To train with bf16 instead of fp16, add `trainer.precision=bf16`.

### GPT3 training on The Pile
To train GPT3 on The Pile with 8 GPUs:
```sh
Tri Dao's avatar
Tri Dao committed
151
152
153
154
155
python run.py experiment=pile/gpt3s-flash trainer.devices=8  # 125M
python run.py experiment=pile/gpt3m-flash trainer.devices=8  # 355M
python run.py experiment=pile/gpt3l-flash trainer.devices=8  # 760M
python run.py experiment=pile/gpt3xl-flash trainer.devices=8  # 1.3B
python run.py experiment=pile/gpt3-2.7B-flash-hdim128 trainer.devices=8  # 2.7B
Tri Dao's avatar
Tri Dao committed
156
```
Tri Dao's avatar
Tri Dao committed
157
The default parameters are set for 8 x A100 80GB. We train with bf16 by default.
Tri Dao's avatar
Tri Dao committed
158

Tri Dao's avatar
Tri Dao committed
159
To train with rotary embedding, run the experiments `pile/gpt3{s,m,l,xl**-flash-rotary**.
Tri Dao's avatar
Tri Dao committed
160

Tri Dao's avatar
Tri Dao committed
161
### Training options
Tri Dao's avatar
Tri Dao committed
162

Tri Dao's avatar
Tri Dao committed
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
**Gradient accumulation**: to adjust device batch size to fit into GPU memory
(the global batch size stays the same, and gradient accumulation is calculated
automatically), set `datamodule.batch_size=blah**.

**Multi-node**: to train on multiple nodes, add `trainer.num_nodes=blah`.

**Speed benchmarking**: to print out iteration time, add `+callbacks.speed_monitor.verbose=True`.

**Resumable training**: set a name to the run, and then set `resume=True` when
you resume. Training will restart at exactly the same batch.
```sh
python run.py experiment=pile/gpt3s-flash trainer.devices=8 name=pile-gpt3s-flash resume=True
```

## Training speed

We measure the wallclock training speed on one node with 8 x A100 80GB SXM4 80GB (400W) with NVLink.

FLOPs are calculated using the formula from the [Megatron-LM
paper](https://arxiv.org/abs/2104.04473) (Section 5.1), except we scale by 3/4
to get the model FLOPs (instead of hardware FLOPs with activation
checkpointing).


### GPT2 (sequence length 1024)

![GPT2 speedup](../assets/gpt2_training_efficiency.jpg)

The implementation in this repo (FlashAttention) is 3-4x faster than the
baseline implementation from Huggingface.

### GPT3 (sequence length 2048)

![GPT3 speedup](../assets/gpt3_training_efficiency.jpg)

The implementation in this repo (FlashAttention) is 3-5x faster than the
baseline implementation from Huggingface.

For the GPT3-2.7B model, we set head dimension to 128 (instead of 80) for better efficiency.

We include here more details on the training speed with FlashAttention on 8 x
A100 80GB.

| Model     | Batch size (tokens) | Through put (tokens/sec) | Hours / 1B tokens |
| --------- | ------------------- | ------------------------ | ----------------- |
| GPT3-125M | 0.5M                | 1310k                    |              0.21 |
| GPT3-355M | 0.5M                | 503k                     |              0.55 |
| GPT3-760M | 0.5M                | 245k                     |              1.13 |
| GPT3-1.3B | 1M                  | 169k                     |              1.64 |
| GPT3-2.7B | 1M                  | 85k                      |              3.27 |

As an example, this means that one can train a GPT3-1.3B model on 26B tokens
(compute-optimal according to Chinchilla scaling) in about 43 hours on 8 x A100.

## Training quality

We include here the loss curve for GPT2 on Openwebtext, trained for 200B tokens.
For GPT2, the runs with FlashAttention yield the same loss curve as the runs
with the baseline implementation from Huggingface for 125M and 355M models. For
larger models the baseline implementation just takes too long.

![GPT2 training curve](../assets/gpt2_training_curve.jpg)

We include here the loss curve for GPT3 on The Pile, trained for 400B tokens.
The 125M, 355M, 760M models have batch size 512k tokens so this translates to
800k training steps, while the 1.3B and 2.7B models have batch size 1M tokens,
which translates to 400k training steps.

![GPT3 training curve](../assets/gpt3_training_curve.jpg)