Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
82c7ba57
Commit
82c7ba57
authored
Dec 07, 2022
by
MaximumEntropy
Committed by
Jimmy Zhang
Apr 03, 2023
Browse files
Initial commit for untied embeddings
Signed-off-by:
MaximumEntropy
<
sandeep.subramanian.1@umontreal.ca
>
parent
fdd34a82
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
41 additions
and
8 deletions
+41
-8
megatron/arguments.py
megatron/arguments.py
+2
-0
megatron/core/parallel_state.py
megatron/core/parallel_state.py
+14
-4
megatron/initialize.py
megatron/initialize.py
+2
-1
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+2
-1
megatron/model/language_model.py
megatron/model/language_model.py
+20
-1
megatron/model/module.py
megatron/model/module.py
+1
-1
No files found.
megatron/arguments.py
View file @
82c7ba57
...
...
@@ -530,6 +530,8 @@ def _add_network_size_args(parser):
dest
=
'bert_binary_head'
)
group
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
None
,
help
=
'Number of Experts in Switch Transformer (None means no Switch)'
)
group
.
add_argument
(
'--untie-embeddings-and-output-weights'
,
action
=
'store_true'
,
help
=
'Untie embeddings and output weights.'
),
return
parser
...
...
megatron/core/parallel_state.py
View file @
82c7ba57
...
...
@@ -53,6 +53,7 @@ def initialize_model_parallel(
pipeline_model_parallel_size
:
int
=
1
,
virtual_pipeline_model_parallel_size
:
Optional
[
int
]
=
None
,
pipeline_model_parallel_split_rank
:
Optional
[
int
]
=
None
,
untie_embeddings_and_output_weights
:
bool
=
False
,
)
->
None
:
"""
Initialize model data parallel groups.
...
...
@@ -93,6 +94,9 @@ def initialize_model_parallel(
pipeline_model_parallel_split_rank is 3, then ranks 0-2
will be the encoder and ranks 3-7 will be the decoder.
untie_embeddings_and_output_weights: whether to use separate embedding and output layer.
this affects the computation of embedding groups
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
...
...
@@ -200,13 +204,19 @@ def initialize_model_parallel(
# Setup embedding group (to exchange gradients between
# first and last stages).
if
len
(
ranks
)
>
1
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
-
1
]]
if
untie_embeddings_and_output_weights
:
embedding_ranks
=
[
ranks
[
0
]]
else
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
-
1
]]
position_embedding_ranks
=
[
ranks
[
0
]]
if
pipeline_model_parallel_split_rank
is
not
None
:
if
ranks
[
pipeline_model_parallel_split_rank
]
not
in
embedding_ranks
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank
],
ranks
[
-
1
]]
if
untie_embeddings_and_output_weights
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank
]]
else
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank
],
ranks
[
-
1
]]
if
ranks
[
pipeline_model_parallel_split_rank
]
not
in
position_embedding_ranks
:
position_embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank
]]
...
...
megatron/initialize.py
View file @
82c7ba57
...
...
@@ -185,7 +185,8 @@ def _initialize_distributed():
mpu
.
initialize_model_parallel
(
args
.
tensor_model_parallel_size
,
args
.
pipeline_model_parallel_size
,
args
.
virtual_pipeline_model_parallel_size
,
args
.
pipeline_model_parallel_split_rank
)
args
.
pipeline_model_parallel_split_rank
,
args
.
untie_embeddings_and_output_weights
)
if
args
.
rank
==
0
:
print
(
f
'> initialized tensor model parallel with size '
f
'
{
mpu
.
get_tensor_model_parallel_world_size
()
}
'
)
...
...
megatron/model/gpt_model.py
View file @
82c7ba57
...
...
@@ -57,6 +57,7 @@ class GPTModel(MegatronModule):
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
untie_embeddings_and_output_weights
=
args
.
untie_embeddings_and_output_weights
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_tokentypes
=
num_tokentypes
,
...
...
@@ -90,7 +91,7 @@ class GPTModel(MegatronModule):
if
self
.
post_process
:
return
post_language_model_processing
(
lm_output
,
labels
,
self
.
word_embeddings_weight
(),
self
.
language_model
.
output_layer
.
weight
if
self
.
untie_embeddings_and_output_weights
else
self
.
word_embeddings_weight
(),
self
.
parallel_output
,
self
.
fp16_lm_cross_entropy
)
else
:
...
...
megatron/model/language_model.py
View file @
82c7ba57
...
...
@@ -326,8 +326,9 @@ class TransformerLanguageModel(MegatronModule):
add_pooler
=
False
,
pre_process
=
True
,
post_process
=
True
):
super
(
TransformerLanguageModel
,
self
).
__init__
()
args
=
get_args
()
# TODO: passing share_word_embeddings=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
super
(
TransformerLanguageModel
,
self
).
__init__
(
share_word_embeddings
=
not
args
.
untie_embeddings_and_output_weights
)
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
...
...
@@ -340,6 +341,7 @@ class TransformerLanguageModel(MegatronModule):
self
.
decoder_attn_mask_type
=
decoder_attn_mask_type
self
.
add_pooler
=
add_pooler
self
.
encoder_hidden_state
=
None
self
.
untie_embeddings_and_output_weights
=
args
.
untie_embeddings_and_output_weights
# Embeddings.
if
self
.
pre_process
:
...
...
@@ -408,6 +410,14 @@ class TransformerLanguageModel(MegatronModule):
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
_pooler_key
=
'pooler'
if
self
.
untie_embeddings_and_output_weights
:
self
.
output_layer
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
padded_vocab_size
,
bias
=
False
,
# Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
init_method
=
output_layer_init_method
)
self
.
_output_layer_key
=
'output_layer'
def
set_input_tensor
(
self
,
input_tensor
):
""" See megatron.model.transformer.set_input_tensor()"""
...
...
@@ -529,6 +539,10 @@ class TransformerLanguageModel(MegatronModule):
state_dict_
[
self
.
_pooler_key
]
\
=
self
.
pooler
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
untie_embeddings_and_output_weights
:
state_dict_
[
self
.
_output_layer_key
]
\
=
self
.
output_layer
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
add_decoder
:
state_dict_
[
self
.
_decoder_key
]
\
=
self
.
decoder
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
...
...
@@ -584,6 +598,11 @@ class TransformerLanguageModel(MegatronModule):
'could not find data for pooler in the checkpoint'
self
.
pooler
.
load_state_dict
(
state_dict
[
self
.
_pooler_key
],
strict
=
strict
)
if
self
.
untie_embeddings_and_output_weights
:
assert
'output_layer'
in
state_dict
,
\
'could not find data for output_layer in the checkpoint'
self
.
output_layer
.
load_state_dict
(
state_dict
[
self
.
_output_layer_key
],
strict
=
strict
)
# Decoder.
if
self
.
add_decoder
:
assert
'decoder'
in
state_dict
,
\
...
...
megatron/model/module.py
View file @
82c7ba57
...
...
@@ -71,7 +71,7 @@ class MegatronModule(torch.nn.Module):
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if
mpu
.
is_pipeline_last_stage
()
and
\
not
self
.
pre_process
:
not
self
.
pre_process
and
not
self
.
untie_embeddings_and_output_weights
:
assert
not
mpu
.
is_pipeline_first_stage
()
self
.
_word_embeddings_for_head_key
=
'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment