Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f11b4c99
Commit
f11b4c99
authored
Jan 31, 2023
by
Jimmy Zhang
Browse files
disable embedding addreduce if untie_embeddings_and_output_weights
parent
a3fbac58
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
5 deletions
+4
-5
megatron/arguments.py
megatron/arguments.py
+0
-2
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+4
-3
No files found.
megatron/arguments.py
View file @
f11b4c99
...
...
@@ -349,7 +349,6 @@ def validate_args(args, defaults={}):
"Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1"
)
# Load retro args.
if
args
.
retro_workdir
:
retro_args_path
=
get_retro_args_path
(
args
.
retro_workdir
)
...
...
@@ -368,7 +367,6 @@ def validate_args(args, defaults={}):
if
retro_args
and
args
!=
retro_args
:
_print_args
(
"retro arguments"
,
types
.
SimpleNamespace
(
**
{
k
:
v
for
k
,
v
in
vars
(
retro_args
).
items
()
if
k
.
startswith
(
"retro"
)},
rank
=
args
.
rank
))
return
args
...
...
megatron/model/gpt_model.py
View file @
f11b4c99
...
...
@@ -50,8 +50,8 @@ class GPTModel(MegatronModule):
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
):
super
(
GPTModel
,
self
).
__init__
()
args
=
get_args
()
super
(
GPTModel
,
self
).
__init__
(
share_word_embeddings
=
not
args
.
untie_embeddings_and_output_weights
)
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
...
...
@@ -68,8 +68,9 @@ class GPTModel(MegatronModule):
args
.
num_layers
),
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
initialize_word_embeddings
(
init_method_normal
)
if
not
args
.
untie_embeddings_and_output_weights
:
self
.
initialize_word_embeddings
(
init_method_normal
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment