Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
7be2648a
Commit
7be2648a
authored
Jan 25, 2021
by
Jared Casper
Browse files
Clarify module.initialize_word_embeddings.
parent
c4c68dce
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
10 deletions
+15
-10
megatron/model/module.py
megatron/model/module.py
+15
-10
No files found.
megatron/model/module.py
View file @
7be2648a
...
@@ -60,8 +60,13 @@ class MegatronModule(torch.nn.Module):
...
@@ -60,8 +60,13 @@ class MegatronModule(torch.nn.Module):
if
not
self
.
share_word_embeddings
:
if
not
self
.
share_word_embeddings
:
raise
Exception
(
'initialize_word_embeddings() was called but '
raise
Exception
(
'initialize_word_embeddings() was called but '
'share_word_embeddings is false'
)
'share_word_embeddings is false'
)
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. If we aren't using pipeline
# parallelism there is nothing to do.
if
args
.
pipeline_model_parallel_size
==
1
:
if
args
.
pipeline_model_parallel_size
==
1
:
return
return
# Parameters are shared between the word embeddings layer, and the
# Parameters are shared between the word embeddings layer, and the
# heads at the end of the model. In a pipelined setup with more than
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
# one stage, the initial embedding layer and the head are on different
...
@@ -75,16 +80,16 @@ class MegatronModule(torch.nn.Module):
...
@@ -75,16 +80,16 @@ class MegatronModule(torch.nn.Module):
# the two word_embeddings layers to ensure that every applied weight
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
# update is the same on both stages.
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
if
not
mpu
.
is_pipeline_first_stage
()
:
assert
not
mpu
.
is_pipeline_first_stage
()
self
.
_word_embeddings_for_head_key
=
'word_embeddings_for_head'
self
.
_word_embeddings_for_head_key
=
'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings
# set word_embeddings weights to 0 here, then copy first
# weights to 0 here, then copy first
stage's weights using
#
stage's weights using
all_reduce below.
# all_reduce below.
self
.
word_embeddings
=
mpu
.
VocabParallelEmbedding
(
self
.
word_embeddings
=
mpu
.
VocabParallelEmbedding
(
args
.
padded_vocab_size
,
args
.
hidden_size
,
args
.
padded_vocab_size
,
args
.
hidden_size
,
init_method
=
init_method_normal
(
args
.
init_method_std
))
init_method
=
init_method_normal
(
args
.
init_method_std
)
)
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
shared
=
True
self
.
word_embeddings
.
weight
.
shared
=
True
# Ensure that first and last stages have the same initial parameter
# Ensure that first and last stages have the same initial parameter
# values.
# values.
if
mpu
.
is_pipeline_first_stage
()
or
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_first_stage
()
or
mpu
.
is_pipeline_last_stage
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment