Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
58cca6e9
Commit
58cca6e9
authored
Mar 07, 2022
by
Jared Casper
Browse files
Transfer consumed train and valid samples to converted checkpoint.
parent
942c402d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
30 additions
and
5 deletions
+30
-5
tools/checkpoint_loader_megatron.py
tools/checkpoint_loader_megatron.py
+18
-5
tools/checkpoint_saver_megatron.py
tools/checkpoint_saver_megatron.py
+8
-0
tools/checkpoint_util.py
tools/checkpoint_util.py
+4
-0
No files found.
tools/checkpoint_loader_megatron.py
View file @
58cca6e9
...
...
@@ -92,10 +92,11 @@ def _load_checkpoint(queue, args):
# supress warning about torch.distributed not being initialized
module
.
MegatronModule
.
embedding_warning_printed
=
True
consumed_train_samples
=
None
consumed_valid_samples
=
None
def
get_models
(
count
,
dtype
,
pre_process
,
post_process
):
# with concurrent.futures.ThreadPoolExecutor(max_workers=count) as executor:
# futures = [executor.submit(model_provider, pre_process, post_process) for _ in range(count)]
# models = [f.result().bfloat16() for f in futures]
nonlocal
consumed_train_samples
nonlocal
consumed_valid_samples
models
=
[]
for
rank
in
range
(
count
):
mpu
.
initialize
.
set_tensor_model_parallel_rank
(
rank
)
...
...
@@ -104,7 +105,16 @@ def _load_checkpoint(queue, args):
margs
.
consumed_valid_samples
=
0
load_checkpoint
(
model_
,
None
,
None
)
assert
(
len
(
model_
)
==
1
)
models
.
append
(
model_
[
0
])
model_
=
model_
[
0
]
if
consumed_train_samples
is
not
None
:
assert
(
margs
.
consumed_train_samples
==
consumed_train_samples
)
else
:
consumed_train_samples
=
margs
.
consumed_train_samples
if
consumed_valid_samples
is
not
None
:
assert
(
margs
.
consumed_valid_samples
==
consumed_valid_samples
)
else
:
consumed_valid_samples
=
margs
.
consumed_valid_samples
models
.
append
(
model_
)
return
models
if
margs
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
...
...
@@ -150,13 +160,16 @@ def _load_checkpoint(queue, args):
md
.
previous_pipeline_parallel_size
=
margs
.
pipeline_model_parallel_size
md
.
true_vocab_size
=
true_vocab_size
md
.
make_vocab_size_divisible_by
=
margs
.
make_vocab_size_divisible_by
queue
.
put
(
md
)
# Get first pipe stage
mpu
.
initialize
.
set_pipeline_model_parallel_rank
(
0
)
post_process
=
pp_size
==
1
models
=
get_models
(
tp_size
,
md
.
params_dtype
,
True
,
post_process
)
md
.
consumed_train_samples
=
consumed_train_samples
md
.
consumed_valid_samples
=
consumed_valid_samples
queue
.
put
(
md
)
# Send embeddings
word_embed
=
[]
...
...
tools/checkpoint_saver_megatron.py
View file @
58cca6e9
...
...
@@ -110,6 +110,14 @@ def save_checkpoint(queue, args):
# margs = megatron args
margs
=
get_args
()
if
hasattr
(
md
,
'consumed_train_samples'
):
margs
.
consumed_train_samples
=
md
.
consumed_train_samples
margs
.
consumed_valid_samples
=
md
.
consumed_valid_samples
print
(
f
"Setting consumed_train_samples to
{
margs
.
consumed_train_samples
}
"
f
" and consumed_valid_samples to
{
margs
.
consumed_valid_samples
}
"
)
else
:
print
(
"consumed_train_samples not provided."
)
# Determine how to make our models
if
md
.
model_type
==
'GPT'
:
from
pretrain_gpt
import
model_provider
...
...
tools/checkpoint_util.py
View file @
58cca6e9
...
...
@@ -32,6 +32,10 @@ import os
# bert_binary_head - Used only if model_type is BERT
# previous_tensor_parallel_size - Optional
# previous_pipeline_parallel_size - Optional
# true_vocab_size
# make_vocab_size_divisble_by
# consumed_train_samples
# consumed_valid_samples
# - Position embeddings
# - Word embeddings
# - For each transformer layer:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment