Commit 4ae54b55 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

Adressing more review comments

parent d836d498
......@@ -539,9 +539,10 @@ def _add_data_args(parser):
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file.')
group.add_argument('--seq-length', type=int, default=None,
help="Maximum sequence length to process.")
help='Maximum sequence length to process.')
group.add_argument('--encoder-seq-length', type=int, default=None,
help="Maximum encoder sequence length to process.")
help='Maximum encoder sequence length to process.'
'This should be exclusive of --seq-length')
group.add_argument('--decoder-seq-length', type=int, default=None,
help="Maximum decoder sequence length to process.")
group.add_argument('--mask-prob', type=float, default=0.15,
......
......@@ -19,6 +19,7 @@ import torch
from megatron import get_args
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model
from megatron.model import import_layernorm
......@@ -147,6 +148,7 @@ class BertModelBase(MegatronModule):
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method)
......
......@@ -19,6 +19,7 @@ import torch
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
......@@ -40,6 +41,7 @@ class ClassificationBase(MegatronModule):
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
......
......@@ -76,7 +76,7 @@ class GPT2ModelBase(MegatronModule):
attention_mask_func=gpt2_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=False,
self_attn_mask_type=AttnMaskType.causal,
encoder_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
......
......@@ -44,9 +44,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
add_decoder=False, init_method=None,
scaled_init_method=None,
self_attn_mask_type=AttnMaskType.padding):
encoder_attn_mask_type, init_method=None,
scaled_init_method=None, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal):
"""Build language model and return along with the key to save."""
args = get_args()
......@@ -58,14 +58,15 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
args.num_layers)
# Language model.
args = [attention_mask_func, init_method, scaled_init_method]
args = [attention_mask_func, init_method,
scaled_init_method, encoder_attn_mask_type]
kwargs = {}
cls = None
if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
cls = TransformerLanguageModel
kwargs['num_tokentypes'] = num_tokentypes
kwargs['self_attn_mask_type'] = self_attn_mask_type
kwargs['add_decoder'] = add_decoder
kwargs['decoder_attn_mask_type'] = decoder_attn_mask_type
kwargs['add_pooler'] = add_pooler
elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage():
cls = TransformerLanguageModelFirstStage
......@@ -192,6 +193,8 @@ class Embedding(MegatronModule):
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Dropout.
embeddings = self.embedding_dropout(embeddings)
......@@ -284,9 +287,10 @@ class TransformerLanguageModelBase(MegatronModule):
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0,
self_attn_mask_type=AttnMaskType.padding,
add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False):
super(TransformerLanguageModelBase, self).__init__()
args = get_args()
......@@ -294,8 +298,9 @@ class TransformerLanguageModelBase(MegatronModule):
self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method = init_method
self.self_attn_mask_type = self_attn_mask_type
self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler
# Embeddings.
......@@ -313,7 +318,7 @@ class TransformerLanguageModelBase(MegatronModule):
attention_mask_func,
self.init_method,
output_layer_init_method,
self_attn_mask_type=self_attn_mask_type)
self_attn_mask_type=self.encoder_attn_mask_type)
self._encoder_key = 'encoder'
# Decoder
......@@ -325,7 +330,7 @@ class TransformerLanguageModelBase(MegatronModule):
self.init_method,
output_layer_init_method,
layer_type=LayerType.decoder,
self_attn_mask_type=AttnMaskType.causal)
self_attn_mask_type=self.decoder_attn_mask_type)
self._decoder_key = 'decoder'
if mpu.is_pipeline_last_stage():
......@@ -334,7 +339,7 @@ class TransformerLanguageModelBase(MegatronModule):
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
def forward(self, enc_language_model_input, enc_attention_mask,
def forward(self, enc_language_model_input, enc_attn_mask,
dec_language_model_input=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
......@@ -352,7 +357,7 @@ class TransformerLanguageModelBase(MegatronModule):
# encoder.
if enc_hidden_states is None:
encoder_output = self.encoder(encoder_input,
enc_attention_mask,
enc_attn_mask,
layer_past=layer_past,
get_key_value=get_key_value)
else:
......@@ -438,8 +443,8 @@ class TransformerLanguageModelBase(MegatronModule):
# for backward compatibility.
state_dict_ = {}
for key in state_dict.keys():
if 'encoder.' in key:
state_dict_[key.split('encoder.')[1]] = state_dict[key]
if 'transformer.' in key:
state_dict_[key.split('transformer.')[1]] = state_dict[key]
# for backward compatibility.
state_dict_self_attention = {}
......@@ -477,27 +482,29 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0,
self_attn_mask_type=AttnMaskType.padding,
decoder_attn_mask_type=AttnMaskType.causal,
add_decoder=False,
add_pooler=False):
super(TransformerLanguageModel, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes,
self_attn_mask_type=self_attn_mask_type,
add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler)
def forward(self, enc_input_ids, enc_position_ids, enc_attention_mask,
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False):
return super(TransformerLanguageModel, self).forward(
(enc_input_ids, enc_position_ids),
enc_attention_mask,
enc_attn_mask,
dec_language_model_input=(dec_input_ids, dec_position_ids),
dec_attn_mask=dec_attn_mask,
enc_dec_attn_mask=enc_dec_attn_mask,
......@@ -519,14 +526,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=0,
self_attn_mask_type=AttnMaskType.padding):
encoder_attn_mask_type,
num_tokentypes=0):
super(TransformerLanguageModelFirstStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
num_tokentypes=num_tokentypes,
self_attn_mask_type=self_attn_mask_type)
encoder_attn_mask_type,
num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False):
......@@ -548,12 +555,12 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
attention_mask_func,
init_method,
output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding):
encoder_attn_mask_type):
super(TransformerLanguageModelIntermediateStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
self_attn_mask_type=self_attn_mask_type)
encoder_attn_mask_type)
def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False):
......@@ -574,13 +581,13 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
attention_mask_func,
init_method,
output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding,
encoder_attn_mask_type,
add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func,
init_method,
output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding,
encoder_attn_mask_type,
add_pooler=add_pooler)
def forward(self, hidden_states, attention_mask,
......
......@@ -19,6 +19,7 @@ import torch
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer
......@@ -39,6 +40,7 @@ class MultipleChoiceBase(MegatronModule):
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
......
......@@ -6,6 +6,7 @@ from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoi
from megatron.model import BertModel
from .module import MegatronModule
from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model
......@@ -159,6 +160,7 @@ class IREncoderBertModel(MegatronModule):
attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes,
add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method,
scaled_init_method=scaled_init_method)
......
......@@ -14,7 +14,6 @@
# limitations under the License.
"""Transformer."""
import enum
import math
import torch
import torch.nn.functional as F
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment