Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
4b3519cb
Commit
4b3519cb
authored
Jan 12, 2021
by
Vijay Korthikanti
Browse files
address review comments
parent
834d6dd5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
57 deletions
+46
-57
megatron/model/language_model.py
megatron/model/language_model.py
+46
-57
No files found.
megatron/model/language_model.py
View file @
4b3519cb
...
...
@@ -316,19 +316,19 @@ class TransformerLanguageModelBase(MegatronModule):
self_attn_mask_type
=
self_attn_mask_type
)
self
.
_encoder_key
=
'encoder'
# assuming pooler and decoder are in the last stage
# of the pipeline(to be revised)
if
mpu
.
is_pipeline_last_stage
():
# decoder
if
self
.
add_decoder
:
self
.
decoder
=
ParallelTransformer
(
attention_mask_func
,
self
.
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
decoder
,
self_attn_mask_type
=
AttnMaskType
.
causal
)
self
.
_decoder_key
=
'decoder'
# Decoder
if
self
.
add_decoder
:
assert
args
.
pipeline_model_parallel_size
==
1
,
\
'pipeline parallelism is not supported in the presence of decoder'
self
.
decoder
=
ParallelTransformer
(
attention_mask_func
,
self
.
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
decoder
,
self_attn_mask_type
=
AttnMaskType
.
causal
)
self
.
_decoder_key
=
'decoder'
if
mpu
.
is_pipeline_last_stage
():
# Pooler.
if
self
.
add_pooler
:
self
.
pooler
=
Pooler
(
self
.
hidden_size
,
self
.
init_method
)
...
...
@@ -363,33 +363,31 @@ class TransformerLanguageModelBase(MegatronModule):
pooled_output
=
self
.
pooler
(
encoder_output
,
pooling_sequence_index
)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if
not
self
.
add_decoder
or
output_enc_hidden
:
if
self
.
add_pooler
:
return
encoder_output
,
pooled_output
else
:
return
encoder_output
# Decoder Embedding
(
dec_input_ids
,
dec_position_ids
)
=
dec_language_model_input
dec_embedding_output
=
self
.
embedding
(
dec_input_ids
,
dec_position_ids
)
# decoder
decoder_output
=
self
.
decoder
(
dec_embedding_output
,
dec_attn_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
)
if
self
.
add_pooler
:
return
decoder_output
,
encoder_output
,
pooled_output
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
# similarity between two sequences by average pooling
if
not
self
.
add_decoder
or
output_enc_hidden
:
if
self
.
add_pooler
and
mpu
.
is_pipeline_last_stage
():
return
encoder_output
,
pooled_output
else
:
return
decoder_output
,
encoder_output
return
encoder_output
return
encoder_output
# Decoder Embedding
(
dec_input_ids
,
dec_position_ids
)
=
dec_language_model_input
dec_embedding_output
=
self
.
embedding
(
dec_input_ids
,
dec_position_ids
)
# decoder
decoder_output
=
self
.
decoder
(
dec_embedding_output
,
dec_attn_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
)
if
self
.
add_pooler
and
mpu
.
is_pipeline_last_stage
():
return
decoder_output
,
encoder_output
,
pooled_output
else
:
return
decoder_output
,
encoder_output
def
state_dict_for_save_checkpoint
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
...
...
@@ -462,12 +460,12 @@ class TransformerLanguageModelBase(MegatronModule):
'could not find data for pooler in the checkpoint'
self
.
pooler
.
load_state_dict
(
state_dict
[
self
.
_pooler_key
],
strict
=
strict
)
# decoder
if
self
.
add_decoder
:
assert
'decoder'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
self
.
decoder
.
load_state_dict
(
state_dict
[
self
.
_decoder_key
],
strict
=
strict
)
# decoder
if
self
.
add_decoder
:
assert
'decoder'
in
state_dict
,
\
'could not find data for pooler in the checkpoint'
self
.
decoder
.
load_state_dict
(
state_dict
[
self
.
_decoder_key
],
strict
=
strict
)
class
TransformerLanguageModel
(
TransformerLanguageModelBase
):
...
...
@@ -577,30 +575,21 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
init_method
,
output_layer_init_method
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
add_decoder
=
False
,
add_pooler
=
False
):
super
(
TransformerLanguageModelLastStage
,
self
).
__init__
(
attention_mask_func
,
init_method
,
output_layer_init_method
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
add_decoder
=
add_decoder
,
add_pooler
=
add_pooler
)
def
forward
(
self
,
hidden_states
,
enc_attention_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
):
return
super
(
TransformerLanguageModelLastStage
,
self
).
forward
(
hidden_states
,
enc_attention_mask
,
dec_language_model_input
=
(
dec_input_ids
,
dec_position_ids
),
dec_attn_mask
=
dec_attn_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
attention_mask
,
layer_past
=
layer_past
,
get_key_value
=
get_key_value
,
pooling_sequence_index
=
pooling_sequence_index
,
enc_hidden_states
=
enc_hidden_states
,
ouput_enc_hidden
=
output_enc_hidden
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment