Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
018391a6
Commit
018391a6
authored
Apr 06, 2023
by
Jared Casper
Browse files
Merge branch 'untie_embeddings' into 'main'
Untie Embeddings See merge request ADLR/megatron-lm!558
parents
4e891fe9
1e2a0405
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
35 additions
and
9 deletions
+35
-9
megatron/arguments.py
megatron/arguments.py
+2
-0
megatron/model/bert_model.py
megatron/model/bert_model.py
+3
-0
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+8
-6
megatron/model/language_model.py
megatron/model/language_model.py
+21
-1
megatron/model/module.py
megatron/model/module.py
+1
-2
No files found.
megatron/arguments.py
View file @
018391a6
...
...
@@ -557,6 +557,8 @@ def _add_network_size_args(parser):
dest
=
'bert_binary_head'
)
group
.
add_argument
(
'--num-experts'
,
type
=
int
,
default
=
None
,
help
=
'Number of Experts in Switch Transformer (None means no Switch)'
)
group
.
add_argument
(
'--untie-embeddings-and-output-weights'
,
action
=
'store_true'
,
help
=
'Untie embeddings and output weights.'
),
return
parser
...
...
megatron/model/bert_model.py
View file @
018391a6
...
...
@@ -132,6 +132,9 @@ class BertModel(MegatronModule):
super
(
BertModel
,
self
).
__init__
()
args
=
get_args
()
# TODO this option is not yet implemented in BERT
assert
args
.
untie_embeddings_and_output_weights
is
False
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
add_binary_head
=
add_binary_head
self
.
parallel_output
=
parallel_output
...
...
megatron/model/gpt_model.py
View file @
018391a6
...
...
@@ -50,13 +50,14 @@ class GPTModel(MegatronModule):
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
):
super
(
GPTModel
,
self
).
__init__
()
args
=
get_args
()
super
(
GPTModel
,
self
).
__init__
(
share_word_embeddings
=
not
args
.
untie_embeddings_and_output_weights
)
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
untie_embeddings_and_output_weights
=
args
.
untie_embeddings_and_output_weights
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
num_tokentypes
=
num_tokentypes
,
...
...
@@ -68,6 +69,7 @@ class GPTModel(MegatronModule):
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
if
not
args
.
untie_embeddings_and_output_weights
:
self
.
initialize_word_embeddings
(
init_method_normal
)
def
set_input_tensor
(
self
,
input_tensor
):
...
...
@@ -90,7 +92,7 @@ class GPTModel(MegatronModule):
if
self
.
post_process
:
return
post_language_model_processing
(
lm_output
,
labels
,
self
.
word_embeddings_weight
(),
self
.
language_model
.
output_layer
.
weight
if
self
.
untie_embeddings_and_output_weights
else
self
.
word_embeddings_weight
(),
self
.
parallel_output
,
self
.
fp16_lm_cross_entropy
)
else
:
...
...
@@ -103,7 +105,7 @@ class GPTModel(MegatronModule):
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
# Save word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
:
if
self
.
post_process
and
not
self
.
pre_process
and
not
self
.
untie_embeddings_and_output_weights
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
=
self
.
word_embeddings
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
...
...
@@ -113,7 +115,7 @@ class GPTModel(MegatronModule):
"""Customized load."""
# Load word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
:
if
self
.
post_process
and
not
self
.
pre_process
and
not
self
.
untie_embeddings_and_output_weights
:
self
.
word_embeddings
.
load_state_dict
(
state_dict
[
self
.
_word_embeddings_for_head_key
],
strict
=
strict
)
if
self
.
_language_model_key
in
state_dict
:
...
...
megatron/model/language_model.py
View file @
018391a6
...
...
@@ -336,8 +336,10 @@ class TransformerLanguageModel(MegatronModule):
add_pooler
=
False
,
pre_process
=
True
,
post_process
=
True
):
super
(
TransformerLanguageModel
,
self
).
__init__
()
args
=
get_args
()
# TODO: passing share_word_embeddings=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
if
args
.
untie_embeddings_and_output_weights
:
assert
not
add_decoder
super
(
TransformerLanguageModel
,
self
).
__init__
(
share_word_embeddings
=
not
args
.
untie_embeddings_and_output_weights
)
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
...
...
@@ -350,6 +352,7 @@ class TransformerLanguageModel(MegatronModule):
self
.
decoder_attn_mask_type
=
decoder_attn_mask_type
self
.
add_pooler
=
add_pooler
self
.
encoder_hidden_state
=
None
self
.
untie_embeddings_and_output_weights
=
args
.
untie_embeddings_and_output_weights
# Embeddings.
if
self
.
pre_process
:
...
...
@@ -434,6 +437,14 @@ class TransformerLanguageModel(MegatronModule):
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
self
.
_pooler_key
=
'pooler'
if
self
.
untie_embeddings_and_output_weights
:
self
.
output_layer
=
tensor_parallel
.
ColumnParallelLinear
(
args
.
hidden_size
,
args
.
padded_vocab_size
,
bias
=
False
,
# Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
init_method
=
self
.
init_method
)
self
.
_output_layer_key
=
'output_layer'
def
set_input_tensor
(
self
,
input_tensor
):
""" See megatron.model.transformer.set_input_tensor()"""
...
...
@@ -566,6 +577,10 @@ class TransformerLanguageModel(MegatronModule):
state_dict_
[
self
.
_pooler_key
]
\
=
self
.
pooler
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
untie_embeddings_and_output_weights
:
state_dict_
[
self
.
_output_layer_key
]
\
=
self
.
output_layer
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
add_decoder
:
state_dict_
[
self
.
_decoder_key
]
\
=
self
.
decoder
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
...
...
@@ -621,6 +636,11 @@ class TransformerLanguageModel(MegatronModule):
'could not find data for pooler in the checkpoint'
self
.
pooler
.
load_state_dict
(
state_dict
[
self
.
_pooler_key
],
strict
=
strict
)
if
self
.
untie_embeddings_and_output_weights
:
assert
'output_layer'
in
state_dict
,
\
'could not find data for output_layer in the checkpoint'
self
.
output_layer
.
load_state_dict
(
state_dict
[
self
.
_output_layer_key
],
strict
=
strict
)
# Decoder.
if
self
.
add_decoder
:
assert
'decoder'
in
state_dict
,
\
...
...
megatron/model/module.py
View file @
018391a6
...
...
@@ -70,8 +70,7 @@ class MegatronModule(torch.nn.Module):
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if
mpu
.
is_pipeline_last_stage
()
and
\
not
self
.
pre_process
:
if
mpu
.
is_pipeline_last_stage
()
and
not
self
.
pre_process
:
assert
not
mpu
.
is_pipeline_first_stage
()
self
.
_word_embeddings_for_head_key
=
'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment