Commit 4ae54b55 authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

Adressing more review comments

parent d836d498
...@@ -539,9 +539,10 @@ def _add_data_args(parser): ...@@ -539,9 +539,10 @@ def _add_data_args(parser):
group.add_argument('--merge-file', type=str, default=None, group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file.') help='Path to the BPE merge file.')
group.add_argument('--seq-length', type=int, default=None, group.add_argument('--seq-length', type=int, default=None,
help="Maximum sequence length to process.") help='Maximum sequence length to process.')
group.add_argument('--encoder-seq-length', type=int, default=None, group.add_argument('--encoder-seq-length', type=int, default=None,
help="Maximum encoder sequence length to process.") help='Maximum encoder sequence length to process.'
'This should be exclusive of --seq-length')
group.add_argument('--decoder-seq-length', type=int, default=None, group.add_argument('--decoder-seq-length', type=int, default=None,
help="Maximum decoder sequence length to process.") help="Maximum decoder sequence length to process.")
group.add_argument('--mask-prob', type=float, default=0.15, group.add_argument('--mask-prob', type=float, default=0.15,
......
...@@ -19,6 +19,7 @@ import torch ...@@ -19,6 +19,7 @@ import torch
from megatron import get_args from megatron import get_args
from megatron import mpu from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.language_model import parallel_lm_logits from megatron.model.language_model import parallel_lm_logits
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model import import_layernorm from megatron.model import import_layernorm
...@@ -147,6 +148,7 @@ class BertModelBase(MegatronModule): ...@@ -147,6 +148,7 @@ class BertModelBase(MegatronModule):
attention_mask_func=bert_attention_mask_func, attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=self.add_binary_head, add_pooler=self.add_binary_head,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method)
......
...@@ -19,6 +19,7 @@ import torch ...@@ -19,6 +19,7 @@ import torch
from megatron import get_args, print_rank_last from megatron import get_args, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
...@@ -40,6 +41,7 @@ class ClassificationBase(MegatronModule): ...@@ -40,6 +41,7 @@ class ClassificationBase(MegatronModule):
attention_mask_func=bert_attention_mask_func, attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
......
...@@ -76,7 +76,7 @@ class GPT2ModelBase(MegatronModule): ...@@ -76,7 +76,7 @@ class GPT2ModelBase(MegatronModule):
attention_mask_func=gpt2_attention_mask_func, attention_mask_func=gpt2_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=False, add_pooler=False,
self_attn_mask_type=AttnMaskType.causal, encoder_attn_mask_type=AttnMaskType.causal,
init_method=init_method_normal(args.init_method_std), init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
......
...@@ -44,9 +44,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, ...@@ -44,9 +44,9 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
def get_language_model(attention_mask_func, num_tokentypes, add_pooler, def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
add_decoder=False, init_method=None, encoder_attn_mask_type, init_method=None,
scaled_init_method=None, scaled_init_method=None, add_decoder=False,
self_attn_mask_type=AttnMaskType.padding): decoder_attn_mask_type=AttnMaskType.causal):
"""Build language model and return along with the key to save.""" """Build language model and return along with the key to save."""
args = get_args() args = get_args()
...@@ -58,14 +58,15 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler, ...@@ -58,14 +58,15 @@ def get_language_model(attention_mask_func, num_tokentypes, add_pooler,
args.num_layers) args.num_layers)
# Language model. # Language model.
args = [attention_mask_func, init_method, scaled_init_method] args = [attention_mask_func, init_method,
scaled_init_method, encoder_attn_mask_type]
kwargs = {} kwargs = {}
cls = None cls = None
if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage(): if mpu.is_pipeline_first_stage() and mpu.is_pipeline_last_stage():
cls = TransformerLanguageModel cls = TransformerLanguageModel
kwargs['num_tokentypes'] = num_tokentypes kwargs['num_tokentypes'] = num_tokentypes
kwargs['self_attn_mask_type'] = self_attn_mask_type
kwargs['add_decoder'] = add_decoder kwargs['add_decoder'] = add_decoder
kwargs['decoder_attn_mask_type'] = decoder_attn_mask_type
kwargs['add_pooler'] = add_pooler kwargs['add_pooler'] = add_pooler
elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage(): elif mpu.is_pipeline_first_stage() and not mpu.is_pipeline_last_stage():
cls = TransformerLanguageModelFirstStage cls = TransformerLanguageModelFirstStage
...@@ -192,6 +193,8 @@ class Embedding(MegatronModule): ...@@ -192,6 +193,8 @@ class Embedding(MegatronModule):
if tokentype_ids is not None: if tokentype_ids is not None:
assert self.tokentype_embeddings is not None assert self.tokentype_embeddings is not None
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) embeddings = embeddings + self.tokentype_embeddings(tokentype_ids)
else:
assert self.tokentype_embeddings is None
# Dropout. # Dropout.
embeddings = self.embedding_dropout(embeddings) embeddings = self.embedding_dropout(embeddings)
...@@ -284,9 +287,10 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -284,9 +287,10 @@ class TransformerLanguageModelBase(MegatronModule):
attention_mask_func, attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0, num_tokentypes=0,
self_attn_mask_type=AttnMaskType.padding,
add_decoder=False, add_decoder=False,
decoder_attn_mask_type=AttnMaskType.causal,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModelBase, self).__init__() super(TransformerLanguageModelBase, self).__init__()
args = get_args() args = get_args()
...@@ -294,8 +298,9 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -294,8 +298,9 @@ class TransformerLanguageModelBase(MegatronModule):
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes self.num_tokentypes = num_tokentypes
self.init_method = init_method self.init_method = init_method
self.self_attn_mask_type = self_attn_mask_type self.encoder_attn_mask_type = encoder_attn_mask_type
self.add_decoder = add_decoder self.add_decoder = add_decoder
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler self.add_pooler = add_pooler
# Embeddings. # Embeddings.
...@@ -313,7 +318,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -313,7 +318,7 @@ class TransformerLanguageModelBase(MegatronModule):
attention_mask_func, attention_mask_func,
self.init_method, self.init_method,
output_layer_init_method, output_layer_init_method,
self_attn_mask_type=self_attn_mask_type) self_attn_mask_type=self.encoder_attn_mask_type)
self._encoder_key = 'encoder' self._encoder_key = 'encoder'
# Decoder # Decoder
...@@ -325,7 +330,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -325,7 +330,7 @@ class TransformerLanguageModelBase(MegatronModule):
self.init_method, self.init_method,
output_layer_init_method, output_layer_init_method,
layer_type=LayerType.decoder, layer_type=LayerType.decoder,
self_attn_mask_type=AttnMaskType.causal) self_attn_mask_type=self.decoder_attn_mask_type)
self._decoder_key = 'decoder' self._decoder_key = 'decoder'
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
...@@ -334,7 +339,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -334,7 +339,7 @@ class TransformerLanguageModelBase(MegatronModule):
self.pooler = Pooler(self.hidden_size, self.init_method) self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler' self._pooler_key = 'pooler'
def forward(self, enc_language_model_input, enc_attention_mask, def forward(self, enc_language_model_input, enc_attn_mask,
dec_language_model_input=None, dec_attn_mask=None, dec_language_model_input=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None, enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0, get_key_value=False, pooling_sequence_index=0,
...@@ -352,7 +357,7 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -352,7 +357,7 @@ class TransformerLanguageModelBase(MegatronModule):
# encoder. # encoder.
if enc_hidden_states is None: if enc_hidden_states is None:
encoder_output = self.encoder(encoder_input, encoder_output = self.encoder(encoder_input,
enc_attention_mask, enc_attn_mask,
layer_past=layer_past, layer_past=layer_past,
get_key_value=get_key_value) get_key_value=get_key_value)
else: else:
...@@ -438,8 +443,8 @@ class TransformerLanguageModelBase(MegatronModule): ...@@ -438,8 +443,8 @@ class TransformerLanguageModelBase(MegatronModule):
# for backward compatibility. # for backward compatibility.
state_dict_ = {} state_dict_ = {}
for key in state_dict.keys(): for key in state_dict.keys():
if 'encoder.' in key: if 'transformer.' in key:
state_dict_[key.split('encoder.')[1]] = state_dict[key] state_dict_[key.split('transformer.')[1]] = state_dict[key]
# for backward compatibility. # for backward compatibility.
state_dict_self_attention = {} state_dict_self_attention = {}
...@@ -477,27 +482,29 @@ class TransformerLanguageModel(TransformerLanguageModelBase): ...@@ -477,27 +482,29 @@ class TransformerLanguageModel(TransformerLanguageModelBase):
attention_mask_func, attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=0, num_tokentypes=0,
self_attn_mask_type=AttnMaskType.padding, decoder_attn_mask_type=AttnMaskType.causal,
add_decoder=False, add_decoder=False,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModel, self).__init__( super(TransformerLanguageModel, self).__init__(
attention_mask_func, attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
encoder_attn_mask_type,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
self_attn_mask_type=self_attn_mask_type,
add_decoder=add_decoder, add_decoder=add_decoder,
decoder_attn_mask_type=decoder_attn_mask_type,
add_pooler=add_pooler) add_pooler=add_pooler)
def forward(self, enc_input_ids, enc_position_ids, enc_attention_mask, def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask,
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None,
enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None, enc_dec_attn_mask=None, tokentype_ids=None, layer_past=None,
get_key_value=False, pooling_sequence_index=0, get_key_value=False, pooling_sequence_index=0,
enc_hidden_states=None, output_enc_hidden=False): enc_hidden_states=None, output_enc_hidden=False):
return super(TransformerLanguageModel, self).forward( return super(TransformerLanguageModel, self).forward(
(enc_input_ids, enc_position_ids), (enc_input_ids, enc_position_ids),
enc_attention_mask, enc_attn_mask,
dec_language_model_input=(dec_input_ids, dec_position_ids), dec_language_model_input=(dec_input_ids, dec_position_ids),
dec_attn_mask=dec_attn_mask, dec_attn_mask=dec_attn_mask,
enc_dec_attn_mask=enc_dec_attn_mask, enc_dec_attn_mask=enc_dec_attn_mask,
...@@ -519,14 +526,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase): ...@@ -519,14 +526,14 @@ class TransformerLanguageModelFirstStage(TransformerLanguageModelBase):
attention_mask_func, attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
num_tokentypes=0, encoder_attn_mask_type,
self_attn_mask_type=AttnMaskType.padding): num_tokentypes=0):
super(TransformerLanguageModelFirstStage, self).__init__( super(TransformerLanguageModelFirstStage, self).__init__(
attention_mask_func, attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
num_tokentypes=num_tokentypes, encoder_attn_mask_type,
self_attn_mask_type=self_attn_mask_type) num_tokentypes=num_tokentypes)
def forward(self, input_ids, position_ids, attention_mask, def forward(self, input_ids, position_ids, attention_mask,
tokentype_ids=None, layer_past=None, get_key_value=False): tokentype_ids=None, layer_past=None, get_key_value=False):
...@@ -548,12 +555,12 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase): ...@@ -548,12 +555,12 @@ class TransformerLanguageModelIntermediateStage(TransformerLanguageModelBase):
attention_mask_func, attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding): encoder_attn_mask_type):
super(TransformerLanguageModelIntermediateStage, self).__init__( super(TransformerLanguageModelIntermediateStage, self).__init__(
attention_mask_func, attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
self_attn_mask_type=self_attn_mask_type) encoder_attn_mask_type)
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
layer_past=None, get_key_value=False): layer_past=None, get_key_value=False):
...@@ -574,13 +581,13 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase): ...@@ -574,13 +581,13 @@ class TransformerLanguageModelLastStage(TransformerLanguageModelBase):
attention_mask_func, attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type,
add_pooler=False): add_pooler=False):
super(TransformerLanguageModelLastStage, self).__init__( super(TransformerLanguageModelLastStage, self).__init__(
attention_mask_func, attention_mask_func,
init_method, init_method,
output_layer_init_method, output_layer_init_method,
self_attn_mask_type=AttnMaskType.padding, encoder_attn_mask_type,
add_pooler=add_pooler) add_pooler=add_pooler)
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
......
...@@ -19,6 +19,7 @@ import torch ...@@ -19,6 +19,7 @@ import torch
from megatron import get_args, print_rank_last from megatron import get_args, print_rank_last
from megatron import mpu from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
...@@ -39,6 +40,7 @@ class MultipleChoiceBase(MegatronModule): ...@@ -39,6 +40,7 @@ class MultipleChoiceBase(MegatronModule):
attention_mask_func=bert_attention_mask_func, attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method_normal(args.init_method_std, scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers)) args.num_layers))
......
...@@ -6,6 +6,7 @@ from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoi ...@@ -6,6 +6,7 @@ from megatron.checkpointing import get_checkpoint_tracker_filename, get_checkpoi
from megatron.model import BertModel from megatron.model import BertModel
from .module import MegatronModule from .module import MegatronModule
from megatron import mpu from megatron import mpu
from megatron.model.enums import AttnMaskType
from megatron.model.utils import get_linear_layer from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal from megatron.model.utils import init_method_normal
from megatron.model.language_model import get_language_model from megatron.model.language_model import get_language_model
...@@ -159,6 +160,7 @@ class IREncoderBertModel(MegatronModule): ...@@ -159,6 +160,7 @@ class IREncoderBertModel(MegatronModule):
attention_mask_func=bert_attention_mask_func, attention_mask_func=bert_attention_mask_func,
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
add_pooler=True, add_pooler=True,
encoder_attn_mask_type=AttnMaskType.padding,
init_method=init_method, init_method=init_method,
scaled_init_method=scaled_init_method) scaled_init_method=scaled_init_method)
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Transformer.""" """Transformer."""
import enum
import math import math
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment