Commit 6aaafee6 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

bert regression fixes

parent 5d647381
...@@ -78,7 +78,12 @@ class BertLMHead(MegatronModule): ...@@ -78,7 +78,12 @@ class BertLMHead(MegatronModule):
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon) setattr(self.dense.weight, 'sequence_parallel', args.sequence_parallel)
setattr(self.dense.bias, 'sequence_parallel', args.sequence_parallel)
self.layernorm = LayerNorm(hidden_size,
eps=layernorm_epsilon,
sequence_parallel=args.sequence_parallel)
self.gelu = torch.nn.functional.gelu self.gelu = torch.nn.functional.gelu
if args.openai_gelu: if args.openai_gelu:
self.gelu = openai_gelu self.gelu = openai_gelu
...@@ -106,7 +111,7 @@ def post_language_model_processing(lm_output, pooled_output, ...@@ -106,7 +111,7 @@ def post_language_model_processing(lm_output, pooled_output,
lm_output, logit_weights) lm_output, logit_weights)
binary_logits = None binary_logits = None
if binary_head is not None: if binary_head is not None and pooled_output is not None:
binary_logits = binary_head(pooled_output) binary_logits = binary_head(pooled_output)
if lm_labels is None: if lm_labels is None:
......
...@@ -412,6 +412,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -412,6 +412,7 @@ class TransformerLanguageModel(MegatronModule):
pooling_sequence_index=0, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False): enc_hidden_states=None, output_enc_hidden=False):
args = get_args()
# Encoder embedding. # Encoder embedding.
if self.pre_process: if self.pre_process:
encoder_input = self.embedding(enc_input_ids, enc_position_ids, encoder_input = self.embedding(enc_input_ids, enc_position_ids,
...@@ -433,8 +434,21 @@ class TransformerLanguageModel(MegatronModule): ...@@ -433,8 +434,21 @@ class TransformerLanguageModel(MegatronModule):
if self.post_process: if self.post_process:
if self.add_pooler: if self.add_pooler:
pooled_output = self.pooler(encoder_output, if args.sequence_parallel:
pooling_sequence_index) # encoder output is split along sequence dimension
# consider appropriate rank based on pooling sequence index
# binary head loss is only computed in just one rank.
seq_denom = args.seq_length // args.tensor_model_parallel_size
seq_rank = mpu.get_tensor_model_parallel_rank()
if pooling_sequence_index // seq_denom == seq_rank:
pooled_output = self.pooler(
encoder_output,
pooling_sequence_index % seq_denom)
else:
pooled_output = None
else:
pooled_output = self.pooler(encoder_output,
pooling_sequence_index)
# output_enc_hidden refers to when we just need the encoder's # output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute # output. For example, it is helpful to compute
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment