Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
0bf5e500
Commit
0bf5e500
authored
Nov 28, 2022
by
Tri Dao
Browse files
Release training code
parent
9bc63d1e
Changes
139
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
221 additions
and
0 deletions
+221
-0
training/configs/experiment/pile/gpt3m-flash-rotary-30B.yaml
training/configs/experiment/pile/gpt3m-flash-rotary-30B.yaml
+10
-0
training/configs/experiment/pile/gpt3m-flash-rotary-8k.yaml
training/configs/experiment/pile/gpt3m-flash-rotary-8k.yaml
+8
-0
training/configs/experiment/pile/gpt3m-flash-rotary.yaml
training/configs/experiment/pile/gpt3m-flash-rotary.yaml
+8
-0
training/configs/experiment/pile/gpt3m-flash.yaml
training/configs/experiment/pile/gpt3m-flash.yaml
+16
-0
training/configs/experiment/pile/gpt3s-flash-8k.yaml
training/configs/experiment/pile/gpt3s-flash-8k.yaml
+10
-0
training/configs/experiment/pile/gpt3s-flash-rotary-30B.yaml
training/configs/experiment/pile/gpt3s-flash-rotary-30B.yaml
+10
-0
training/configs/experiment/pile/gpt3s-flash-rotary-8k.yaml
training/configs/experiment/pile/gpt3s-flash-rotary-8k.yaml
+8
-0
training/configs/experiment/pile/gpt3s-flash-rotary.yaml
training/configs/experiment/pile/gpt3s-flash-rotary.yaml
+8
-0
training/configs/experiment/pile/gpt3s-flash.yaml
training/configs/experiment/pile/gpt3s-flash.yaml
+17
-0
training/configs/experiment/pile/gpt3xl-flash-8k.yaml
training/configs/experiment/pile/gpt3xl-flash-8k.yaml
+10
-0
training/configs/experiment/pile/gpt3xl-flash-rotary-60B.yaml
...ning/configs/experiment/pile/gpt3xl-flash-rotary-60B.yaml
+10
-0
training/configs/experiment/pile/gpt3xl-flash-rotary-8k.yaml
training/configs/experiment/pile/gpt3xl-flash-rotary-8k.yaml
+8
-0
training/configs/experiment/pile/gpt3xl-flash-rotary.yaml
training/configs/experiment/pile/gpt3xl-flash-rotary.yaml
+8
-0
training/configs/experiment/pile/gpt3xl-flash.yaml
training/configs/experiment/pile/gpt3xl-flash.yaml
+35
-0
training/configs/logger/comet.yaml
training/configs/logger/comet.yaml
+7
-0
training/configs/logger/csv.yaml
training/configs/logger/csv.yaml
+8
-0
training/configs/logger/many_loggers.yaml
training/configs/logger/many_loggers.yaml
+9
-0
training/configs/logger/mlflow.yaml
training/configs/logger/mlflow.yaml
+10
-0
training/configs/logger/neptune.yaml
training/configs/logger/neptune.yaml
+11
-0
training/configs/logger/tensorboard.yaml
training/configs/logger/tensorboard.yaml
+10
-0
No files found.
training/configs/experiment/pile/gpt3m-flash-rotary-30B.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt3m-flash-rotary.yaml
trainer
:
max_steps
:
60000
train
:
scheduler
:
t_initial
:
${trainer.max_steps}
training/configs/experiment/pile/gpt3m-flash-rotary-8k.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt3m-flash-8k.yaml
model
:
config
:
max_position_embeddings
:
0
# Disable absolute position embedding
rotary_emb_fraction
:
0.5
training/configs/experiment/pile/gpt3m-flash-rotary.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt3m-flash.yaml
model
:
config
:
max_position_embeddings
:
0
# Disable absolute position embedding
rotary_emb_fraction
:
0.5
training/configs/experiment/pile/gpt3m-flash.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt3s-flash.yaml
-
override /model/gpt2model
:
gpt2-medium
# Can enable mlp_checkpoint_lvl to fit batch_size 16 to A100 40GB
# model:
# config:
# mlp_checkpoint_lvl: 1
datamodule
:
batch_size
:
${eval:"4 if ${train.gpu_mem} < 24 else (8 if ${train.gpu_mem} < 40 else 16)"}
train
:
optimizer
:
lr
:
3.0e-4
training/configs/experiment/pile/gpt3s-flash-8k.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt3s-flash.yaml
datamodule
:
max_length
:
8192
batch_size
:
${eval:"2 if ${train.gpu_mem} < 24 else (4 if ${train.gpu_mem} < 40 else 8)"}
train
:
global_batch_size
:
64
training/configs/experiment/pile/gpt3s-flash-rotary-30B.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt3s-flash-rotary.yaml
trainer
:
max_steps
:
60000
train
:
scheduler
:
t_initial
:
${trainer.max_steps}
training/configs/experiment/pile/gpt3s-flash-rotary-8k.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt3s-flash-8k.yaml
model
:
config
:
max_position_embeddings
:
0
# Disable absolute position embedding
rotary_emb_fraction
:
0.5
training/configs/experiment/pile/gpt3s-flash-rotary.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt3s-flash.yaml
model
:
config
:
max_position_embeddings
:
0
# Disable absolute position embedding
rotary_emb_fraction
:
0.5
training/configs/experiment/pile/gpt3s-flash.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/base.yaml
-
override /model
:
gpt2
-
override /model/gpt2model
:
gpt2-small
model
:
config
:
# n_positions is already set to ${datamodule.max_length}
use_flash_attn
:
True
fused_dropout_add_ln
:
True
fused_dense_gelu_dense
:
True
fused_bias_fc
:
True
pad_vocab_size_multiple
:
8
datamodule
:
batch_size
:
${eval:"8 if ${train.gpu_mem} < 24 else (16 if ${train.gpu_mem} < 40 else 32)"}
training/configs/experiment/pile/gpt3xl-flash-8k.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt2xl-flash.yaml
datamodule
:
max_length
:
8192
batch_size
:
${eval:"1 if ${train.gpu_mem} < 40 else (2 if ${train.gpu_mem} < 80 else 4)"}
train
:
global_batch_size
:
128
training/configs/experiment/pile/gpt3xl-flash-rotary-60B.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt2xl-flash-rotary.yaml
trainer
:
max_steps
:
60000
train
:
scheduler
:
t_initial
:
${trainer.max_steps}
training/configs/experiment/pile/gpt3xl-flash-rotary-8k.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt2xl-flash-8k.yaml
model
:
config
:
max_position_embeddings
:
0
# Disable absolute position embedding
rotary_emb_fraction
:
0.5
training/configs/experiment/pile/gpt3xl-flash-rotary.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt2xl-flash.yaml
model
:
config
:
max_position_embeddings
:
0
# Disable absolute position embedding
rotary_emb_fraction
:
0.5
training/configs/experiment/pile/gpt3xl-flash.yaml
0 → 100644
View file @
0bf5e500
# @package _global_
defaults
:
-
/experiment/pile/gpt3s-flash.yaml
-
override /optimizer
:
adamw-zero
model
:
config
:
n_embd
:
2048
n_head
:
16
n_layer
:
24
datamodule
:
batch_size
:
${eval:"1 if ${train.gpu_mem} < 24 else (2 if ${train.gpu_mem} < 40 else (4 if ${train.gpu} < 80 else 8))"}
train
:
global_batch_size
:
512
optimizer
:
lr
:
2.0e-4
scheduler
:
t_initial
:
300000
trainer
:
strategy
:
_target_
:
src.utils.ddp_zero1.DDPStrategyZero1
find_unused_parameters
:
False
gradient_as_bucket_view
:
True
max_steps
:
400000
val_check_interval
:
${eval:1000 * ${.accumulate_grad_batches}}
callbacks
:
model_checkpoint
:
every_n_train_steps
:
1000
model_checkpoint_progress
:
every_n_train_steps
:
12500
fault_tolerant
:
False
# Saving takes too long
training/configs/logger/comet.yaml
0 → 100644
View file @
0bf5e500
# https://www.comet.ml
comet
:
_target_
:
pytorch_lightning.loggers.comet.CometLogger
api_key
:
${oc.env:COMET_API_TOKEN}
# api key is loaded from environment variable
project_name
:
"
template-tests"
experiment_name
:
${name}
training/configs/logger/csv.yaml
0 → 100644
View file @
0bf5e500
# csv logger built in lightning
csv
:
_target_
:
pytorch_lightning.loggers.csv_logs.CSVLogger
save_dir
:
"
."
name
:
"
csv/"
version
:
${name}
prefix
:
"
"
training/configs/logger/many_loggers.yaml
0 → 100644
View file @
0bf5e500
# train with many loggers at once
defaults
:
# - comet.yaml
-
csv.yaml
# - mlflow.yaml
# - neptune.yaml
# - tensorboard.yaml
-
wandb.yaml
training/configs/logger/mlflow.yaml
0 → 100644
View file @
0bf5e500
# https://mlflow.org
mlflow
:
_target_
:
pytorch_lightning.loggers.mlflow.MLFlowLogger
experiment_name
:
${name}
tracking_uri
:
null
tags
:
null
save_dir
:
./mlruns
prefix
:
"
"
artifact_location
:
null
training/configs/logger/neptune.yaml
0 → 100644
View file @
0bf5e500
# https://neptune.ai
neptune
:
_target_
:
pytorch_lightning.loggers.neptune.NeptuneLogger
api_key
:
${oc.env:NEPTUNE_API_TOKEN}
# api key is loaded from environment variable
project_name
:
your_name/template-tests
close_after_fit
:
True
offline_mode
:
False
experiment_name
:
${name}
experiment_id
:
null
prefix
:
"
"
training/configs/logger/tensorboard.yaml
0 → 100644
View file @
0bf5e500
# https://www.tensorflow.org/tensorboard/
tensorboard
:
_target_
:
pytorch_lightning.loggers.tensorboard.TensorBoardLogger
save_dir
:
"
tensorboard/"
name
:
"
default"
version
:
${name}
log_graph
:
False
default_hp_metric
:
True
prefix
:
"
"
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment