Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f772fbc9
Commit
f772fbc9
authored
Jan 05, 2021
by
Jared Casper
Browse files
Only create task heads on last pipeline stage.
parent
6fa36844
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
31 deletions
+37
-31
megatron/model/classification.py
megatron/model/classification.py
+19
-16
megatron/model/multiple_choice.py
megatron/model/multiple_choice.py
+18
-15
No files found.
megatron/model/classification.py
View file @
f772fbc9
...
...
@@ -17,7 +17,7 @@
import
torch
from
megatron
import
get_args
,
print_rank_
0
from
megatron
import
get_args
,
print_rank_
last
from
megatron
import
mpu
from
megatron.model.bert_model
import
bert_attention_mask_func
,
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.language_model
import
get_language_model
...
...
@@ -45,6 +45,7 @@ class ClassificationBase(PipelinedMegatronModule):
args
.
num_layers
))
# Multi-choice head.
if
mpu
.
is_pipeline_last_stage
():
self
.
classification_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
self
.
classification_head
=
get_linear_layer
(
args
.
hidden_size
,
self
.
num_classes
,
...
...
@@ -85,6 +86,7 @@ class ClassificationBase(PipelinedMegatronModule):
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
():
state_dict_
[
self
.
_classification_head_key
]
\
=
self
.
classification_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
...
...
@@ -95,11 +97,12 @@ class ClassificationBase(PipelinedMegatronModule):
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
():
if
self
.
_classification_head_key
in
state_dict
:
self
.
classification_head
.
load_state_dict
(
state_dict
[
self
.
_classification_head_key
],
strict
=
strict
)
else
:
print_rank_
0
(
'***WARNING*** could not find {} in the checkpoint, '
print_rank_
last
(
'***WARNING*** could not find {} in the checkpoint, '
'initializing to random'
.
format
(
self
.
_classification_head_key
))
...
...
megatron/model/multiple_choice.py
View file @
f772fbc9
...
...
@@ -17,7 +17,7 @@
import
torch
from
megatron
import
get_args
,
print_rank_
0
from
megatron
import
get_args
,
print_rank_
last
from
megatron
import
mpu
from
megatron.model.bert_model
import
bert_attention_mask_func
,
bert_extended_attention_mask
,
bert_position_ids
from
megatron.model.language_model
import
get_language_model
...
...
@@ -44,6 +44,7 @@ class MultipleChoiceBase(PipelinedMegatronModule):
args
.
num_layers
))
# Multi-choice head.
if
mpu
.
is_pipeline_last_stage
():
self
.
multichoice_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
self
.
multichoice_head
=
get_linear_layer
(
args
.
hidden_size
,
1
,
init_method
)
...
...
@@ -97,6 +98,7 @@ class MultipleChoiceBase(PipelinedMegatronModule):
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
destination
,
prefix
,
keep_vars
)
if
mpu
.
is_pipeline_last_stage
():
state_dict_
[
self
.
_multichoice_head_key
]
\
=
self
.
multichoice_head
.
state_dict
(
destination
,
prefix
,
keep_vars
)
...
...
@@ -107,11 +109,12 @@ class MultipleChoiceBase(PipelinedMegatronModule):
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
mpu
.
is_pipeline_last_stage
():
if
self
.
_multichoice_head_key
in
state_dict
:
self
.
multichoice_head
.
load_state_dict
(
state_dict
[
self
.
_multichoice_head_key
],
strict
=
strict
)
else
:
print_rank_
0
(
'***WARNING*** could not find {} in the checkpoint, '
print_rank_
last
(
'***WARNING*** could not find {} in the checkpoint, '
'initializing to random'
.
format
(
self
.
_multichoice_head_key
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment