Commit 3f91f09b authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

avoiding sequence parallelism on the pooler

parent 6aaafee6
...@@ -81,7 +81,7 @@ class BertLMHead(MegatronModule): ...@@ -81,7 +81,7 @@ class BertLMHead(MegatronModule):
setattr(self.dense.weight, 'sequence_parallel', args.sequence_parallel) setattr(self.dense.weight, 'sequence_parallel', args.sequence_parallel)
setattr(self.dense.bias, 'sequence_parallel', args.sequence_parallel) setattr(self.dense.bias, 'sequence_parallel', args.sequence_parallel)
self.layernorm = LayerNorm(hidden_size, self.layernorm = LayerNorm(hidden_size,
eps=layernorm_epsilon, eps=layernorm_epsilon,
sequence_parallel=args.sequence_parallel) sequence_parallel=args.sequence_parallel)
self.gelu = torch.nn.functional.gelu self.gelu = torch.nn.functional.gelu
...@@ -111,7 +111,7 @@ def post_language_model_processing(lm_output, pooled_output, ...@@ -111,7 +111,7 @@ def post_language_model_processing(lm_output, pooled_output,
lm_output, logit_weights) lm_output, logit_weights)
binary_logits = None binary_logits = None
if binary_head is not None and pooled_output is not None: if binary_head is not None:
binary_logits = binary_head(pooled_output) binary_logits = binary_head(pooled_output)
if lm_labels is None: if lm_labels is None:
......
...@@ -104,11 +104,20 @@ class Pooler(MegatronModule): ...@@ -104,11 +104,20 @@ class Pooler(MegatronModule):
def __init__(self, hidden_size, init_method): def __init__(self, hidden_size, init_method):
super(Pooler, self).__init__() super(Pooler, self).__init__()
args = get_args()
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) self.dense = get_linear_layer(hidden_size, hidden_size, init_method)
self.sequence_parallel = args.sequence_parallel
def forward(self, hidden_states, sequence_index=0): def forward(self, hidden_states, sequence_index=0):
# hidden_states: [s, b, h] # hidden_states: [s, b, h]
# sequence_index: index of the token to pool. # sequence_index: index of the token to pool.
# gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes
if self.sequence_parallel:
hidden_states = mpu.gather_from_sequence_parallel_region(hidden_states)
pooled = hidden_states[sequence_index, :, :] pooled = hidden_states[sequence_index, :, :]
pooled = self.dense(pooled) pooled = self.dense(pooled)
pooled = torch.tanh(pooled) pooled = torch.tanh(pooled)
...@@ -412,7 +421,6 @@ class TransformerLanguageModel(MegatronModule): ...@@ -412,7 +421,6 @@ class TransformerLanguageModel(MegatronModule):
pooling_sequence_index=0, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False): enc_hidden_states=None, output_enc_hidden=False):
args = get_args()
# Encoder embedding. # Encoder embedding.
if self.pre_process: if self.pre_process:
encoder_input = self.embedding(enc_input_ids, enc_position_ids, encoder_input = self.embedding(enc_input_ids, enc_position_ids,
...@@ -434,21 +442,8 @@ class TransformerLanguageModel(MegatronModule): ...@@ -434,21 +442,8 @@ class TransformerLanguageModel(MegatronModule):
if self.post_process: if self.post_process:
if self.add_pooler: if self.add_pooler:
if args.sequence_parallel: pooled_output = self.pooler(encoder_output,
# encoder output is split along sequence dimension pooling_sequence_index)
# consider appropriate rank based on pooling sequence index
# binary head loss is only computed in just one rank.
seq_denom = args.seq_length // args.tensor_model_parallel_size
seq_rank = mpu.get_tensor_model_parallel_rank()
if pooling_sequence_index // seq_denom == seq_rank:
pooled_output = self.pooler(
encoder_output,
pooling_sequence_index % seq_denom)
else:
pooled_output = None
else:
pooled_output = self.pooler(encoder_output,
pooling_sequence_index)
# output_enc_hidden refers to when we just need the encoder's # output_enc_hidden refers to when we just need the encoder's
# output. For example, it is helpful to compute # output. For example, it is helpful to compute
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment