Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
73f9f23f
Unverified
Commit
73f9f23f
authored
Feb 05, 2024
by
Hongxin Liu
Committed by
GitHub
Feb 05, 2024
Browse files
[llama] update training script (#5360)
* [llama] update training script * [doc] polish docstr
parent
6c0fa7b9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
107 additions
and
477 deletions
+107
-477
applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
...ations/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
+6
-4
applications/Colossal-LLaMA-2/train.example.sh
applications/Colossal-LLaMA-2/train.example.sh
+1
-0
applications/Colossal-LLaMA-2/train.py
applications/Colossal-LLaMA-2/train.py
+98
-69
applications/Colossal-LLaMA-2/train_sft.example.sh
applications/Colossal-LLaMA-2/train_sft.example.sh
+2
-1
applications/Colossal-LLaMA-2/train_sft.py
applications/Colossal-LLaMA-2/train_sft.py
+0
-403
No files found.
applications/Colossal-LLaMA-2/colossal_llama2/dataset/loader.py
View file @
73f9f23f
...
...
@@ -58,6 +58,7 @@ class DataCollatorForSupervisedDataset(object):
tokenizer
:
PreTrainedTokenizer
max_length
:
int
=
4096
ignore_index
:
int
=
-
100
padding
:
str
=
"max_length"
def
__call__
(
self
,
instances
:
Sequence
[
Dict
[
str
,
List
[
int
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
"""
...
...
@@ -102,10 +103,11 @@ class DataCollatorForSupervisedDataset(object):
batch_first
=
True
,
padding_value
=
self
.
ignore_index
,
)
# (bsz, max_len)
# pad to max
to_pad
=
self
.
max_length
-
input_ids
.
size
(
1
)
input_ids
=
F
.
pad
(
input_ids
,
(
0
,
to_pad
),
value
=
self
.
tokenizer
.
pad_token_id
)
labels
=
F
.
pad
(
labels
,
(
0
,
to_pad
),
value
=
self
.
ignore_index
)
if
self
.
padding
==
"max_length"
:
# pad to max
to_pad
=
self
.
max_length
-
input_ids
.
size
(
1
)
input_ids
=
F
.
pad
(
input_ids
,
(
0
,
to_pad
),
value
=
self
.
tokenizer
.
pad_token_id
)
labels
=
F
.
pad
(
labels
,
(
0
,
to_pad
),
value
=
self
.
ignore_index
)
elif
self
.
tokenizer
.
padding_side
==
"left"
:
reversed_input_ids
=
[
seq
.
flip
(
dims
=
(
0
,))
for
seq
in
batch_input_ids
]
reversed_input_ids
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
...
...
applications/Colossal-LLaMA-2/train.example.sh
View file @
73f9f23f
...
...
@@ -42,3 +42,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train.
--warmup_steps
100
\
--use_grad_checkpoint
\
--use_flash_attn
\
--pad_token
"unk"
applications/Colossal-LLaMA-2/train.py
View file @
73f9f23f
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Continual Pre-training
of
LLaMA-2 developed by Colossal-AI Team
Continual Pre-training
/Supervised fine-tuning of Colossal-
LLaMA-2 developed by Colossal-AI Team
"""
import
argparse
...
...
@@ -20,17 +20,20 @@ from colossal_llama2.dataset.loader import (
from
colossal_llama2.utils.ckpt_io
import
load_checkpoint
,
save_checkpoint
from
colossal_llama2.utils.flash_attention_patch
import
replace_with_flash_attention
from
colossal_llama2.utils.froze
import
freeze_non_embeds_parameters
from
colossal_llama2.utils.neftune_patch
import
activate_neftune
,
deactivate_neftune
from
torch.utils.tensorboard
import
SummaryWriter
from
tqdm
import
tqdm
from
transformers
import
LlamaConfig
,
LlamaForCausalLM
,
LlamaTokenizer
from
transformers
import
LlamaForCausalLM
,
LlamaTokenizer
import
colossalai
from
colossalai.accelerator
import
get_accelerator
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
GeminiPlugin
,
HybridParallelPlugin
,
LowLevelZeroPlugin
from
colossalai.cluster
import
DistCoordinator
from
colossalai.lazy
import
LazyInitContext
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.utils
import
get_current_device
def
get_model_numel
(
model
:
torch
.
nn
.
Module
)
->
int
:
...
...
@@ -82,6 +85,7 @@ def main() -> None:
parser
.
add_argument
(
"--tensorboard_dir"
,
type
=
str
,
default
=
"logs_dir"
,
help
=
"Tensorboard directory"
)
parser
.
add_argument
(
"--config_file"
,
type
=
str
,
default
=
"config_file"
,
help
=
"Config file"
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
1
,
help
=
"Number of training epochs"
)
parser
.
add_argument
(
"--accumulation_steps"
,
type
=
int
,
default
=
1
,
help
=
"Number of accumulation steps"
)
parser
.
add_argument
(
"--micro_batch_size"
,
type
=
int
,
default
=
2
,
help
=
"Batch size of each process"
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
3e-4
,
help
=
"Learning rate"
)
parser
.
add_argument
(
"--max_length"
,
type
=
int
,
default
=
4096
,
help
=
"Model max length"
)
...
...
@@ -107,6 +111,12 @@ def main() -> None:
default
=
False
,
help
=
"Use flash-attention"
,
)
parser
.
add_argument
(
"--use_neft"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use NEFTune"
,
)
parser
.
add_argument
(
"--freeze_non_embeds_params"
,
action
=
"store_true"
,
...
...
@@ -115,6 +125,8 @@ def main() -> None:
)
parser
.
add_argument
(
"--tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--zero"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--pad_token"
,
choices
=
[
"eos"
,
"unk"
],
default
=
"eos"
)
parser
.
add_argument
(
"--padding_mode"
,
choices
=
[
"max_length"
,
"longest"
],
default
=
"max_length"
)
args
=
parser
.
parse_args
()
with
open
(
args
.
config_file
,
"w"
)
as
f
:
...
...
@@ -124,6 +136,7 @@ def main() -> None:
# Initialize Distributed Training
# ==============================
colossalai
.
launch_from_torch
({})
accelerator
=
get_accelerator
()
coordinator
=
DistCoordinator
()
# ==============================
...
...
@@ -181,7 +194,10 @@ def main() -> None:
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
args
.
pretrained
)
tokenizer
.
pad_token
=
tokenizer
.
unk_token
if
args
.
pad_token
==
"eos"
:
tokenizer
.
pad_token
=
tokenizer
.
eos_token
elif
args
.
pad_token
==
"unk"
:
tokenizer
.
pad_token
=
tokenizer
.
unk_token
tokenizer
.
add_bos_token
=
False
tokenizer
.
add_eos_token
=
False
...
...
@@ -192,7 +208,9 @@ def main() -> None:
coordinator
.
print_on_master
(
f
"Load dataset:
{
args
.
dataset
}
"
)
dataset
=
load_tokenized_dataset
(
dataset_paths
=
args
.
dataset
,
mode
=
"train"
)
data_collator
=
DataCollatorForSupervisedDataset
(
tokenizer
=
tokenizer
,
max_length
=
args
.
max_length
)
data_collator
=
DataCollatorForSupervisedDataset
(
tokenizer
=
tokenizer
,
max_length
=
args
.
max_length
,
padding
=
args
.
padding_mode
)
dataloader
=
plugin
.
prepare_dataloader
(
dataset
=
dataset
,
batch_size
=
args
.
micro_batch_size
,
...
...
@@ -202,26 +220,19 @@ def main() -> None:
distributed_sampler_cls
=
StatefulDistributedSampler
,
)
coordinator
.
print_on_master
(
f
"Max
CUDA
memory after data loader:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
f
"Max
device
memory after data loader:
{
accelerator
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# colossalai has changed api for get_current_device in 0.3.4 version or newer
try
:
from
colossalai.accelerator
import
get_accelerator
current_device
=
get_accelerator
().
get_current_device
()
except
:
from
colossalai.utils
import
get_current_device
current_device
=
get_current_device
()
init_ctx
=
LazyInitContext
(
default_device
=
current_device
)
if
isinstance
(
plugin
,
(
GeminiPlugin
,))
else
nullcontext
()
init_ctx
=
(
LazyInitContext
(
default_device
=
get_current_device
())
if
isinstance
(
plugin
,
(
GeminiPlugin
,
HybridParallelPlugin
))
else
nullcontext
()
)
with
init_ctx
:
model
=
LlamaForCausalLM
(
LlamaConfig
.
from_pretrained
(
args
.
pretrained
)
)
model
=
LlamaForCausalLM
.
from_pretrained
(
args
.
pretrained
)
# Freeze part of parameters.
if
args
.
freeze_non_embeds_params
:
freeze_non_embeds_parameters
(
model
=
model
)
...
...
@@ -246,12 +257,14 @@ def main() -> None:
adamw_mode
=
True
,
)
if
args
.
warmup_steps
is
None
:
args
.
warmup_steps
=
int
(
args
.
num_epochs
*
0.025
*
(
len
(
dataloader
)
//
args
.
accumulation_steps
))
coordinator
.
print_on_master
(
f
"Warmup steps is set to
{
args
.
warmup_steps
}
"
)
lr_scheduler
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer
,
total_steps
=
args
.
num_epochs
*
len
(
dataloader
),
warmup_steps
=
args
.
warmup_steps
if
args
.
warmup_steps
is
not
None
else
int
(
args
.
num_epochs
*
len
(
dataloader
)
*
0.025
),
total_steps
=
args
.
num_epochs
*
(
len
(
dataloader
)
//
args
.
accumulation_steps
),
warmup_steps
=
args
.
warmup_steps
,
eta_min
=
0.1
*
args
.
lr
,
)
...
...
@@ -267,11 +280,9 @@ def main() -> None:
torch
.
set_default_dtype
(
torch
.
float
)
if
args
.
load_checkpoint
is
None
:
coordinator
.
print_on_master
(
f
"Load pretrained model checkpoint from
{
args
.
pretrained
}
"
)
booster
.
load_model
(
model
,
args
.
pretrained
,
strict
=
False
)
coordinator
.
print_on_master
(
f
"Booster init max CUDA memory:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
coordinator
.
print_on_master
(
f
"Booster init max device memory:
{
accelerator
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
coordinator
.
print_on_master
(
f
"Booster init max CPU memory:
{
resource
.
getrusage
(
resource
.
RUSAGE_SELF
).
ru_maxrss
/
1024
:.
2
f
}
MB"
)
...
...
@@ -298,85 +309,103 @@ def main() -> None:
coordinator
.
print_on_master
(
f
"Loaded sample at index
{
sampler_start_idx
}
"
)
coordinator
.
print_on_master
(
f
"Checkpoint loaded max
CUDA
memory:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
f
"Checkpoint loaded max
device
memory:
{
accelerator
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
coordinator
.
print_on_master
(
f
"Checkpoint loaded
CUDA
memory:
{
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
f
"Checkpoint loaded
device
memory:
{
accelerator
.
memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
coordinator
.
print_on_master
(
f
"Checkpoint loaded max CPU memory:
{
resource
.
getrusage
(
resource
.
RUSAGE_SELF
).
ru_maxrss
/
1024
:.
2
f
}
MB"
)
num_steps_per_epoch
=
len
(
dataloader
)
if
args
.
use_neft
:
coordinator
.
print_on_master
(
"Activate NEFTune."
)
model
,
handle
=
activate_neftune
(
model
)
num_steps_per_epoch
=
len
(
dataloader
)
//
args
.
accumulation_steps
# If resume training, set the sampler start index to the correct value
assert
isinstance
(
dataloader
.
sampler
,
StatefulDistributedSampler
)
dataloader
.
sampler
.
set_start_index
(
start_index
=
sampler_start_idx
)
for
epoch
in
range
(
start_epoch
,
args
.
num_epochs
):
dataloader
.
sampler
.
set_epoch
(
epoch
=
epoch
)
with
tqdm
(
iterable
=
enumerate
(
dataloader
,
start
=
start_step
),
desc
=
f
"Epoch
{
epoch
}
"
,
disable
=
not
coordinator
.
is_master
(),
total
=
num_steps_per_epoch
,
initial
=
start_step
,
)
as
pbar
:
for
step
,
batch
in
pbar
:
batch
=
{
k
:
v
.
to
(
current_device
)
for
k
,
v
in
batch
.
items
()
if
isinstance
(
v
,
torch
.
Tensor
)}
pbar
=
tqdm
(
desc
=
f
"Epoch
{
epoch
}
"
,
disable
=
not
coordinator
.
is_master
(),
total
=
num_steps_per_epoch
)
total_loss
=
torch
.
tensor
(
0.0
,
device
=
get_current_device
())
for
step
,
batch
in
enumerate
(
dataloader
):
batch
=
{
k
:
v
.
to
(
get_current_device
())
for
k
,
v
in
batch
.
items
()
if
isinstance
(
v
,
torch
.
Tensor
)}
batch_output
=
model
(
**
batch
)
batch_output
=
model
(
**
batch
)
loss
=
batch_output
.
loss
loss
=
batch_output
.
loss
/
args
.
accumulation_steps
total_loss
.
add_
(
loss
.
data
)
booster
.
backward
(
loss
=
loss
,
optimizer
=
optimizer
)
booster
.
backward
(
loss
=
loss
,
optimizer
=
optimizer
)
if
(
step
+
1
)
%
args
.
accumulation_steps
==
0
:
optimizer
.
step
()
lr_scheduler
.
step
()
optimizer
.
zero_grad
()
all_reduce_mean
(
tensor
=
loss
)
pbar
.
set_postfix
({
"Loss"
:
f
"
{
loss
.
item
():.
4
f
}
"
})
all_reduce_mean
(
tensor
=
total_
loss
)
pbar
.
set_postfix
({
"Loss"
:
f
"
{
total_
loss
.
item
():.
4
f
}
"
})
if
coordinator
.
is_master
():
global_step
=
epoch
*
num_steps_per_epoch
+
step
writer
.
add_scalar
(
tag
=
"Loss"
,
scalar_value
=
loss
.
item
(),
global_step
=
global_step
)
global_step
=
(
epoch
*
num_steps_per_epoch
)
+
(
step
+
1
)
//
args
.
accumulation_steps
writer
.
add_scalar
(
tag
=
"Loss"
,
scalar_value
=
total_
loss
.
item
(),
global_step
=
global_step
)
writer
.
add_scalar
(
tag
=
"Learning Rate"
,
scalar_value
=
lr_scheduler
.
get_last_lr
()[
0
],
global_step
=
global_step
,
)
# Save modeling.
if
(
args
.
save_interval
>
0
and
(
step
+
1
)
%
args
.
save_interval
==
0
)
or
(
step
+
1
)
==
len
(
dataloader
):
coordinator
.
print_on_master
(
"
\n
Start saving model checkpoint with running states"
)
save_checkpoint
(
save_dir
=
args
.
save_dir
,
booster
=
booster
,
model
=
model
,
optimizer
=
optimizer
,
lr_scheduler
=
lr_scheduler
,
epoch
=
epoch
,
step
=
step
+
1
,
batch_size
=
args
.
micro_batch_size
,
coordinator
=
coordinator
,
)
coordinator
.
print_on_master
(
f
"Saved checkpoint at epoch
{
epoch
}
step
{
step
+
1
}
at folder
{
args
.
save_dir
}
"
)
# Delete CUDA cache.
# del batch, batch_labels, batch_output, loss
torch
.
cuda
.
empty_cache
()
total_loss
.
fill_
(
0.0
)
pbar
.
update
()
# Save modeling.
if
(
args
.
save_interval
>
0
and
(
step
+
1
)
%
(
args
.
save_interval
*
args
.
accumulation_steps
)
==
0
)
or
(
step
+
1
)
==
len
(
dataloader
):
coordinator
.
print_on_master
(
"
\n
Start saving model checkpoint with running states"
)
if
args
.
use_neft
:
coordinator
.
print_on_master
(
"Deactivate NEFTune before saving model."
)
deactivate_neftune
(
model
,
handle
)
save_checkpoint
(
save_dir
=
args
.
save_dir
,
booster
=
booster
,
model
=
model
,
optimizer
=
optimizer
,
lr_scheduler
=
lr_scheduler
,
epoch
=
epoch
,
step
=
step
+
1
,
batch_size
=
args
.
micro_batch_size
,
coordinator
=
coordinator
,
)
coordinator
.
print_on_master
(
f
"Saved checkpoint at epoch
{
epoch
}
step
{
step
+
1
}
at folder
{
args
.
save_dir
}
"
)
if
args
.
use_neft
:
coordinator
.
print_on_master
(
"Activate NEFTune."
)
model
,
handle
=
activate_neftune
(
model
)
# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator
.
empty_cache
()
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader
.
sampler
.
set_start_index
(
start_index
=
0
)
start_step
=
0
if
args
.
use_neft
:
coordinator
.
print_on_master
(
"Deactivate NEFTune."
)
deactivate_neftune
(
model
,
handle
)
# Final save.
coordinator
.
print_on_master
(
"Start saving final model checkpoint"
)
booster
.
save_model
(
model
,
os
.
path
.
join
(
args
.
save_dir
,
"modeling"
),
shard
=
True
)
coordinator
.
print_on_master
(
f
"Saved final model checkpoint at epoch
{
epoch
}
at folder
{
args
.
save_dir
}
"
)
coordinator
.
print_on_master
(
f
"Max
CUDA
memory usage:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
coordinator
.
print_on_master
(
f
"Max
device
memory usage:
{
accelerator
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
if
__name__
==
"__main__"
:
...
...
applications/Colossal-LLaMA-2/train_sft.example.sh
View file @
73f9f23f
...
...
@@ -25,7 +25,7 @@ SAVE_DIR="${PARENT_SAVE_DIR}${FULL_PROJECT_NAME}"
TENSORBOARD_DIR
=
"
${
PARENT_TENSORBOARD_DIR
}${
FULL_PROJECT_NAME
}
"
CONFIG_FILE
=
"
${
PARENT_CONFIG_FILE
}${
FULL_PROJECT_NAME
}
.json"
colossalai run
--nproc_per_node
8
--hostfile
hostfile
--master_port
30013 train
_sft
.py
\
colossalai run
--nproc_per_node
8
--hostfile
hostfile
--master_port
30013 train.py
\
--pretrained
$PRETRAINED_MODEL_PATH
\
--dataset
${
dataset
[@]
}
\
--plugin
"zero2"
\
...
...
@@ -44,3 +44,4 @@ colossalai run --nproc_per_node 8 --hostfile hostfile --master_port 30013 train_
--use_grad_checkpoint
\
--use_flash_attn
\
--use_neft
\
--pad_token
"eos"
applications/Colossal-LLaMA-2/train_sft.py
deleted
100644 → 0
View file @
6c0fa7b9
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Supervised fine-tuning of Colossal-LLaMA-2-base developed by Colossal-AI Team
"""
import
argparse
import
json
import
os
import
resource
from
contextlib
import
nullcontext
import
torch
import
torch.distributed
as
dist
from
colossal_llama2.dataset.loader
import
(
DataCollatorForSupervisedDataset
,
StatefulDistributedSampler
,
load_tokenized_dataset
,
)
from
colossal_llama2.utils.ckpt_io
import
load_checkpoint
,
save_checkpoint
from
colossal_llama2.utils.flash_attention_patch
import
replace_with_flash_attention
from
colossal_llama2.utils.froze
import
freeze_non_embeds_parameters
from
colossal_llama2.utils.neftune_patch
import
activate_neftune
,
deactivate_neftune
from
torch.utils.tensorboard
import
SummaryWriter
from
tqdm
import
tqdm
from
transformers
import
LlamaConfig
,
LlamaForCausalLM
,
LlamaTokenizer
import
colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
GeminiPlugin
,
HybridParallelPlugin
,
LowLevelZeroPlugin
from
colossalai.cluster
import
DistCoordinator
from
colossalai.lazy
import
LazyInitContext
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.utils
import
get_current_device
def
get_model_numel
(
model
:
torch
.
nn
.
Module
)
->
int
:
return
sum
(
p
.
numel
()
for
p
in
model
.
parameters
())
def
format_numel_str
(
numel
:
int
)
->
str
:
B
=
1024
**
3
M
=
1024
**
2
K
=
1024
if
numel
>=
B
:
return
f
"
{
numel
/
B
:.
2
f
}
B"
elif
numel
>=
M
:
return
f
"
{
numel
/
M
:.
2
f
}
M"
elif
numel
>=
K
:
return
f
"
{
numel
/
K
:.
2
f
}
K"
else
:
return
f
"
{
numel
}
"
def
all_reduce_mean
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
dist
.
all_reduce
(
tensor
=
tensor
,
op
=
dist
.
ReduceOp
.
SUM
)
tensor
.
div_
(
dist
.
get_world_size
())
return
tensor
def
main
()
->
None
:
# ==============================
# Parse Arguments
# ==============================
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--pretrained"
,
type
=
str
,
default
=
None
,
help
=
"Address of the pre-trained modeling"
,
)
parser
.
add_argument
(
"--dataset"
,
nargs
=
"+"
,
default
=
[])
parser
.
add_argument
(
"--plugin"
,
type
=
str
,
default
=
"gemini"
,
choices
=
[
"gemini"
,
"gemini_auto"
,
"zero2"
,
"zero2_cpu"
,
"3d"
],
help
=
"Choose which plugin to use"
,
)
parser
.
add_argument
(
"--load_checkpoint"
,
type
=
str
,
default
=
None
,
help
=
"Load checkpoint"
)
parser
.
add_argument
(
"--save_interval"
,
type
=
int
,
default
=
1000
,
help
=
"Save interval"
)
parser
.
add_argument
(
"--save_dir"
,
type
=
str
,
default
=
"checkpoint_dir"
,
help
=
"Checkpoint directory"
)
parser
.
add_argument
(
"--tensorboard_dir"
,
type
=
str
,
default
=
"logs_dir"
,
help
=
"Tensorboard directory"
)
parser
.
add_argument
(
"--config_file"
,
type
=
str
,
default
=
"config_file"
,
help
=
"Config file"
)
parser
.
add_argument
(
"--num_epochs"
,
type
=
int
,
default
=
1
,
help
=
"Number of training epochs"
)
parser
.
add_argument
(
"--accumulation_steps"
,
type
=
int
,
default
=
8
,
help
=
"Number of accumulation steps"
)
parser
.
add_argument
(
"--micro_batch_size"
,
type
=
int
,
default
=
2
,
help
=
"Batch size of each process"
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
3e-4
,
help
=
"Learning rate"
)
parser
.
add_argument
(
"--max_length"
,
type
=
int
,
default
=
4096
,
help
=
"Model max length"
)
parser
.
add_argument
(
"--mixed_precision"
,
type
=
str
,
default
=
"fp16"
,
choices
=
[
"fp16"
,
"bf16"
],
help
=
"Mixed precision"
,
)
parser
.
add_argument
(
"--grad_clip"
,
type
=
float
,
default
=
1.0
,
help
=
"Gradient clipping value"
)
parser
.
add_argument
(
"--weight_decay"
,
type
=
float
,
default
=
0.1
,
help
=
"Weight decay"
)
parser
.
add_argument
(
"--warmup_steps"
,
type
=
int
,
default
=
None
,
help
=
"Warmup steps"
)
parser
.
add_argument
(
"--use_grad_checkpoint"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use gradient checkpointing"
,
)
parser
.
add_argument
(
"--use_flash_attn"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use flash-attention"
,
)
parser
.
add_argument
(
"--use_neft"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use NEFTune"
,
)
parser
.
add_argument
(
"--freeze_non_embeds_params"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Freeze non embeddings parameters"
,
)
parser
.
add_argument
(
"--tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--zero"
,
type
=
int
,
default
=
1
)
args
=
parser
.
parse_args
()
with
open
(
args
.
config_file
,
"w"
)
as
f
:
json
.
dump
(
args
.
__dict__
,
f
,
indent
=
4
)
# ==============================
# Initialize Distributed Training
# ==============================
colossalai
.
launch_from_torch
({})
coordinator
=
DistCoordinator
()
# ==============================
# Initialize Tensorboard
# ==============================
if
coordinator
.
is_master
():
os
.
makedirs
(
args
.
tensorboard_dir
,
exist_ok
=
True
)
writer
=
SummaryWriter
(
args
.
tensorboard_dir
)
# ==============================
# Initialize Booster
# ==============================
if
args
.
plugin
==
"gemini"
:
plugin
=
GeminiPlugin
(
precision
=
args
.
mixed_precision
,
initial_scale
=
2
**
16
,
max_norm
=
args
.
grad_clip
,
)
elif
args
.
plugin
==
"gemini_auto"
:
plugin
=
GeminiPlugin
(
precision
=
args
.
mixed_precision
,
placement_policy
=
"auto"
,
initial_scale
=
2
**
16
,
max_norm
=
args
.
grad_clip
,
)
elif
args
.
plugin
==
"zero2"
:
plugin
=
LowLevelZeroPlugin
(
stage
=
2
,
precision
=
args
.
mixed_precision
,
initial_scale
=
2
**
16
,
max_norm
=
args
.
grad_clip
,
)
elif
args
.
plugin
==
"zero2_cpu"
:
plugin
=
LowLevelZeroPlugin
(
stage
=
2
,
precision
=
args
.
mixed_precision
,
initial_scale
=
2
**
16
,
cpu_offload
=
True
,
max_norm
=
args
.
grad_clip
,
)
elif
args
.
plugin
==
"3d"
:
plugin
=
HybridParallelPlugin
(
tp_size
=
args
.
tp
,
pp_size
=
1
,
zero_stage
=
args
.
zero
,
max_norm
=
args
.
grad_clip
,
precision
=
args
.
mixed_precision
,
)
else
:
raise
ValueError
(
f
"Unknown plugin
{
args
.
plugin
}
"
)
booster
=
Booster
(
plugin
=
plugin
)
# ======================================================
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
args
.
pretrained
)
tokenizer
.
pad_token
=
tokenizer
.
eos_token
tokenizer
.
add_bos_token
=
False
tokenizer
.
add_eos_token
=
False
coordinator
.
print_on_master
(
f
"Configuration file will be saved at:
{
args
.
config_file
}
"
)
coordinator
.
print_on_master
(
f
"Tensorboard logs will be saved at:
{
args
.
tensorboard_dir
}
"
)
coordinator
.
print_on_master
(
f
"Model checkpoint will be saved at:
{
args
.
save_dir
}
"
)
coordinator
.
print_on_master
(
f
"Load dataset:
{
args
.
dataset
}
"
)
dataset
=
load_tokenized_dataset
(
dataset_paths
=
args
.
dataset
,
mode
=
"train"
)
data_collator
=
DataCollatorForSupervisedDataset
(
tokenizer
=
tokenizer
,
max_length
=
args
.
max_length
)
dataloader
=
plugin
.
prepare_dataloader
(
dataset
=
dataset
,
batch_size
=
args
.
micro_batch_size
,
shuffle
=
True
,
drop_last
=
True
,
collate_fn
=
data_collator
,
distributed_sampler_cls
=
StatefulDistributedSampler
,
)
coordinator
.
print_on_master
(
f
"Max CUDA memory after data loader:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
init_ctx
=
(
LazyInitContext
(
default_device
=
get_current_device
())
if
isinstance
(
plugin
,
(
GeminiPlugin
,))
else
nullcontext
()
)
with
init_ctx
:
model
=
LlamaForCausalLM
(
LlamaConfig
.
from_pretrained
(
args
.
pretrained
))
# Freeze part of parameters.
if
args
.
freeze_non_embeds_params
:
freeze_non_embeds_parameters
(
model
=
model
)
if
args
.
use_grad_checkpoint
:
model
.
gradient_checkpointing_enable
()
coordinator
.
print_on_master
(
msg
=
"Gradient checkpointing enabled successfully"
)
if
args
.
use_flash_attn
:
replace_with_flash_attention
(
model
=
model
)
coordinator
.
print_on_master
(
msg
=
"Flash-attention enabled successfully"
)
model_numel
=
get_model_numel
(
model
)
coordinator
.
print_on_master
(
f
"Model params:
{
format_numel_str
(
model_numel
)
}
"
)
optimizer
=
HybridAdam
(
model_params
=
filter
(
lambda
p
:
p
.
requires_grad
,
model
.
parameters
())
if
args
.
freeze_non_embeds_params
else
model
.
parameters
(),
lr
=
args
.
lr
,
betas
=
(
0.9
,
0.95
),
weight_decay
=
args
.
weight_decay
,
adamw_mode
=
True
,
)
if
args
.
warmup_steps
is
None
:
args
.
warmup_steps
=
int
(
args
.
num_epochs
*
0.025
*
(
len
(
dataloader
)
//
args
.
accumulation_steps
))
coordinator
.
print_on_master
(
f
"Warmup steps is set to
{
args
.
warmup_steps
}
"
)
lr_scheduler
=
CosineAnnealingWarmupLR
(
optimizer
=
optimizer
,
total_steps
=
args
.
num_epochs
*
(
len
(
dataloader
)
//
args
.
accumulation_steps
),
warmup_steps
=
args
.
warmup_steps
,
eta_min
=
0.1
*
args
.
lr
,
)
# Flash attention will be disabled because it does NOT support fp32.
default_dtype
=
torch
.
float16
if
args
.
mixed_precision
==
"fp16"
else
torch
.
bfloat16
torch
.
set_default_dtype
(
default_dtype
)
model
,
optimizer
,
_
,
dataloader
,
lr_scheduler
=
booster
.
boost
(
model
=
model
,
optimizer
=
optimizer
,
lr_scheduler
=
lr_scheduler
,
dataloader
=
dataloader
,
)
torch
.
set_default_dtype
(
torch
.
float
)
if
args
.
load_checkpoint
is
None
:
coordinator
.
print_on_master
(
f
"Load pretrained model checkpoint from
{
args
.
pretrained
}
"
)
booster
.
load_model
(
model
,
args
.
pretrained
,
strict
=
False
)
coordinator
.
print_on_master
(
f
"Booster init max CUDA memory:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
coordinator
.
print_on_master
(
f
"Booster init max CPU memory:
{
resource
.
getrusage
(
resource
.
RUSAGE_SELF
).
ru_maxrss
/
1024
:.
2
f
}
MB"
)
start_epoch
=
0
start_step
=
0
sampler_start_idx
=
0
if
args
.
load_checkpoint
is
not
None
:
if
"modeling"
in
args
.
load_checkpoint
:
coordinator
.
print_on_master
(
f
"Continued pretrain from checkpoint
{
args
.
load_checkpoint
}
"
)
booster
.
load_model
(
model
,
args
.
load_checkpoint
)
else
:
coordinator
.
print_on_master
(
f
"Load model checkpoint from
{
args
.
load_checkpoint
}
"
)
start_epoch
,
start_step
,
sampler_start_idx
=
load_checkpoint
(
load_dir
=
args
.
load_checkpoint
,
booster
=
booster
,
model
=
model
,
optimizer
=
optimizer
,
lr_scheduler
=
lr_scheduler
,
)
coordinator
.
print_on_master
(
f
"Loaded checkpoint
{
args
.
load_checkpoint
}
at epoch
{
start_epoch
}
step
{
start_step
}
"
)
coordinator
.
print_on_master
(
f
"Loaded sample at index
{
sampler_start_idx
}
"
)
coordinator
.
print_on_master
(
f
"Checkpoint loaded max CUDA memory:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
coordinator
.
print_on_master
(
f
"Checkpoint loaded CUDA memory:
{
torch
.
cuda
.
memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
coordinator
.
print_on_master
(
f
"Checkpoint loaded max CPU memory:
{
resource
.
getrusage
(
resource
.
RUSAGE_SELF
).
ru_maxrss
/
1024
:.
2
f
}
MB"
)
if
args
.
use_neft
:
coordinator
.
print_on_master
(
"Activate NEFTune."
)
model
,
handle
=
activate_neftune
(
model
)
num_steps_per_epoch
=
len
(
dataloader
)
//
args
.
accumulation_steps
# If resume training, set the sampler start index to the correct value
assert
isinstance
(
dataloader
.
sampler
,
StatefulDistributedSampler
)
dataloader
.
sampler
.
set_start_index
(
start_index
=
sampler_start_idx
)
for
epoch
in
range
(
start_epoch
,
args
.
num_epochs
):
dataloader
.
sampler
.
set_epoch
(
epoch
=
epoch
)
pbar
=
tqdm
(
desc
=
f
"Epoch
{
epoch
}
"
,
disable
=
not
coordinator
.
is_master
(),
total
=
num_steps_per_epoch
)
total_loss
=
torch
.
tensor
(
0.0
).
to
(
torch
.
cuda
.
current_device
())
for
step
,
batch
in
enumerate
(
dataloader
):
batch
=
{
k
:
v
.
to
(
get_current_device
())
for
k
,
v
in
batch
.
items
()
if
isinstance
(
v
,
torch
.
Tensor
)}
batch_output
=
model
(
**
batch
)
loss
=
batch_output
.
loss
/
args
.
accumulation_steps
total_loss
+=
loss
.
item
()
booster
.
backward
(
loss
=
loss
,
optimizer
=
optimizer
)
if
(
step
+
1
)
%
args
.
accumulation_steps
==
0
:
optimizer
.
step
()
lr_scheduler
.
step
()
optimizer
.
zero_grad
()
all_reduce_mean
(
tensor
=
total_loss
)
pbar
.
set_postfix
({
"Loss"
:
f
"
{
total_loss
.
item
():.
4
f
}
"
})
if
coordinator
.
is_master
():
global_step
=
(
epoch
*
num_steps_per_epoch
)
+
(
step
+
1
)
//
args
.
accumulation_steps
writer
.
add_scalar
(
tag
=
"Loss"
,
scalar_value
=
total_loss
.
item
(),
global_step
=
global_step
)
writer
.
add_scalar
(
tag
=
"Learning Rate"
,
scalar_value
=
lr_scheduler
.
get_last_lr
()[
0
],
global_step
=
global_step
,
)
total_loss
.
fill_
(
0.0
)
pbar
.
update
()
# Save modeling.
if
(
args
.
save_interval
>
0
and
(
step
+
1
)
%
(
args
.
save_interval
*
args
.
accumulation_steps
)
==
0
)
or
(
step
+
1
)
==
len
(
dataloader
):
coordinator
.
print_on_master
(
"
\n
Start saving model checkpoint with running states"
)
if
args
.
use_neft
:
coordinator
.
print_on_master
(
"Deactivate NEFTune before saving model."
)
deactivate_neftune
(
model
,
handle
)
save_checkpoint
(
save_dir
=
args
.
save_dir
,
booster
=
booster
,
model
=
model
,
optimizer
=
optimizer
,
lr_scheduler
=
lr_scheduler
,
epoch
=
epoch
,
step
=
step
+
1
,
batch_size
=
args
.
micro_batch_size
,
coordinator
=
coordinator
,
)
coordinator
.
print_on_master
(
f
"Saved checkpoint at epoch
{
epoch
}
step
{
step
+
1
}
at folder
{
args
.
save_dir
}
"
)
if
args
.
use_neft
:
coordinator
.
print_on_master
(
"Activate NEFTune."
)
model
,
handle
=
activate_neftune
(
model
)
# Delete CUDA cache.
# del batch, batch_labels, batch_output, loss
torch
.
cuda
.
empty_cache
()
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader
.
sampler
.
set_start_index
(
start_index
=
0
)
start_step
=
0
if
args
.
use_neft
:
coordinator
.
print_on_master
(
"Deactivate NEFTune."
)
deactivate_neftune
(
model
,
handle
)
# Final save.
coordinator
.
print_on_master
(
"Start saving final model checkpoint"
)
booster
.
save_model
(
model
,
os
.
path
.
join
(
args
.
save_dir
,
"modeling"
),
shard
=
True
)
coordinator
.
print_on_master
(
f
"Saved final model checkpoint at epoch
{
epoch
}
at folder
{
args
.
save_dir
}
"
)
coordinator
.
print_on_master
(
f
"Max CUDA memory usage:
{
torch
.
cuda
.
max_memory_allocated
()
/
1024
**
2
:.
2
f
}
MB"
)
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment