Commit 0bf5e500 authored by Tri Dao's avatar Tri Dao
Browse files

Release training code

parent 9bc63d1e
# @package _global_
defaults:
- /experiment/pile/gpt3m-flash-rotary.yaml
trainer:
max_steps: 60000
train:
scheduler:
t_initial: ${trainer.max_steps}
# @package _global_
defaults:
- /experiment/pile/gpt3m-flash-8k.yaml
model:
config:
max_position_embeddings: 0 # Disable absolute position embedding
rotary_emb_fraction: 0.5
# @package _global_
defaults:
- /experiment/pile/gpt3m-flash.yaml
model:
config:
max_position_embeddings: 0 # Disable absolute position embedding
rotary_emb_fraction: 0.5
# @package _global_
defaults:
- /experiment/pile/gpt3s-flash.yaml
- override /model/gpt2model: gpt2-medium
# Can enable mlp_checkpoint_lvl to fit batch_size 16 to A100 40GB
# model:
# config:
# mlp_checkpoint_lvl: 1
datamodule:
batch_size: ${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else 16)"}
train:
optimizer:
lr: 3.0e-4
# @package _global_
defaults:
- /experiment/pile/gpt3s-flash.yaml
datamodule:
max_length: 8192
batch_size: ${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)"}
train:
global_batch_size: 64
# @package _global_
defaults:
- /experiment/pile/gpt3s-flash-rotary.yaml
trainer:
max_steps: 60000
train:
scheduler:
t_initial: ${trainer.max_steps}
# @package _global_
defaults:
- /experiment/pile/gpt3s-flash-8k.yaml
model:
config:
max_position_embeddings: 0 # Disable absolute position embedding
rotary_emb_fraction: 0.5
# @package _global_
defaults:
- /experiment/pile/gpt3s-flash.yaml
model:
config:
max_position_embeddings: 0 # Disable absolute position embedding
rotary_emb_fraction: 0.5
# @package _global_
defaults:
- /experiment/pile/base.yaml
- override /model: gpt2
- override /model/gpt2model: gpt2-small
model:
config:
# n_positions is already set to ${datamodule.max_length}
use_flash_attn: True
fused_dropout_add_ln: True
fused_dense_gelu_dense: True
fused_bias_fc: True
pad_vocab_size_multiple: 8
datamodule:
batch_size: ${eval:"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else 32)"}
# @package _global_
defaults:
- /experiment/pile/gpt2xl-flash.yaml
datamodule:
max_length: 8192
batch_size: ${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"}
train:
global_batch_size: 128
# @package _global_
defaults:
- /experiment/pile/gpt2xl-flash-rotary.yaml
trainer:
max_steps: 60000
train:
scheduler:
t_initial: ${trainer.max_steps}
# @package _global_
defaults:
- /experiment/pile/gpt2xl-flash-8k.yaml
model:
config:
max_position_embeddings: 0 # Disable absolute position embedding
rotary_emb_fraction: 0.5
# @package _global_
defaults:
- /experiment/pile/gpt2xl-flash.yaml
model:
config:
max_position_embeddings: 0 # Disable absolute position embedding
rotary_emb_fraction: 0.5
# @package _global_
defaults:
- /experiment/pile/gpt3s-flash.yaml
- override /optimizer: adamw-zero
model:
config:
n_embd: 2048
n_head: 16
n_layer: 24
datamodule:
batch_size: ${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu} < 80 else 8))"}
train:
global_batch_size: 512
optimizer:
lr: 2.0e-4
scheduler:
t_initial: 300000
trainer:
strategy:
_target_: src.utils.ddp_zero1.DDPStrategyZero1
find_unused_parameters: False
gradient_as_bucket_view: True
max_steps: 400000
val_check_interval: ${eval:1000 * ${.accumulate_grad_batches}}
callbacks:
model_checkpoint:
every_n_train_steps: 1000
model_checkpoint_progress:
every_n_train_steps: 12500
fault_tolerant: False # Saving takes too long
# https://www.comet.ml
comet:
_target_: pytorch_lightning.loggers.comet.CometLogger
api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable
project_name: "template-tests"
experiment_name: ${name}
# csv logger built in lightning
csv:
_target_: pytorch_lightning.loggers.csv_logs.CSVLogger
save_dir: "."
name: "csv/"
version: ${name}
prefix: ""
# train with many loggers at once
defaults:
# - comet.yaml
- csv.yaml
# - mlflow.yaml
# - neptune.yaml
# - tensorboard.yaml
- wandb.yaml
# https://mlflow.org
mlflow:
_target_: pytorch_lightning.loggers.mlflow.MLFlowLogger
experiment_name: ${name}
tracking_uri: null
tags: null
save_dir: ./mlruns
prefix: ""
artifact_location: null
# https://neptune.ai
neptune:
_target_: pytorch_lightning.loggers.neptune.NeptuneLogger
api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable
project_name: your_name/template-tests
close_after_fit: True
offline_mode: False
experiment_name: ${name}
experiment_id: null
prefix: ""
# https://www.tensorflow.org/tensorboard/
tensorboard:
_target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
save_dir: "tensorboard/"
name: "default"
version: ${name}
log_graph: False
default_hp_metric: True
prefix: ""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment