You need to sign in or sign up before continuing.
Commit 6aaafee6 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

bert regression fixes

parent 5d647381
......@@ -78,7 +78,12 @@ class BertLMHead(MegatronModule):
self.parallel_output = parallel_output
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
setattr(self.dense.weight, 'sequence_parallel', args.sequence_parallel)
setattr(self.dense.bias, 'sequence_parallel', args.sequence_parallel)
self.layernorm = LayerNorm(hidden_size,
eps=layernorm_epsilon,
sequence_parallel=args.sequence_parallel)
self.gelu = torch.nn.functional.gelu
if args.openai_gelu:
self.gelu = openai_gelu
......@@ -106,7 +111,7 @@ def post_language_model_processing(lm_output, pooled_output,
lm_output, logit_weights)
binary_logits = None
if binary_head is not None:
if binary_head is not None and pooled_output is not None:
binary_logits = binary_head(pooled_output)
if lm_labels is None:
......
......@@ -412,6 +412,7 @@ class TransformerLanguageModel(MegatronModule):
pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
args = get_args()
# Encoder embedding.
if self.pre_process:
encoder_input = self.embedding(enc_input_ids, enc_position_ids,
......@@ -433,8 +434,21 @@ class TransformerLanguageModel(MegatronModule):
if self.post_process:
if self.add_pooler:
pooled_output = self.pooler(encoder_output,
pooling_sequence_index)
if args.sequence_parallel:
# encoder output is split along sequence dimension
# consider appropriate rank based on pooling sequence index
# binary head loss is only computed in just one rank.
seq_denom = args.seq_length // args.tensor_model_parallel_size
seq_rank = mpu.get_tensor_model_parallel_rank()
if pooling_sequence_index // seq_denom == seq_rank:
pooled_output = self.pooler(
encoder_output,
pooling_sequence_index % seq_denom)
else:
pooled_output = None
else:
pooled_output = self.pooler(encoder_output,
pooling_sequence_index)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment