Commit 4a6eaa9f authored by Tri Dao's avatar Tri Dao
Browse files

Update configs, add results

parent 0bf5e500
...@@ -14,6 +14,20 @@ We've been very happy to see FlashAttention being widely adopted in such a short ...@@ -14,6 +14,20 @@ We've been very happy to see FlashAttention being widely adopted in such a short
time after its release. This [page](https://github.com/HazyResearch/flash-attention/blob/main/usage.md) time after its release. This [page](https://github.com/HazyResearch/flash-attention/blob/main/usage.md)
contains a partial list of places where FlashAttention is being used. contains a partial list of places where FlashAttention is being used.
## Full model code and training script
We have released the full GPT model
[implementation](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py).
We also provide optimized implementations of other layers (e.g., MLP, LayerNorm,
cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x
compared to the baseline implementation from Huggingface, reaching up to 189
TFLOPs/sec per A100, equivalent to 60.6\% model FLOPs utilization (we don't need
any activation checkpointing).
We also include a training
[script](https://github.com/HazyResearch/flash-attention/tree/main/training) to
train GPT2 on Openwebtext and GPT3 on The Pile.
## Triton implementation of FlashAttention ## Triton implementation of FlashAttention
Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton: Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton:
......
...@@ -65,14 +65,6 @@ ENV PIP_NO_CACHE_DIR=1 ...@@ -65,14 +65,6 @@ ENV PIP_NO_CACHE_DIR=1
# # apex and pytorch-fast-transformers take a while to compile so we install them first # # apex and pytorch-fast-transformers take a while to compile so we install them first
# TD [2022-04-28] apex is already installed. In case we need a newer commit: # TD [2022-04-28] apex is already installed. In case we need a newer commit:
# RUN pip install --upgrade --force-reinstall --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" --global-option="--fmha" --global-option="--fast_layer_norm" --global-option="--xentropy" git+https://github.com/NVIDIA/apex.git#egg=apex # RUN pip install --upgrade --force-reinstall --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" --global-option="--fmha" --global-option="--fast_layer_norm" --global-option="--xentropy" git+https://github.com/NVIDIA/apex.git#egg=apex
# TD [2021-10-28] pytorch-fast-transformers doesn't have a wheel compatible with CUDA 11.3 and Pytorch 1.10
# So we install from source, and change compiler flag -arch=compute_60 -> -arch=compute_70 for V100
# RUN pip install pytorch-fast-transformers==0.4.0
# RUN pip install git+git://github.com/idiap/fast-transformers.git@v0.4.0 # doesn't work on V100
RUN git clone https://github.com/idiap/fast-transformers \
&& sed -i 's/\["-arch=compute_60"\]/\["-arch=compute_70"\]/' fast-transformers/setup.py \
&& pip install fast-transformers/ \
&& rm -rf fast-transformers
# xgboost conflicts with deepspeed # xgboost conflicts with deepspeed
RUN pip uninstall -y xgboost && DS_BUILD_UTILS=1 DS_BUILD_FUSED_LAMB=1 pip install deepspeed==0.7.5 RUN pip uninstall -y xgboost && DS_BUILD_UTILS=1 DS_BUILD_FUSED_LAMB=1 pip install deepspeed==0.7.5
......
Examples of how FlashAttention can be integrated into a model (e.g., GPT, ViT) # Optimized Transformer implementation
and trained end-to-end. This repo contains examples of how FlashAttention can be integrated into a model
We also added optimized implementations of other layers (e.g., MLP, LayerNorm, (e.g., GPT, ViT) and trained end-to-end. We also provide optimized
cross-entropy loss, rotary embedding). implementations of other layers (e.g., MLP, LayerNorm, cross-entropy loss,
rotary embedding). Overall this speeds up training by 3-5x compared to the
baseline implementation from Huggingface, reaching up to 189 TFLOPs/sec per A100,
equivalent to 60.6\% model FLOPs utilization (we don't need any activation
checkpointing). All without changing the model architecture (i.e., no
approximation).
Goals: Goals:
- Performance: we optimize for model speed and memory, especially on 1-node - Performance: we optimize for model speed and memory, especially on 1-node
...@@ -29,17 +34,36 @@ Non-goals (and other resources): ...@@ -29,17 +34,36 @@ Non-goals (and other resources):
The GPT model is implemented The GPT model is implemented
[here](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py). [here](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py).
And here's an example to construct the GPT3-1.3B model with rotary embedding:
```python
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from flash_attn.models.gpt import GPTLMHeadModel
seqlen = 2048
hidden_dim = 2048
nheads = 16
n_layer = 24
rotary_emb_fraction = 0.5
config = GPT2Config(vocab_size=50257, n_positions=seqlen, n_embd=hidden_dim,
n_layer=n_layer, n_head=nheads,
scale_attn_by_inverse_layer_idx=True,
rotary_emb_fraction=rotary_emb_fraction,
use_flash_attn=True, fused_dense_gelu_dense=True,
fused_bias_fc=True, fused_dropout_add_ln=True,
pad_vocab_size_multiple=8)
model = GPTLMHeadModel(config)
```
We provide the following optimized components: We provide the following optimized components:
- FlashAttention: fast and memory-efficient exact attention. This makes 1. FlashAttention: fast and memory-efficient exact attention. This makes
attention much faster and saves a lot of activation memory. As a result we don't need attention much faster and saves a lot of activation memory. As a result we don't need
to use any activation checkpointing. to use any activation checkpointing.
```sh ```sh
pip install flash-attn pip install flash-attn
``` ```
- Fused matmul + bias (forward and backward), and fused matmul + bias + gelu 2. Fused matmul + bias (forward and backward), and fused matmul + bias + gelu
(forward and backward), adapted from Apex's (forward and backward), adapted from Apex's
[FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense). We [FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense). We
make it work for bfloat16. For best performance, you should use CUDA >= 11.8. CuBLAS versions before make it work for bfloat16. For best performance, you should use CUDA >= 11.8. CuBLAS versions before
...@@ -47,16 +71,16 @@ this doesn't have the best matmul + bias + gelu performance for bfloat16. ...@@ -47,16 +71,16 @@ this doesn't have the best matmul + bias + gelu performance for bfloat16.
```sh ```sh
cd ../csrc/fused_dense_lib && pip install . cd ../csrc/fused_dense_lib && pip install .
``` ```
- Optimized cross-entropy loss, adapted from Apex's 3. Optimized cross-entropy loss, adapted from Apex's
[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). We make it work for bfloat16 and support in-place backward to save memory. [Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). We make it work for bfloat16 and support in-place backward to save memory.
```sh ```sh
cd ../csrc/xentropy && pip install . cd ../csrc/xentropy && pip install .
``` ```
- Fused rotary embedding: 4. Fused rotary embedding:
```sh ```sh
cd ../csrc/rotary && pip install . cd ../csrc/rotary && pip install .
``` ```
- Fused dropout + residual + LayerNorm, adapted from Apex's 5. Fused dropout + residual + LayerNorm, adapted from Apex's
[FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). We add dropout and residual, and make it work for both pre-norm and post-norm architecture. [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
This only supports a limited set of dimensions, see `csrc/layer_norm/ln_fwd_cuda_kernel.cu`. This only supports a limited set of dimensions, see `csrc/layer_norm/ln_fwd_cuda_kernel.cu`.
```sh ```sh
...@@ -65,8 +89,9 @@ cd ../csrc/layer_norm && pip install . ...@@ -65,8 +89,9 @@ cd ../csrc/layer_norm && pip install .
## Training ## Training
Feel free to use the model in your training setup. We also provide here training We also provide here training scripts to train GPT2 on Openwebtext and GPT3 on
scripts to train GPT2 on Openwebtext and GPT3 on The Pile as examples. The Pile as examples. Feel free to use the model in your own training setup as
well.
We use [Hydra](https://hydra.cc/) for configuration, We use [Hydra](https://hydra.cc/) for configuration,
[Pytorch-Lightning](https://github.com/Lightning-AI/lightning) for training, and [Pytorch-Lightning](https://github.com/Lightning-AI/lightning) for training, and
...@@ -75,12 +100,20 @@ We use [Hydra](https://hydra.cc/) for configuration, ...@@ -75,12 +100,20 @@ We use [Hydra](https://hydra.cc/) for configuration,
We use the template from `https://github.com/ashleve/lightning-hydra-template`. We use the template from `https://github.com/ashleve/lightning-hydra-template`.
Please read the instructions there to understand the repo structure. Please read the instructions there to understand the repo structure.
### Requirements
Python 3.8+, Pytorch 1.12+, torchvision, einops, timm, hydra-core,
hydra-colorlog, python-dotenv, rich, pytorch-lightning, triton, flash-attn.
We recommend CUDA 11.8 (e.g., using the Nvidia's Pytorch Docker image from https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
We provide a Dockerfile that lists all the required packages.
### Dataset preparation ### Dataset preparation
Running the training command would automatically download the datasets Running the training command would automatically download the datasets
(Openwebtext, Pile), tokenize with the GPT2 tokenizer, concatenate all the (Openwebtext, Pile), tokenize with the GPT2 tokenizer, concatenate all the
tokens, then save this cache to disk. Alternatively, you can also prepare the tokens, then save this cache to disk. Alternatively, you can also prepare the
datasets as a separate steps. datasets as a separate step.
The cached datasets are saved to `${DATA_DIR}/openwebtext` and The cached datasets are saved to `${DATA_DIR}/openwebtext` and
`${DATA_DIR}/the_pile`. If `${DATA_DIR}` is not set, they will be saved to `${DATA_DIR}/the_pile`. If `${DATA_DIR}` is not set, they will be saved to
...@@ -98,36 +131,101 @@ This takes around 1h on a 64-core CPU. The processed dataset has size 17GB. ...@@ -98,36 +131,101 @@ This takes around 1h on a 64-core CPU. The processed dataset has size 17GB.
export PYTHONPATH=$PWD:$PYTHONPATH export PYTHONPATH=$PWD:$PYTHONPATH
pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "pile" pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "pile"
``` ```
This takes around 20h on a 96-core CPU. The processed dataset has size 699GB. This takes around 20h on a 64-core CPU. The processed dataset has size 699GB.
### GPT2 training on Openwebtext ### GPT2 training on Openwebtext
To train GPT2 on Openwebtext with 8 GPUs: To train GPT2 on Openwebtext with 8 GPUs:
```sh ```sh
python run.py experiment=owt/gpt2s-flash trainer.devices=8 python run.py experiment=owt/gpt2s-flash trainer.devices=8 # 125M
python run.py experiment=owt/gpt2m-flash trainer.devices=8 python run.py experiment=owt/gpt2m-flash trainer.devices=8 # 355M
python run.py experiment=owt/gpt2l-flash trainer.devices=8 python run.py experiment=owt/gpt2l-flash trainer.devices=8 # 760M
python run.py experiment=owt/gpt2xl-flash trainer.devices=8 python run.py experiment=owt/gpt2xl-flash trainer.devices=8 # 1.6B
``` ```
The default parameters are set for 8 x A100 80GB. The default parameters are set for 8 x A100 80GB.
To train with bf16 instead of fp16, add `trainer.precision=bf16`. To train with bf16 instead of fp16, add `trainer.precision=bf16`.
To adjust device batch size to fit GPU memory (the global batch size stays the
same, and gradient accumulation is calculated automatically), set `datamodule.batch_size=blah`.
### GPT3 training on The Pile ### GPT3 training on The Pile
To train GPT3 on The Pile with 8 GPUs: To train GPT3 on The Pile with 8 GPUs:
```sh ```sh
python run.py experiment=pile/gpt3s-flash trainer.devices=8 python run.py experiment=pile/gpt3s-flash trainer.devices=8 # 125M
python run.py experiment=pile/gpt3m-flash trainer.devices=8 python run.py experiment=pile/gpt3m-flash trainer.devices=8 # 355M
python run.py experiment=pile/gpt3l-flash trainer.devices=8 python run.py experiment=pile/gpt3l-flash trainer.devices=8 # 760M
python run.py experiment=pile/gpt3xl-flash trainer.devices=8 python run.py experiment=pile/gpt3xl-flash trainer.devices=8 # 1.3B
python run.py experiment=pile/gpt3-2.7B-flash-hdim128 trainer.devices=8 # 2.7B
``` ```
The default parameters are set for 8 x A100 80GB. The default parameters are set for 8 x A100 80GB. We train with bf16 by default.
## Requirements To train with rotary embedding, run the experiments `pile/gpt3{s,m,l,xl**-flash-rotary**.
Python 3.8+, Pytorch 1.12+, torchvision, einops, timm, hydra-core, ### Training options
hydra-colorlog, python-dotenv, rich, pytorch-lightning, triton, flash-attn.
We recommend CUDA 11.8 (e.g., using the Nvidia's Pytorch Docker image from https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
We provide a Dockerfile that lists all the required packages. **Gradient accumulation**: to adjust device batch size to fit into GPU memory
(the global batch size stays the same, and gradient accumulation is calculated
automatically), set `datamodule.batch_size=blah**.
**Multi-node**: to train on multiple nodes, add `trainer.num_nodes=blah`.
**Speed benchmarking**: to print out iteration time, add `+callbacks.speed_monitor.verbose=True`.
**Resumable training**: set a name to the run, and then set `resume=True` when
you resume. Training will restart at exactly the same batch.
```sh
python run.py experiment=pile/gpt3s-flash trainer.devices=8 name=pile-gpt3s-flash resume=True
```
## Training speed
We measure the wallclock training speed on one node with 8 x A100 80GB SXM4 80GB (400W) with NVLink.
FLOPs are calculated using the formula from the [Megatron-LM
paper](https://arxiv.org/abs/2104.04473) (Section 5.1), except we scale by 3/4
to get the model FLOPs (instead of hardware FLOPs with activation
checkpointing).
### GPT2 (sequence length 1024)
![GPT2 speedup](../assets/gpt2_training_efficiency.jpg)
The implementation in this repo (FlashAttention) is 3-4x faster than the
baseline implementation from Huggingface.
### GPT3 (sequence length 2048)
![GPT3 speedup](../assets/gpt3_training_efficiency.jpg)
The implementation in this repo (FlashAttention) is 3-5x faster than the
baseline implementation from Huggingface.
For the GPT3-2.7B model, we set head dimension to 128 (instead of 80) for better efficiency.
We include here more details on the training speed with FlashAttention on 8 x
A100 80GB.
| Model | Batch size (tokens) | Through put (tokens/sec) | Hours / 1B tokens |
| --------- | ------------------- | ------------------------ | ----------------- |
| GPT3-125M | 0.5M | 1310k | 0.21 |
| GPT3-355M | 0.5M | 503k | 0.55 |
| GPT3-760M | 0.5M | 245k | 1.13 |
| GPT3-1.3B | 1M | 169k | 1.64 |
| GPT3-2.7B | 1M | 85k | 3.27 |
As an example, this means that one can train a GPT3-1.3B model on 26B tokens
(compute-optimal according to Chinchilla scaling) in about 43 hours on 8 x A100.
## Training quality
We include here the loss curve for GPT2 on Openwebtext, trained for 200B tokens.
For GPT2, the runs with FlashAttention yield the same loss curve as the runs
with the baseline implementation from Huggingface for 125M and 355M models. For
larger models the baseline implementation just takes too long.
![GPT2 training curve](../assets/gpt2_training_curve.jpg)
We include here the loss curve for GPT3 on The Pile, trained for 400B tokens.
The 125M, 355M, 760M models have batch size 512k tokens so this translates to
800k training steps, while the 1.3B and 2.7B models have batch size 1M tokens,
which translates to 400k training steps.
![GPT3 training curve](../assets/gpt3_training_curve.jpg)
...@@ -28,7 +28,7 @@ defaults: ...@@ -28,7 +28,7 @@ defaults:
datamodule: datamodule:
# batch_size: 16 # batch_size: 16
batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else 16)"} batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu_mem} < 80 else 32))"}
trainer: trainer:
# strategy: null # strategy: null
......
...@@ -4,13 +4,13 @@ defaults: ...@@ -4,13 +4,13 @@ defaults:
- override /model/gpt2model: gpt2-medium - override /model/gpt2model: gpt2-medium
# Can enable mlp_checkpoint_lvl to fit batch_size 32 to A100 40GB # Can enable mlp_checkpoint_lvl to fit batch_size 32 to A100 40GB
model: # model:
config: # config:
mlp_checkpoint_lvl: 1 # mlp_checkpoint_lvl: 1
datamodule: datamodule:
# batch_size: 32 # batch_size: 32
batch_size: ${eval:"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else 32)"} batch_size: ${eval:"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else (32 if ${train.gpu_mem} < 80 else 64))"}
train: train:
optimizer: optimizer:
......
...@@ -10,8 +10,8 @@ defaults: ...@@ -10,8 +10,8 @@ defaults:
# mlp_checkpoint_lvl: 1 # mlp_checkpoint_lvl: 1
datamodule: datamodule:
batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)"} batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else (8 if ${train.gpu_mem} < 80 else 16))"}
# With adamw-zero optimizer: # With adamw-zero optimizer, on A100 40GB:
# checkpoint_lvl=1, batch size = 4: mem 37GB, 4650ms / batch of 512 (285ms * 15 + 375ms * 1) # checkpoint_lvl=1, batch size = 4: mem 37GB, 4650ms / batch of 512 (285ms * 15 + 375ms * 1)
# checkpoint_lvl=1, batch size = 8: mem 46GB, 4330ms / batch of 512 (530ms * 7 + 620ms * 1) # checkpoint_lvl=1, batch size = 8: mem 46GB, 4330ms / batch of 512 (530ms * 7 + 620ms * 1)
# checkpoint_lvl=2, batch size = 8: mem 41GB, 4570ms / batch of 512 (560ms * 7 + 650ms * 1) # checkpoint_lvl=2, batch size = 8: mem 41GB, 4570ms / batch of 512 (560ms * 7 + 650ms * 1)
......
# @package _global_
defaults:
- /experiment/owt/gpt2l-hf.yaml
- override /model/gpt2model: gpt2-xlarge
datamodule:
batch_size: 1
# @package _global_ # @package _global_
defaults: defaults:
- /experiment/pile/gpt2xl-flash-8k.yaml - /experiment/pile/gpt3xl-flash-8k.yaml
model: model:
config: config:
......
# @package _global_ # @package _global_
defaults: defaults:
- /experiment/pile/gpt2xl-flash-rotary-8k.yaml - /experiment/pile/gpt3xl-flash-rotary-8k.yaml
model: model:
config: config:
......
# @package _global_ # @package _global_
defaults: defaults:
- /experiment/pile/gpt2xl-flash-rotary.yaml - /experiment/pile/gpt3xl-flash-rotary.yaml
model: model:
config: config:
......
# @package _global_
defaults:
- /experiment/pile/gpt3xl-flash.yaml
model:
config:
n_embd: 2560
n_head: 20 # Headdim 128 is faster than headdim 80
n_layer: 32
initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"}
mlp_checkpoint_lvl: 0
datamodule:
batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"}
train:
optimizer:
lr: 1.6e-4
# @package _global_ # @package _global_
defaults: defaults:
- /experiment/pile/gpt2xl-flash-rotary-8k.yaml - /experiment/pile/gpt3xl-flash-rotary-8k.yaml
model: model:
config: config:
......
# @package _global_ # @package _global_
defaults: defaults:
- /experiment/pile/gpt2xl-flash-rotary.yaml - /experiment/pile/gpt3xl-flash-rotary.yaml
model: model:
config: config:
......
# @package _global_
defaults:
- /experiment/pile/gpt3xl-flash.yaml
model:
config:
n_embd: 2560
n_head: 32
n_layer: 32
initializer_range: ${eval:"(2 / (${.n_embd} * 5)) ** 0.5"}
mlp_checkpoint_lvl: 0
datamodule:
batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"}
train:
optimizer:
lr: 1.6e-4
# @package _global_
defaults:
- /experiment/pile/gpt3xl-hf.yaml
model:
config:
n_embd: 2560
n_head: 128
n_layer: 32
# OOM on A100 80GB even with batch_size = 1
datamodule:
batch_size: 1
train:
optimizer:
lr: 1.6e-4
# @package _global_
defaults:
- /experiment/pile/gpt3xl-hf.yaml
model:
config:
n_embd: 2560
n_head: 32
n_layer: 32
datamodule:
batch_size: 1
train:
optimizer:
lr: 1.6e-4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment