Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
be1a575e
Commit
be1a575e
authored
Apr 19, 2023
by
Jared Casper
Browse files
Some quick fixes to checkpoint_util.
parent
8dbd0757
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
9 deletions
+9
-9
megatron/model/transformer.py
megatron/model/transformer.py
+2
-1
tools/checkpoint_loader_megatron.py
tools/checkpoint_loader_megatron.py
+6
-8
tools/checkpoint_saver_megatron.py
tools/checkpoint_saver_megatron.py
+1
-0
No files found.
megatron/model/transformer.py
View file @
be1a575e
...
...
@@ -1012,8 +1012,9 @@ class ParallelTransformer(MegatronModule):
import
transformer_engine
self
.
use_fp8
=
args
.
fp8_e4m3
or
args
.
fp8_hybrid
self
.
fp8_recipe
=
None
self
.
fp8_group
=
mpu
.
get_data_parallel_group
()
self
.
fp8_group
=
None
if
self
.
use_fp8
:
self
.
fp8_group
=
mpu
.
get_data_parallel_group
()
if
args
.
fp8_e4m3
:
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
elif
args
.
fp8_hybrid
:
...
...
tools/checkpoint_loader_megatron.py
View file @
be1a575e
...
...
@@ -43,6 +43,7 @@ def _load_checkpoint(queue, args):
'--no-masked-softmax-fusion'
,
'--no-bias-gelu-fusion'
,
'--no-bias-dropout-fusion'
,
'--no-async-tensor-model-parallel-allreduce'
,
'--use-cpu-initialization'
,
'--micro-batch-size'
,
'1'
,
'--no-load-optim'
,
...
...
@@ -101,7 +102,7 @@ def _load_checkpoint(queue, args):
nonlocal
consumed_valid_samples
models
=
[]
for
rank
in
range
(
count
):
mpu
.
parallel_state
.
set_tensor_model_parallel_rank
(
rank
)
mpu
.
set_tensor_model_parallel_rank
(
rank
)
model_
=
[
model_provider
(
pre_process
,
post_process
).
to
(
dtype
)]
margs
.
consumed_train_samples
=
0
margs
.
consumed_valid_samples
=
0
...
...
@@ -125,8 +126,8 @@ def _load_checkpoint(queue, args):
exit
(
1
)
set_global_variables
(
margs
)
mpu
.
parallel_state
.
set_tensor_model_parallel_world_size
(
margs
.
tensor_model_parallel_size
)
mpu
.
parallel_state
.
set_pipeline_model_parallel_world_size
(
margs
.
pipeline_model_parallel_size
)
mpu
.
set_tensor_model_parallel_world_size
(
margs
.
tensor_model_parallel_size
)
mpu
.
set_pipeline_model_parallel_world_size
(
margs
.
pipeline_model_parallel_size
)
fused_kernels
.
load
(
margs
)
# Get true (non-padded) vocab size
...
...
@@ -164,7 +165,7 @@ def _load_checkpoint(queue, args):
md
.
make_vocab_size_divisible_by
=
margs
.
make_vocab_size_divisible_by
# Get first pipe stage
mpu
.
parallel_state
.
set_pipeline_model_parallel_rank
(
0
)
mpu
.
set_pipeline_model_parallel_rank
(
0
)
post_process
=
pp_size
==
1
models
=
get_models
(
tp_size
,
md
.
params_dtype
,
True
,
post_process
)
...
...
@@ -190,7 +191,7 @@ def _load_checkpoint(queue, args):
total_layer_num
=
0
for
pp_rank
in
range
(
pp_size
):
if
pp_rank
>
0
:
mpu
.
parallel_state
.
set_pipeline_model_parallel_rank
(
pp_rank
)
mpu
.
set_pipeline_model_parallel_rank
(
pp_rank
)
post_process
=
pp_rank
==
pp_size
-
1
models
=
get_models
(
tp_size
,
md
.
params_dtype
,
False
,
post_process
)
for
layer_num
in
range
(
len
(
models
[
0
].
language_model
.
encoder
.
layers
)):
...
...
@@ -242,7 +243,6 @@ def _load_checkpoint(queue, args):
# Send BERT lm head and binary head if it exists
if
md
.
model_type
==
'BERT'
:
print
(
"Sending LM Pooler"
)
message
=
{
"weight"
:
models
[
0
].
language_model
.
pooler
.
dense
.
weight
.
data
,
"bias"
:
models
[
0
].
language_model
.
pooler
.
dense
.
bias
.
data
...
...
@@ -258,8 +258,6 @@ def _load_checkpoint(queue, args):
queue_put
(
"lm head"
,
message
)
if
md
.
bert_binary_head
:
print
(
"Sending BERT Binary head"
)
queue
.
put
(
"binary head"
)
message
=
{
"weight"
:
models
[
0
].
binary_head
.
weight
.
data
,
"bias"
:
models
[
0
].
binary_head
.
bias
.
data
...
...
tools/checkpoint_saver_megatron.py
View file @
be1a575e
...
...
@@ -102,6 +102,7 @@ def save_checkpoint(queue, args):
'--no-masked-softmax-fusion'
,
'--no-bias-gelu-fusion'
,
'--no-bias-dropout-fusion'
,
'--no-async-tensor-model-parallel-allreduce'
,
'--use-cpu-initialization'
,
'--micro-batch-size'
,
'1'
,
'--no-load-optim'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment