Commit 3f91f09b authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

avoiding sequence parallelism on the pooler

parent 6aaafee6
......@@ -81,7 +81,7 @@ class BertLMHead(MegatronModule):
setattr(self.dense.weight, 'sequence_parallel', args.sequence_parallel)
setattr(self.dense.bias, 'sequence_parallel', args.sequence_parallel)
self.layernorm = LayerNorm(hidden_size,
self.layernorm = LayerNorm(hidden_size,
eps=layernorm_epsilon,
sequence_parallel=args.sequence_parallel)
self.gelu = torch.nn.functional.gelu
......@@ -111,7 +111,7 @@ def post_language_model_processing(lm_output, pooled_output,
lm_output, logit_weights)
binary_logits = None
if binary_head is not None and pooled_output is not None:
if binary_head is not None:
binary_logits = binary_head(pooled_output)
if lm_labels is None:
......
......@@ -104,11 +104,20 @@ class Pooler(MegatronModule):
def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__()
args = get_args()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.sequence_parallel = args.sequence_parallel
def forward(self, hidden_states, sequence_index=0):
# hidden_states: [s, b, h]
# sequence_index: index of the token to pool.
# gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes
if self.sequence_parallel:
hidden_states = mpu.gather_from_sequence_parallel_region(hidden_states)
pooled = hidden_states[sequence_index, :, :]
pooled = self.dense(pooled)
pooled = torch.tanh(pooled)
......@@ -412,7 +421,6 @@ class TransformerLanguageModel(MegatronModule):
pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
args = get_args()
# Encoder embedding.
if self.pre_process:
encoder_input = self.embedding(enc_input_ids, enc_position_ids,
......@@ -434,21 +442,8 @@ class TransformerLanguageModel(MegatronModule):
if self.post_process:
if self.add_pooler:
if args.sequence_parallel:
# encoder output is split along sequence dimension
# consider appropriate rank based on pooling sequence index
# binary head loss is only computed in just one rank.
seq_denom = args.seq_length // args.tensor_model_parallel_size
seq_rank = mpu.get_tensor_model_parallel_rank()
if pooling_sequence_index // seq_denom == seq_rank:
pooled_output = self.pooler(
encoder_output,
pooling_sequence_index % seq_denom)
else:
pooled_output = None
else:
pooled_output = self.pooler(encoder_output,
pooling_sequence_index)
pooled_output = self.pooler(encoder_output,
pooling_sequence_index)
# output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment