Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3b9dc880
Commit
3b9dc880
authored
Nov 01, 2021
by
Jared Casper
Browse files
Update to work with latest main branch.
parent
67aa8619
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
68 additions
and
58 deletions
+68
-58
megatron/model/module.py
megatron/model/module.py
+27
-24
tools/checkpoint_loader_megatron.py
tools/checkpoint_loader_megatron.py
+26
-22
tools/checkpoint_saver_megatron.py
tools/checkpoint_saver_megatron.py
+15
-12
No files found.
megatron/model/module.py
View file @
3b9dc880
...
...
@@ -96,6 +96,16 @@ class MegatronModule(torch.nn.Module):
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
shared
=
True
if
not
torch
.
distributed
.
is_initialized
():
if
not
getattr
(
MegatronModule
,
"embedding_warning_printed"
,
False
):
print
(
"WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
MegatronModule
.
embedding_warning_printed
=
True
return
# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if
not
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
and
\
...
...
@@ -105,7 +115,6 @@ class MegatronModule(torch.nn.Module):
# Ensure that first and last stages have the same initial parameter
# values.
if
torch
.
distributed
.
is_initialized
():
if
mpu
.
is_rank_in_embedding_group
():
torch
.
distributed
.
all_reduce
(
self
.
word_embeddings_weight
().
data
,
group
=
mpu
.
get_embedding_group
())
...
...
@@ -124,12 +133,6 @@ class MegatronModule(torch.nn.Module):
position_embeddings
=
self
.
language_model
.
embedding
.
position_embeddings
torch
.
distributed
.
all_reduce
(
position_embeddings
.
weight
.
data
,
group
=
mpu
.
get_embedding_group
())
else
:
print
(
"WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
def
conversion_helper
(
val
,
conversion
):
...
...
tools/checkpoint_loader_megatron.py
View file @
3b9dc880
...
...
@@ -23,34 +23,13 @@ def _load_checkpoint(queue, args):
from
megatron.arguments
import
parse_args
,
validate_args
from
megatron.global_vars
import
set_args
,
set_global_variables
,
rebuild_tokenizer
from
megatron.checkpointing
import
load_args_from_checkpoint
,
load_checkpoint
from
megatron.model
import
ModelType
from
megatron
import
mpu
,
fused_kernels
except
ModuleNotFoundError
:
print
(
"Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting."
)
queue
.
put
(
"exit"
)
exit
(
1
)
def
get_models
(
count
,
dtype
,
pre_process
,
post_process
):
if
args
.
model_type
==
'GPT'
:
from
pretrain_gpt
import
model_provider
elif
args
.
model_type
==
'BERT'
:
from
pretrain_bert
import
model_provider
else
:
raise
Exception
(
f
'unrecognized model type:
{
args
.
model_type
}
'
)
# with concurrent.futures.ThreadPoolExecutor(max_workers=count) as executor:
# futures = [executor.submit(model_provider, pre_process, post_process) for _ in range(count)]
# models = [f.result().bfloat16() for f in futures]
models
=
[]
for
rank
in
range
(
count
):
mpu
.
initialize
.
set_tensor_model_parallel_rank
(
rank
)
model_
=
[
model_provider
(
pre_process
,
post_process
).
to
(
dtype
)]
margs
.
consumed_train_samples
=
0
margs
.
consumed_valid_samples
=
0
load_checkpoint
(
model_
,
None
,
None
)
assert
(
len
(
model_
)
==
1
)
models
.
append
(
model_
[
0
])
return
models
# We want all arguments to come from us
sys
.
argv
=
[
'script.py'
,
'--no-masked-softmax-fusion'
,
...
...
@@ -95,6 +74,31 @@ def _load_checkpoint(queue, args):
check_for_arg
(
'params_dtype'
)
# Determine how to make our models
if
args
.
model_type
==
'GPT'
:
from
pretrain_gpt
import
model_provider
margs
.
model_type
=
ModelType
.
encoder_or_decoder
elif
args
.
model_type
==
'BERT'
:
from
pretrain_bert
import
model_provider
margs
.
model_type
=
ModelType
.
encoder_or_decoder
else
:
raise
Exception
(
f
'unrecognized model type:
{
args
.
model_type
}
'
)
def
get_models
(
count
,
dtype
,
pre_process
,
post_process
):
# with concurrent.futures.ThreadPoolExecutor(max_workers=count) as executor:
# futures = [executor.submit(model_provider, pre_process, post_process) for _ in range(count)]
# models = [f.result().bfloat16() for f in futures]
models
=
[]
for
rank
in
range
(
count
):
mpu
.
initialize
.
set_tensor_model_parallel_rank
(
rank
)
model_
=
[
model_provider
(
pre_process
,
post_process
).
to
(
dtype
)]
margs
.
consumed_train_samples
=
0
margs
.
consumed_valid_samples
=
0
load_checkpoint
(
model_
,
None
,
None
)
assert
(
len
(
model_
)
==
1
)
models
.
append
(
model_
[
0
])
return
models
set_args
(
margs
)
if
margs
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
...
...
tools/checkpoint_saver_megatron.py
View file @
3b9dc880
...
...
@@ -30,6 +30,7 @@ def save_checkpoint(queue, args):
try
:
from
megatron.checkpointing
import
save_checkpoint
from
megatron.global_vars
import
set_global_variables
,
get_args
from
megatron.model
import
ModelType
from
megatron
import
mpu
,
fused_kernels
except
ModuleNotFoundError
:
print
(
"Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting."
)
...
...
@@ -44,18 +45,6 @@ def save_checkpoint(queue, args):
md
=
queue_get
()
def
get_models
(
count
,
dtype
,
pre_process
,
post_process
):
if
md
.
model_type
==
'GPT'
:
from
pretrain_gpt
import
model_provider
elif
md
.
model_type
==
'BERT'
:
from
pretrain_bert
import
model_provider
else
:
raise
Exception
(
f
'unrecognized model type:
{
md
.
model_type
}
'
)
# with concurrent.futures.ThreadPoolExecutor(max_workers=count) as executor:
# futures = [executor.submit(model_provider, pre_process, post_process) for _ in range(count)]
# models = [f.result().bfloat16() for f in futures]
models
=
[
model_provider
(
pre_process
,
post_process
).
to
(
dtype
)
for
_
in
range
(
count
)]
return
models
if
args
.
target_tensor_parallel_size
is
None
:
if
hasattr
(
md
,
'previous_tensor_parallel_size'
):
...
...
@@ -114,6 +103,20 @@ def save_checkpoint(queue, args):
# margs = megatron args
margs
=
get_args
()
# Determine how to make our models
if
md
.
model_type
==
'GPT'
:
from
pretrain_gpt
import
model_provider
margs
.
model_type
=
ModelType
.
encoder_or_decoder
elif
md
.
model_type
==
'BERT'
:
from
pretrain_bert
import
model_provider
margs
.
model_type
=
ModelType
.
encoder_or_decoder
else
:
raise
Exception
(
f
'unrecognized model type:
{
args
.
model_type
}
'
)
def
get_models
(
count
,
dtype
,
pre_process
,
post_process
):
models
=
[
model_provider
(
pre_process
,
post_process
).
to
(
dtype
)
for
_
in
range
(
count
)]
return
models
# fake initializing distributed
mpu
.
initialize
.
set_tensor_model_parallel_world_size
(
args
.
target_tensor_parallel_size
)
mpu
.
initialize
.
set_pipeline_model_parallel_world_size
(
args
.
target_pipeline_parallel_size
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment