Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
6aaafee6
You need to sign in or sign up before continuing.
Commit
6aaafee6
authored
May 20, 2022
by
Vijay Korthikanti
Browse files
bert regression fixes
parent
5d647381
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
4 deletions
+23
-4
megatron/model/bert_model.py
megatron/model/bert_model.py
+7
-2
megatron/model/language_model.py
megatron/model/language_model.py
+16
-2
No files found.
megatron/model/bert_model.py
View file @
6aaafee6
...
...
@@ -78,7 +78,12 @@ class BertLMHead(MegatronModule):
self
.
parallel_output
=
parallel_output
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
setattr
(
self
.
dense
.
weight
,
'sequence_parallel'
,
args
.
sequence_parallel
)
setattr
(
self
.
dense
.
bias
,
'sequence_parallel'
,
args
.
sequence_parallel
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
,
sequence_parallel
=
args
.
sequence_parallel
)
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
if
args
.
openai_gelu
:
self
.
gelu
=
openai_gelu
...
...
@@ -106,7 +111,7 @@ def post_language_model_processing(lm_output, pooled_output,
lm_output
,
logit_weights
)
binary_logits
=
None
if
binary_head
is
not
None
:
if
binary_head
is
not
None
and
pooled_output
is
not
None
:
binary_logits
=
binary_head
(
pooled_output
)
if
lm_labels
is
None
:
...
...
megatron/model/language_model.py
View file @
6aaafee6
...
...
@@ -412,6 +412,7 @@ class TransformerLanguageModel(MegatronModule):
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
args
=
get_args
()
# Encoder embedding.
if
self
.
pre_process
:
encoder_input
=
self
.
embedding
(
enc_input_ids
,
enc_position_ids
,
...
...
@@ -433,8 +434,21 @@ class TransformerLanguageModel(MegatronModule):
if
self
.
post_process
:
if
self
.
add_pooler
:
pooled_output
=
self
.
pooler
(
encoder_output
,
pooling_sequence_index
)
if
args
.
sequence_parallel
:
# encoder output is split along sequence dimension
# consider appropriate rank based on pooling sequence index
# binary head loss is only computed in just one rank.
seq_denom
=
args
.
seq_length
//
args
.
tensor_model_parallel_size
seq_rank
=
mpu
.
get_tensor_model_parallel_rank
()
if
pooling_sequence_index
//
seq_denom
==
seq_rank
:
pooled_output
=
self
.
pooler
(
encoder_output
,
pooling_sequence_index
%
seq_denom
)
else
:
pooled_output
=
None
else
:
pooled_output
=
self
.
pooler
(
encoder_output
,
pooling_sequence_index
)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment