Commit 0bf5e500 authored by Tri Dao's avatar Tri Dao
Browse files

Release training code

parent 9bc63d1e
...@@ -2,6 +2,8 @@ This CUDA extension implements fused dropout + residual + LayerNorm, based on ...@@ -2,6 +2,8 @@ This CUDA extension implements fused dropout + residual + LayerNorm, based on
Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). Apex's [FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm).
We add dropout and residual, and make it work for both pre-norm and post-norm architecture. We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
This only supports a limited set of dimensions, see `csrc/layer_norm/ln_fwd_cuda_kernel.cu`.
It has only been tested on A100s. It has only been tested on A100s.
```sh ```sh
......
# Inspired by https://github.com/anibali/docker-pytorch/blob/master/dockerfiles/1.10.0-cuda11.3-ubuntu20.04/Dockerfile
# ARG COMPAT=0
ARG PERSONAL=0
# FROM nvidia/cuda:11.3.1-devel-ubuntu20.04 as base-0
FROM nvcr.io/nvidia/pytorch:22.11-py3 as base
ENV HOST docker
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
# https://serverfault.com/questions/683605/docker-container-time-timezone-will-not-reflect-changes
ENV TZ America/Los_Angeles
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
# git for installing dependencies
# tzdata to set time zone
# wget and unzip to download data
# [2021-09-09] TD: zsh, stow, subversion, fasd are for setting up my personal environment.
# [2021-12-07] TD: openmpi-bin for MPI (multi-node training)
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
cmake \
curl \
ca-certificates \
sudo \
less \
htop \
git \
tzdata \
wget \
tmux \
zip \
unzip \
zsh stow subversion fasd \
&& rm -rf /var/lib/apt/lists/*
# openmpi-bin \
# Allow running runmpi as root
# ENV OMPI_ALLOW_RUN_AS_ROOT=1 OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1
# # Create a non-root user and switch to it
# RUN adduser --disabled-password --gecos '' --shell /bin/bash user \
# && echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user
# USER user
# All users can use /home/user as their home directory
ENV HOME=/home/user
RUN mkdir -p /home/user && chmod 777 /home/user
WORKDIR /home/user
# Set up personal environment
# FROM base-${COMPAT} as env-0
FROM base as env-0
FROM env-0 as env-1
# Use ONBUILD so that the dotfiles dir doesn't need to exist unless we're building a personal image
# https://stackoverflow.com/questions/31528384/conditional-copy-add-in-dockerfile
ONBUILD COPY dotfiles ./dotfiles
ONBUILD RUN cd ~/dotfiles && stow bash zsh tmux && sudo chsh -s /usr/bin/zsh $(whoami)
# nvcr pytorch image sets SHELL=/bin/bash
ONBUILD ENV SHELL=/bin/zsh
FROM env-${PERSONAL} as packages
# Disable pip cache: https://stackoverflow.com/questions/45594707/what-is-pips-no-cache-dir-good-for
ENV PIP_NO_CACHE_DIR=1
# # apex and pytorch-fast-transformers take a while to compile so we install them first
# TD [2022-04-28] apex is already installed. In case we need a newer commit:
# RUN pip install --upgrade --force-reinstall --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--fast_multihead_attn" --global-option="--fmha" --global-option="--fast_layer_norm" --global-option="--xentropy" git+https://github.com/NVIDIA/apex.git#egg=apex
# TD [2021-10-28] pytorch-fast-transformers doesn't have a wheel compatible with CUDA 11.3 and Pytorch 1.10
# So we install from source, and change compiler flag -arch=compute_60 -> -arch=compute_70 for V100
# RUN pip install pytorch-fast-transformers==0.4.0
# RUN pip install git+git://github.com/idiap/fast-transformers.git@v0.4.0 # doesn't work on V100
RUN git clone https://github.com/idiap/fast-transformers \
&& sed -i 's/\["-arch=compute_60"\]/\["-arch=compute_70"\]/' fast-transformers/setup.py \
&& pip install fast-transformers/ \
&& rm -rf fast-transformers
# xgboost conflicts with deepspeed
RUN pip uninstall -y xgboost && DS_BUILD_UTILS=1 DS_BUILD_FUSED_LAMB=1 pip install deepspeed==0.7.5
# General packages that we don't care about the version
# zstandard to extract the_pile dataset
# psutil to get the number of cpu physical cores
# twine to upload package to PyPI
# ninja is broken for some reason, it returns error code 245
RUN pip uninstall -y ninja && pip install ninja
RUN pip install pytest matplotlib jupyter ipython ipdb gpustat scikit-learn spacy munch einops opt_einsum fvcore gsutil cmake pykeops zstandard psutil h5py twine \
&& python -m spacy download en_core_web_sm
# hydra
RUN pip install hydra-core==1.2.0 hydra-colorlog==1.2.0 hydra-optuna-sweeper==1.2.0 pyrootutils rich
# Core packages
RUN pip install transformers==4.24.0 datasets==2.7.1 pytorch-lightning==1.7.7 triton==2.0.0.dev20221120 wandb==0.13.5 timm==0.6.12 torchmetrics==0.10.3
# For MLPerf
RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0
# Install FlashAttention
RUN pip install flash-attn==0.2.2
# Install CUDA extensions for cross-entropy, fused dense, layer norm
RUN git clone https://github.com/HazyResearch/flash-attention \
&& cd flash-attention && git checkout v0.2.2 \
&& cd csrc/fused_softmax && pip install . && cd ../../ \
&& cd csrc/rotary && pip install . && cd ../../ \
&& cd csrc/xentropy && pip install . && cd ../../ \
&& cd csrc/layer_norm && pip install . && cd ../../ \
&& cd csrc/fused_dense_lib && pip install . && cd ../../ \
&& cd .. && rm -rf flash-attention
Examples of how FlashAttention can be integrated into a model (e.g., GPT, ViT)
and trained end-to-end.
We also added optimized implementations of other layers (e.g., MLP, LayerNorm,
cross-entropy loss, rotary embedding).
Goals:
- Performance: we optimize for model speed and memory, especially on 1-node
(e.g., with 8 A100s).
- Flexibility: we provide optimized building blocks (MLP, attention, LayerNorm),
and the model code illustrates how these components can be put together.
The training code also aims to be model- & task-agnostic.
Non-goals (and other resources):
- Support as many models as possible: Huggingface's
[transformers](https://github.com/huggingface/transformers) and
[timm](https://github.com/rwightman/pytorch-image-models/) are great for this.
- Large-scale distributed training: our codebase has been used for multi-GPU and multi-node
training for models up to 2.7B parameters. However, if you're looking for large-scale distributed
training techniques (e.g., pipeline parallelism, tensor parallelism),
check out [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/) and
[DeepSpeed](https://github.com/microsoft/deepspeed).
- Inference: we currently focus on training (this might change in the future).
If you want fast inference, take a look at
[FasterTransformer](https://github.com/NVIDIA/FasterTransformer).
- Production: this codebase was written during several research projects to validate ideas
on speeding up ML models.
## Model Components
The GPT model is implemented
[here](https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/models/gpt.py).
We provide the following optimized components:
- FlashAttention: fast and memory-efficient exact attention. This makes
attention much faster and saves a lot of activation memory. As a result we don't need
to use any activation checkpointing.
```sh
pip install flash-attn
```
- Fused matmul + bias (forward and backward), and fused matmul + bias + gelu
(forward and backward), adapted from Apex's
[FusedDense](https://github.com/NVIDIA/apex/tree/master/apex/fused_dense). We
make it work for bfloat16. For best performance, you should use CUDA >= 11.8. CuBLAS versions before
this doesn't have the best matmul + bias + gelu performance for bfloat16.
```sh
cd ../csrc/fused_dense_lib && pip install .
```
- Optimized cross-entropy loss, adapted from Apex's
[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). We make it work for bfloat16 and support in-place backward to save memory.
```sh
cd ../csrc/xentropy && pip install .
```
- Fused rotary embedding:
```sh
cd ../csrc/rotary && pip install .
```
- Fused dropout + residual + LayerNorm, adapted from Apex's
[FastLayerNorm](https://github.com/NVIDIA/apex/tree/master/apex/contrib/layer_norm). We add dropout and residual, and make it work for both pre-norm and post-norm architecture.
This only supports a limited set of dimensions, see `csrc/layer_norm/ln_fwd_cuda_kernel.cu`.
```sh
cd ../csrc/layer_norm && pip install .
```
## Training
Feel free to use the model in your training setup. We also provide here training
scripts to train GPT2 on Openwebtext and GPT3 on The Pile as examples.
We use [Hydra](https://hydra.cc/) for configuration,
[Pytorch-Lightning](https://github.com/Lightning-AI/lightning) for training, and
[Wandb](https://wandb.ai/) for logging.
We use the template from `https://github.com/ashleve/lightning-hydra-template`.
Please read the instructions there to understand the repo structure.
### Dataset preparation
Running the training command would automatically download the datasets
(Openwebtext, Pile), tokenize with the GPT2 tokenizer, concatenate all the
tokens, then save this cache to disk. Alternatively, you can also prepare the
datasets as a separate steps.
The cached datasets are saved to `${DATA_DIR}/openwebtext` and
`${DATA_DIR}/the_pile`. If `${DATA_DIR}` is not set, they will be saved to
`./data/{openwebtext,the_pile}`.
- Openwebtext:
```sh
export PYTHONPATH=$PWD:$PYTHONPATH
pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "openwebtext"
```
This takes around 1h on a 64-core CPU. The processed dataset has size 17GB.
- The Pile:
```sh
export PYTHONPATH=$PWD:$PYTHONPATH
pytest -q -s tests/datamodules/test_language_modeling_hf.py -k "pile"
```
This takes around 20h on a 96-core CPU. The processed dataset has size 699GB.
### GPT2 training on Openwebtext
To train GPT2 on Openwebtext with 8 GPUs:
```sh
python run.py experiment=owt/gpt2s-flash trainer.devices=8
python run.py experiment=owt/gpt2m-flash trainer.devices=8
python run.py experiment=owt/gpt2l-flash trainer.devices=8
python run.py experiment=owt/gpt2xl-flash trainer.devices=8
```
The default parameters are set for 8 x A100 80GB.
To train with bf16 instead of fp16, add `trainer.precision=bf16`.
To adjust device batch size to fit GPU memory (the global batch size stays the
same, and gradient accumulation is calculated automatically), set `datamodule.batch_size=blah`.
### GPT3 training on The Pile
To train GPT3 on The Pile with 8 GPUs:
```sh
python run.py experiment=pile/gpt3s-flash trainer.devices=8
python run.py experiment=pile/gpt3m-flash trainer.devices=8
python run.py experiment=pile/gpt3l-flash trainer.devices=8
python run.py experiment=pile/gpt3xl-flash trainer.devices=8
```
The default parameters are set for 8 x A100 80GB.
## Requirements
Python 3.8+, Pytorch 1.12+, torchvision, einops, timm, hydra-core,
hydra-colorlog, python-dotenv, rich, pytorch-lightning, triton, flash-attn.
We recommend CUDA 11.8 (e.g., using the Nvidia's Pytorch Docker image from https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch)
We provide a Dockerfile that lists all the required packages.
causality-monitor:
_target_: src.callbacks.causality_monitor.CausalityMonitor
\ No newline at end of file
# rich_progress_bar:
# _target_: pytorch_lightning.callbacks.RichProgressBar
rich_model_summary:
_target_: pytorch_lightning.callbacks.RichModelSummary
model_checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: "val/acc" # name of the logged metric which determines when model is improving
mode: "max" # can be "max" or "min"
save_top_k: 1 # save k best models (determined by above metric)
save_last: True # additionaly always save model from last epoch
verbose: False
dirpath: ${oc.env:CHECKPOINT_DIR,checkpoints}/${oc.select:name,''}
filename: "epoch_{epoch:03d}"
auto_insert_metric_name: False
early_stopping:
_target_: pytorch_lightning.callbacks.EarlyStopping
monitor: "val/acc" # name of the logged metric which determines when model is improving
mode: "max" # can be "max" or "min"
patience: 100 # how many epochs of not improving until training stops
min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
learning_rate_monitor:
_target_: pytorch_lightning.callbacks.LearningRateMonitor
logging_interval: step
speed_monitor:
_target_: src.callbacks.speed_monitor.SpeedMonitor
intra_step_time: True
inter_step_time: True
epoch_time: True
loss_scale_monitor:
_target_: src.callbacks.loss_scale_monitor.LossScaleMonitor
params_log:
_target_: src.callbacks.params_log.ParamsLog
total_params_log: True
trainable_params_log: True
non_trainable_params_log: True
gpu_affinity:
_target_: src.callbacks.gpu_affinity.GpuAffinity
ema:
_target_: src.callbacks.ema.EMACallback
decay: ???
use_num_updates: False
flop_count:
_target_: src.callbacks.flop_count.FlopCount
profilers: ['fvcore']
input_size: [3, 224, 224]
device: null
defaults:
- default.yaml
gpu_stats_monitor:
_target_: pytorch_lightning.callbacks.GPUStatsMonitor
# [2021-08-13] TD: I just want the intra_step_size but it'll error if I
# don't have memory_utilization and gpu_utilization.
# Maybe I should write a callback with just the intra_step_size.
memory_utilization: True
gpu_utilization: True
intra_step_time: True
model_summary:
_target_: pytorch_lightning.callbacks.RichModelSummary
norm_monitor:
_target_: src.callbacks.norm_monitor.NormMonitor
params_log:
_target_: src.callbacks.params_log.ParamsLog
total_params_log: True
trainable_params_log: True
non_trainable_params_log: True
defaults:
- default.yaml
watch_model:
_target_: src.callbacks.wandb_callbacks.WatchModel
log: "all"
log_freq: 100
upload_code_as_artifact:
_target_: src.callbacks.wandb_callbacks.UploadCodeAsArtifact
code_dir: ${work_dir}/src
upload_ckpts_as_artifact:
_target_: src.callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
ckpt_dir: "checkpoints/"
upload_best_only: True
log_f1_precision_recall_heatmap:
_target_: src.callbacks.wandb_callbacks.LogF1PrecRecHeatmap
log_confusion_matrix:
_target_: src.callbacks.wandb_callbacks.LogConfusionMatrix
log_image_predictions:
_target_: src.callbacks.wandb_callbacks.LogImagePredictions
num_samples: 8
# @package _global_
# specify here default training configuration
defaults:
- _self_
- trainer: default
- optimizer: adamw
- scheduler: null
- task: sequence-model
- model: null
- datamodule: null
- callbacks: default # set this to null if you don't want to use callbacks
- metrics: null
- logger: null # set logger here or use command line (e.g. `python run.py logger=wandb`)
- mode: default
- experiment: null
- hparams_search: null
# enable color logging
- override hydra/hydra_logging: colorlog
- override hydra/job_logging: colorlog
# path to original working directory
# hydra hijacks working directory by changing it to the current log directory,
# so it's useful to have this path as a special variable
# https://hydra.cc/docs/next/tutorials/basic/running_your_app/working_directory
work_dir: ${hydra:runtime.cwd}
# path to folder with data
data_dir: ${work_dir}/data/
# pretty print config at the start of the run using Rich library
print_config: True
# disable python warnings if they annoy you
ignore_warnings: True
# check performance on test set, using the best model achieved during training
# lightning chooses best model based on metric specified in checkpoint callback
test_after_training: True
resume: False
# seed for random number generators in pytorch, numpy and python.random
seed: null
# name of the run, accessed by loggers
name: null
_target_: src.datamodules.language_modeling_hf.LMDataModule
dataset_name: openwebtext
dataset_config_name: null
tokenizer_name: gpt2
cache_dir: ${oc.env:DATA_DIR,${data_dir}}/openwebtext/cache
max_length: 1024
val_ratio: 0.0005
val_split_seed: 2357
add_eos: True
batch_size: 8 # per GPU
batch_size_eval: ${eval:${.batch_size} * 2}
num_workers: 32 # For preprocessing only
shuffle: True
pin_memory: True
__train_len: ${div_up:9035582198, ${.max_length}}
_target_: src.datamodules.language_modeling_hf.LMDataModule
dataset_name: the_pile
dataset_config_name: null
tokenizer_name: gpt2
cache_dir: ${oc.env:DATA_DIR,${data_dir}}/the_pile/cache
max_length: 2048
add_eos: True
batch_size: 4 # per GPU
batch_size_eval: ${eval:${.batch_size} * 2}
num_workers: 64 # For preprocessing only
use_shmem: False
shuffle: True
pin_memory: True
__train_len: ${div_up:374337375694, ${.max_length}}
# @package _global_
defaults:
- override /trainer: default # choose trainer from 'configs/trainer/'
- override /model: null
- override /datamodule: openwebtext
# FusedAdam from apex speeds up the optimizer step a bit, for GPT2-small time
# per global step (i.e. batch size 512) on 8 A100s goes from 376ms to 368ms.
# For GPT2-medium time per global goes from 997ms to 972ms.
- override /optimizer: adamw-apex
- override /scheduler: linear-warmup
- override /callbacks: [default, norm-monitor]
- override /metrics: [perplexity, num-tokens]
- override /logger: wandb
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
task:
_target_: src.tasks.seq.SequenceLMModel
seed: 1111
trainer:
accelerator: gpu
devices: 8
num_nodes: 1
accumulate_grad_batches: ${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}}
max_steps: 400000
val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}}
check_val_every_n_epoch: null # We don't care about epoch boundary
precision: 16
gradient_clip_val: 1.0
strategy: null
datamodule:
batch_size: 16 # Per GPU
batch_size_eval: ${.batch_size} # Fused dense only support batch size at most 64k
max_length: 1024
fault_tolerant: True
ddp: ${eval:"${trainer.devices} > 1"}
train:
gpu_mem: ${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
global_batch_size: 512
optimizer:
lr: 6e-4
weight_decay: 0.1
optimizer_param_grouping:
bias_weight_decay: False
normalization_weight_decay: False
scheduler:
num_warmup_steps: ${eval:0.01 * ${trainer.max_steps}}
num_training_steps: ${trainer.max_steps}
loss_fn:
# This is faster and uses less memory than torch.nn.CrossEntropyLoss.
# It's also more numerically stable if we're using DeepSpeed 16 bits.
_target_: src.losses.cross_entropy_apex.CrossEntropyLossApex
inplace_backward: True # to save memory
eval:
log_on_step: True # 1 training epoch takes too long, we want to see metrics per train step
callbacks:
model_checkpoint:
monitor: val/loss
mode: min
save_top_k: 3
save_last: True
every_n_train_steps: 1000
dirpath: ${work_dir}/checkpoints/${oc.select:name,''}
filename: step_{step}
auto_insert_metric_name: False
model_checkpoint_progress:
_target_: src.callbacks.model_checkpoint.ModelCheckpointMine
fault_tolerant: True
every_n_train_steps: 50000
save_last: False
save_top_k: -1 # Save all the checkpoints
dirpath: ${..model_checkpoint.dirpath}
filename: progress_step_{step}
auto_insert_metric_name: False
early_stopping: null
# @package _global_
defaults:
- /experiment/owt/gpt2m-flash.yaml
- override /model/gpt2model: gpt2-large
# TD [2022-08-03] Surprisingly it's faster to use the ZeRO optimizer than just AdamW.
# Still, fairscale is even faster and uses less memory.
# I think it's because Pytorch is using ZeRO stage 1 and fairscale is using ZeRO stage 2?
# However, fairscale has issues with saving checkpoint (either OOM or very
# slow since it goes through the CPU?). Fairscale says Pytorch ZeRO is the
# upstream version of OSS
# https://github.com/facebookresearch/fairscale/issues/937
# Pytorch ZeRO as also very slow for saving checkpoints due to
# consolidate_state_dict(), but I've fixed it to save separate checkpoint per GPU.
- override /optimizer: adamw-zero
# FusedAdam doesn't seem to speed things up here, time per global step
# (i.e. batch size 512) on 8 A100s is around 2056ms for both AdamW and FusedAdam.
# This could be because each GPU is only doing the optimizer step for 1 /
# world_size of the parameters.
# Maybe the bottleneck here is the NCCL call to exchange parameters (ZeRO).
# - override /optimizer: adamw-apex-zero
# Can enable mlp_chekcpoint_lvl to fit batch_size 16 on A100 40GB
# model:
# config:
# # mlp_checkpoint_lvl: ${eval:"[1] * 18 + [2] * 18"}
# mlp_checkpoint_lvl: 1
datamodule:
# batch_size: 16
batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else 16)"}
trainer:
# strategy: null
# strategy: ${eval:"None if ${trainer.devices} == 1 else 'ddp_sharded'"}
strategy:
_target_: src.utils.ddp_zero1.DDPStrategyZero1
find_unused_parameters: False
gradient_as_bucket_view: True
# TD [2022-08-03] Deepspeed makes the ppl curve go wild
# strategy: deepspeed_stage_1
# @package _global_
defaults:
- /experiment/owt/gpt2m-hf.yaml
- override /model/gpt2model: gpt2-large
- override /optimizer: adamw-zero
datamodule:
batch_size: 2
trainer:
strategy:
_target_: src.utils.ddp_zero1.DDPStrategyZero1
find_unused_parameters: False
gradient_as_bucket_view: True
# @package _global_
defaults:
- /experiment/owt/gpt2m.yaml
- override /model/gpt2model: gpt2-large
- override /optimizer: adamw-zero
datamodule:
batch_size: 4 # Per GPU
trainer:
strategy:
_target_: src.utils.ddp_zero1.DDPStrategyZero1
find_unused_parameters: False
gradient_as_bucket_view: True
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment