Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
eb4f2d90
Unverified
Commit
eb4f2d90
authored
Feb 06, 2024
by
Hongxin Liu
Committed by
GitHub
Feb 06, 2024
Browse files
[llama] polish training script and fix optim ckpt (#5368)
parent
a5756a87
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
5 deletions
+14
-5
applications/Colossal-LLaMA-2/train.py
applications/Colossal-LLaMA-2/train.py
+11
-3
colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
+3
-2
No files found.
applications/Colossal-LLaMA-2/train.py
View file @
eb4f2d90
...
@@ -23,7 +23,7 @@ from colossal_llama2.utils.froze import freeze_non_embeds_parameters
...
@@ -23,7 +23,7 @@ from colossal_llama2.utils.froze import freeze_non_embeds_parameters
from
colossal_llama2.utils.neftune_patch
import
activate_neftune
,
deactivate_neftune
from
colossal_llama2.utils.neftune_patch
import
activate_neftune
,
deactivate_neftune
from
torch.utils.tensorboard
import
SummaryWriter
from
torch.utils.tensorboard
import
SummaryWriter
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
transformers
import
LlamaForCausalLM
,
LlamaTokenizer
from
transformers
import
LlamaConfig
,
LlamaForCausalLM
,
LlamaTokenizer
import
colossalai
import
colossalai
from
colossalai.accelerator
import
get_accelerator
from
colossalai.accelerator
import
get_accelerator
...
@@ -232,7 +232,7 @@ def main() -> None:
...
@@ -232,7 +232,7 @@ def main() -> None:
else
nullcontext
()
else
nullcontext
()
)
)
with
init_ctx
:
with
init_ctx
:
model
=
LlamaForCausalLM
.
from_pretrained
(
args
.
pretrained
)
model
=
LlamaForCausalLM
(
LlamaConfig
.
from_pretrained
(
args
.
pretrained
)
)
# Freeze part of parameters.
# Freeze part of parameters.
if
args
.
freeze_non_embeds_params
:
if
args
.
freeze_non_embeds_params
:
freeze_non_embeds_parameters
(
model
=
model
)
freeze_non_embeds_parameters
(
model
=
model
)
...
@@ -277,6 +277,8 @@ def main() -> None:
...
@@ -277,6 +277,8 @@ def main() -> None:
lr_scheduler
=
lr_scheduler
,
lr_scheduler
=
lr_scheduler
,
dataloader
=
dataloader
,
dataloader
=
dataloader
,
)
)
if
args
.
load_checkpoint
is
None
:
booster
.
load_model
(
model
,
args
.
pretrained
)
torch
.
set_default_dtype
(
torch
.
float
)
torch
.
set_default_dtype
(
torch
.
float
)
...
@@ -329,7 +331,12 @@ def main() -> None:
...
@@ -329,7 +331,12 @@ def main() -> None:
for
epoch
in
range
(
start_epoch
,
args
.
num_epochs
):
for
epoch
in
range
(
start_epoch
,
args
.
num_epochs
):
dataloader
.
sampler
.
set_epoch
(
epoch
=
epoch
)
dataloader
.
sampler
.
set_epoch
(
epoch
=
epoch
)
pbar
=
tqdm
(
desc
=
f
"Epoch
{
epoch
}
"
,
disable
=
not
coordinator
.
is_master
(),
total
=
num_steps_per_epoch
,
initial
=
start_step
//
args
.
accumulation_steps
)
pbar
=
tqdm
(
desc
=
f
"Epoch
{
epoch
}
"
,
disable
=
not
coordinator
.
is_master
(),
total
=
num_steps_per_epoch
,
initial
=
start_step
//
args
.
accumulation_steps
,
)
total_loss
=
torch
.
tensor
(
0.0
,
device
=
get_current_device
())
total_loss
=
torch
.
tensor
(
0.0
,
device
=
get_current_device
())
for
step
,
batch
in
enumerate
(
dataloader
,
start
=
start_step
):
for
step
,
batch
in
enumerate
(
dataloader
,
start
=
start_step
):
batch
=
{
k
:
v
.
to
(
get_current_device
())
for
k
,
v
in
batch
.
items
()
if
isinstance
(
v
,
torch
.
Tensor
)}
batch
=
{
k
:
v
.
to
(
get_current_device
())
for
k
,
v
in
batch
.
items
()
if
isinstance
(
v
,
torch
.
Tensor
)}
...
@@ -369,6 +376,7 @@ def main() -> None:
...
@@ -369,6 +376,7 @@ def main() -> None:
coordinator
.
print_on_master
(
"Deactivate NEFTune before saving model."
)
coordinator
.
print_on_master
(
"Deactivate NEFTune before saving model."
)
deactivate_neftune
(
model
,
handle
)
deactivate_neftune
(
model
,
handle
)
accelerator
.
empty_cache
()
save_checkpoint
(
save_checkpoint
(
save_dir
=
args
.
save_dir
,
save_dir
=
args
.
save_dir
,
booster
=
booster
,
booster
=
booster
,
...
...
colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
View file @
eb4f2d90
...
@@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
...
@@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from
colossalai.cluster
import
DistCoordinator
from
colossalai.cluster
import
DistCoordinator
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
colossalai.utils
import
get_current_device
from
.general_checkpoint_io
import
GeneralCheckpointIO
from
.general_checkpoint_io
import
GeneralCheckpointIO
from
.index_file
import
CheckpointIndexFile
from
.index_file
import
CheckpointIndexFile
...
@@ -721,7 +722,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
...
@@ -721,7 +722,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
tp_group
=
self
.
tp_group
,
tp_group
=
self
.
tp_group
,
use_zero
=
self
.
use_zero
,
use_zero
=
self
.
use_zero
,
inplace
=
False
,
inplace
=
False
,
device
=
torch
.
device
(
"cuda"
),
device
=
get_current_
device
(),
)
)
if
self
.
pp_size
==
1
:
if
self
.
pp_size
==
1
:
...
@@ -854,7 +855,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
...
@@ -854,7 +855,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if
isinstance
(
v
,
torch
.
Tensor
)
and
k
!=
"step"
:
if
isinstance
(
v
,
torch
.
Tensor
)
and
k
!=
"step"
:
# First gather Zero shards.
# First gather Zero shards.
if
use_zero
:
if
use_zero
:
v
=
v
.
cuda
()
v
=
v
.
to
(
get_current_device
()
)
gather_tensor
=
[
torch
.
zeros_like
(
v
)
for
_
in
range
(
dp_size
)]
gather_tensor
=
[
torch
.
zeros_like
(
v
)
for
_
in
range
(
dp_size
)]
dist
.
all_gather
(
gather_tensor
,
v
,
group
=
dp_group
)
dist
.
all_gather
(
gather_tensor
,
v
,
group
=
dp_group
)
v
=
torch
.
stack
(
gather_tensor
).
view
(
-
1
)[:
param
.
numel
()].
reshape_as
(
param
)
v
=
torch
.
stack
(
gather_tensor
).
view
(
-
1
)[:
param
.
numel
()].
reshape_as
(
param
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment