Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
0bf5e500
Commit
0bf5e500
authored
Nov 28, 2022
by
Tri Dao
Browse files
Release training code
parent
9bc63d1e
Changes
139
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
366 additions
and
0 deletions
+366
-0
training/configs/experiment/owt/gpt2m-flash.yaml
training/configs/experiment/owt/gpt2m-flash.yaml
+17
-0
training/configs/experiment/owt/gpt2m-hf.yaml
training/configs/experiment/owt/gpt2m-hf.yaml
+11
-0
training/configs/experiment/owt/gpt2m.yaml
training/configs/experiment/owt/gpt2m.yaml
+11
-0
training/configs/experiment/owt/gpt2s-flash.yaml
training/configs/experiment/owt/gpt2s-flash.yaml
+18
-0
training/configs/experiment/owt/gpt2s-hf.yaml
training/configs/experiment/owt/gpt2s-hf.yaml
+23
-0
training/configs/experiment/owt/gpt2s.yaml
training/configs/experiment/owt/gpt2s.yaml
+8
-0
training/configs/experiment/owt/gpt2xl-flash.yaml
training/configs/experiment/owt/gpt2xl-flash.yaml
+21
-0
training/configs/experiment/owt/gpt2xl.yaml
training/configs/experiment/owt/gpt2xl.yaml
+14
-0
training/configs/experiment/pile/base.yaml
training/configs/experiment/pile/base.yaml
+83
-0
training/configs/experiment/pile/gpt3-2.7B-flash-8k.yaml
training/configs/experiment/pile/gpt3-2.7B-flash-8k.yaml
+18
-0
training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary-8k.yaml
...gs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary-8k.yaml
+18
-0
training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary.yaml
...nfigs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary.yaml
+18
-0
training/configs/experiment/pile/gpt3-2.7B-flash-rotary-8k.yaml
...ng/configs/experiment/pile/gpt3-2.7B-flash-rotary-8k.yaml
+18
-0
training/configs/experiment/pile/gpt3-2.7B-flash-rotary.yaml
training/configs/experiment/pile/gpt3-2.7B-flash-rotary.yaml
+18
-0
training/configs/experiment/pile/gpt3l-flash-8k.yaml
training/configs/experiment/pile/gpt3l-flash-8k.yaml
+10
-0
training/configs/experiment/pile/gpt3l-flash-rotary-30B.yaml
training/configs/experiment/pile/gpt3l-flash-rotary-30B.yaml
+10
-0
training/configs/experiment/pile/gpt3l-flash-rotary-8k.yaml
training/configs/experiment/pile/gpt3l-flash-rotary-8k.yaml
+8
-0
training/configs/experiment/pile/gpt3l-flash-rotary.yaml
training/configs/experiment/pile/gpt3l-flash-rotary.yaml
+8
-0
training/configs/experiment/pile/gpt3l-flash.yaml
training/configs/experiment/pile/gpt3l-flash.yaml
+24
-0
training/configs/experiment/pile/gpt3m-flash-8k.yaml
training/configs/experiment/pile/gpt3m-flash-8k.yaml
+10
-0
No files found.
training/configs/experiment/owt/gpt2m-flash.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/owt/gpt2s-flash.yaml
-
override /model/gpt2model
:
gpt2-medium
# Can enable mlp_checkpoint_lvl to fit batch_size 32 to A100 40GB
model
:
config
:
mlp_checkpoint_lvl
:
1
datamodule
:
# batch_size: 32
batch_size
:
${eval:"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else 32)"}
train
:
optimizer
:
lr
:
1.5e-4
training/configs/experiment/owt/gpt2m-hf.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/owt/gpt2s-hf.yaml
-
override /model/gpt2model
:
gpt2-medium
datamodule
:
batch_size
:
4
train
:
optimizer
:
lr
:
1.5e-4
training/configs/experiment/owt/gpt2m.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/owt/gpt2s.yaml
-
override /model/gpt2model
:
gpt2-medium
datamodule
:
batch_size
:
8
# Per GPU
train
:
optimizer
:
lr
:
1.5e-4
training/configs/experiment/owt/gpt2s-flash.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/owt/base.yaml
-
override /model
:
gpt2
-
override /model/gpt2model
:
gpt2-small
model
:
config
:
# n_positions is already set to ${datamodule.max_length}
use_flash_attn
:
True
fused_bias_fc
:
True
fused_dense_gelu_dense
:
True
fused_dropout_add_ln
:
True
pad_vocab_size_multiple
:
8
datamodule
:
# batch_size: 64
batch_size
:
${eval:"16 if ${train.gpu_mem} < 24 else (32 if ${train.gpu_mem} < 40 else 64)"}
training/configs/experiment/owt/gpt2s-hf.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/owt/base.yaml
-
override /model
:
gpt2-hf
-
override /model/gpt2model
:
gpt2-small
-
override /callbacks
:
[
default
,
norm-monitor
,
flop-count
]
datamodule
:
batch_size
:
8
train
:
# Use the standard torch.nn.CrossEntropyLoss
loss_fn
:
null
callbacks
:
flop_count
:
input_size
:
-
${datamodule.max_length}
input_dtype
:
# It's surprisingly hard to get hydra to return torch.long since it's not a callable
_target_
:
torch.__getattribute__
_args_
:
-
long
training/configs/experiment/owt/gpt2s.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/owt/base.yaml
-
override /model
:
gpt2
-
override /model/gpt2model
:
gpt2-small
datamodule
:
batch_size
:
${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else 16)"}
training/configs/experiment/owt/gpt2xl-flash.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/owt/gpt2l-flash.yaml
-
override /model/gpt2model
:
gpt2-xlarge
# Can enable mlp_checkpoint_lvl to fit to A100 40GB
# model:
# config:
# # mlp_checkpoint_lvl: ${eval:"[1] * 18 + [2] * 18"}
# mlp_checkpoint_lvl: 1
datamodule
:
batch_size
:
${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)"}
# With adamw-zero optimizer:
# checkpoint_lvl=1, batch size = 4: mem 37GB, 4650ms / batch of 512 (285ms * 15 + 375ms * 1)
# checkpoint_lvl=1, batch size = 8: mem 46GB, 4330ms / batch of 512 (530ms * 7 + 620ms * 1)
# checkpoint_lvl=2, batch size = 8: mem 41GB, 4570ms / batch of 512 (560ms * 7 + 650ms * 1)
# With adamw-apex-distributed optimizer:
# checkpoint_lvl=1, batch size = 8: mem 41.5GB, 4500ms / batch of 512 (550ms * 7 + 650ms * 1)
# checkpoint_lvl=1 for 24 layers and checkpoint_lvl=2 for 24 layers,
# batch size = 8: mem 39GB, 4640ms / batch of 512 (565ms * 7 + 675ms * 1)
training/configs/experiment/owt/gpt2xl.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/owt/gpt2m.yaml
-
override /model/gpt2model
:
gpt2-xlarge
-
override /optimizer
:
adamw-zero
datamodule
:
batch_size
:
2
# Per GPU
trainer
:
strategy
:
_target_
:
src.utils.ddp_zero1.DDPStrategyZero1
find_unused_parameters
:
False
gradient_as_bucket_view
:
True
training/configs/experiment/pile/base.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
override /trainer
:
default
# choose trainer from 'configs/trainer/'
-
override /model
:
null
-
override /datamodule
:
thepile
-
override /optimizer
:
adamw-apex
# slight speedup (1-2%) over Pytorch AdamW
-
override /scheduler
:
cosine-warmup-timm
-
override /callbacks
:
[
default
,
norm-monitor
]
-
override /metrics
:
[
perplexity
,
num-tokens
]
-
override /logger
:
wandb
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters
task
:
_target_
:
src.tasks.seq.SequenceLMModel
seed
:
1111
trainer
:
accelerator
:
gpu
devices
:
8
num_nodes
:
1
accumulate_grad_batches
:
${div_up:${train.global_batch_size}, ${eval:${trainer.devices} * ${datamodule.batch_size} * ${trainer.num_nodes}}}
max_steps
:
800000
val_check_interval
:
${eval:2000 * ${.accumulate_grad_batches}}
check_val_every_n_epoch
:
null
# We don't care about epoch boundary
precision
:
bf16
gradient_clip_val
:
1.0
strategy
:
null
datamodule
:
batch_size
:
16
# Per GPU
batch_size_eval
:
${.batch_size}
# Fused dense only support batch size at most 64k
max_length
:
2048
fault_tolerant
:
True
ddp
:
${eval:"${trainer.devices} > 1"}
train
:
gpu_mem
:
${eval:"round(float(__import__('subprocess').check_output('nvidia-smi -i 0 --query-gpu=memory.total --format=csv,noheader,nounits', shell=True).strip().decode()) / 1000)"}
global_batch_size
:
256
optimizer
:
lr
:
6e-4
weight_decay
:
0.1
optimizer_param_grouping
:
bias_weight_decay
:
False
normalization_weight_decay
:
False
scheduler
:
t_in_epochs
:
False
t_initial
:
600000
warmup_lr_init
:
1e-6
warmup_t
:
${eval:0.01 * ${trainer.max_steps}}
lr_min
:
${eval:0.1 * ${train.optimizer.lr}}
loss_fn
:
# This is faster and uses less memory than torch.nn.CrossEntropyLoss.
# It's also more numerically stable if we're using DeepSpeed 16 bits.
_target_
:
src.losses.cross_entropy_apex.CrossEntropyLossApex
inplace_backward
:
True
# to save memory
eval
:
log_on_step
:
True
# 1 training epoch takes too long, we want to see metrics per train step
callbacks
:
model_checkpoint
:
monitor
:
val/loss
mode
:
min
save_top_k
:
3
save_last
:
True
every_n_train_steps
:
1000
dirpath
:
${work_dir}/checkpoints/${oc.select:name,''}
filename
:
step_{step}
auto_insert_metric_name
:
False
model_checkpoint_progress
:
_target_
:
src.callbacks.model_checkpoint.ModelCheckpointMine
# fault_tolerant: True # The .pl_auto_save.ckpt doesn't get saved by all workers
every_n_train_steps
:
50000
save_last
:
False
save_top_k
:
-1
# Save all the checkpoints
dirpath
:
${..model_checkpoint.dirpath}
filename
:
progress_step_{step}
auto_insert_metric_name
:
False
early_stopping
:
null
training/configs/experiment/pile/gpt3-2.7B-flash-8k.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt2xl-flash-8k.yaml
model
:
config
:
n_embd
:
2560
n_head
:
32
n_layer
:
32
initializer_range
:
${eval:"(2 / (${.n_embd} * 5)) ** 0.5"}
mlp_checkpoint_lvl
:
0
datamodule
:
batch_size
:
${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"}
train
:
optimizer
:
lr
:
1.6e-4
training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary-8k.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt2xl-flash-rotary-8k.yaml
model
:
config
:
n_embd
:
2560
n_head
:
20
n_layer
:
32
initializer_range
:
${eval:"(2 / (${.n_embd} * 5)) ** 0.5"}
mlp_checkpoint_lvl
:
0
datamodule
:
batch_size
:
${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu} < 80 else 8))"}
train
:
optimizer
:
lr
:
1.6e-4
training/configs/experiment/pile/gpt3-2.7B-flash-hdim128-rotary.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt2xl-flash-rotary.yaml
model
:
config
:
n_embd
:
2560
n_head
:
20
n_layer
:
32
initializer_range
:
${eval:"(2 / (${.n_embd} * 5)) ** 0.5"}
mlp_checkpoint_lvl
:
0
datamodule
:
batch_size
:
${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu} < 80 else 32))"}
train
:
optimizer
:
lr
:
1.6e-4
training/configs/experiment/pile/gpt3-2.7B-flash-rotary-8k.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt2xl-flash-rotary-8k.yaml
model
:
config
:
n_embd
:
2560
n_head
:
32
n_layer
:
32
initializer_range
:
${eval:"(2 / (${.n_embd} * 5)) ** 0.5"}
mlp_checkpoint_lvl
:
0
datamodule
:
batch_size
:
${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu} < 80 else 8))"}
train
:
optimizer
:
lr
:
1.6e-4
training/configs/experiment/pile/gpt3-2.7B-flash-rotary.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt2xl-flash-rotary.yaml
model
:
config
:
n_embd
:
2560
n_head
:
32
n_layer
:
32
initializer_range
:
${eval:"(2 / (${.n_embd} * 5)) ** 0.5"}
mlp_checkpoint_lvl
:
0
datamodule
:
batch_size
:
${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else (16 if ${train.gpu} < 80 else 32))"}
train
:
optimizer
:
lr
:
1.6e-4
training/configs/experiment/pile/gpt3l-flash-8k.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt3l-flash.yaml
datamodule
:
max_length
:
8192
batch_size
:
${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"}
train
:
global_batch_size
:
64
training/configs/experiment/pile/gpt3l-flash-rotary-30B.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt3l-flash-rotary.yaml
trainer
:
max_steps
:
60000
train
:
scheduler
:
t_initial
:
${trainer.max_steps}
training/configs/experiment/pile/gpt3l-flash-rotary-8k.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt3l-flash-8k.yaml
model
:
config
:
max_position_embeddings
:
0
# Disable absolute position embedding
rotary_emb_fraction
:
0.5
training/configs/experiment/pile/gpt3l-flash-rotary.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt3l-flash.yaml
model
:
config
:
max_position_embeddings
:
0
# Disable absolute position embedding
rotary_emb_fraction
:
0.5
training/configs/experiment/pile/gpt3l-flash.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt3s-flash.yaml
-
override /optimizer
:
adamw-zero
model
:
config
:
n_embd
:
1536
n_head
:
16
n_layer
:
24
# mlp_checkpoint_lvl: 1 # To fit batch_size 8
datamodule
:
batch_size
:
${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else (8 if ${train.gpu_mem} < 80 else 16))"}
train
:
optimizer
:
lr
:
2.5e-4
trainer
:
strategy
:
_target_
:
src.utils.ddp_zero1.DDPStrategyZero1
find_unused_parameters
:
False
gradient_as_bucket_view
:
True
training/configs/experiment/pile/gpt3m-flash-8k.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt3m-flash.yaml
datamodule
:
max_length
:
8192
batch_size
:
${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)"}
train
:
global_batch_size
:
64
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment