"git@developer.sourcefind.cn:norm/vllm.git" did not exist on "7f22f90e8cb423fdaa35203d41badd734d9c2e86"
Commit 93ec8d0b authored by Jingfei Du's avatar Jingfei Du Committed by Facebook Github Bot
Browse files

expose arguments for bias_kv and zero_attn for masked_lm

Summary: the old no_bias_kv argument for masked_lm models are not used. Split it into 2 arguments and expose them.

Reviewed By: myleott

Differential Revision: D15266154

fbshipit-source-id: 60b041f8370ca1d8869ed3402fb9a67d1cd8e0e8
parent acb9ab32
...@@ -59,9 +59,10 @@ class MaskedLMModel(BaseFairseqModel): ...@@ -59,9 +59,10 @@ class MaskedLMModel(BaseFairseqModel):
help='num encoder layers') help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N', parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads') help='num encoder attention heads')
parser.add_argument('--no-bias-kv', action='store_true', parser.add_argument('--bias-kv', action='store_true',
help='if set, pads attn with zero instead of' help='if set, adding a learnable bias kv')
' adding a learnable bias kv') parser.add_argument('--zero-attn', action='store_true',
help='if set, pads attn with zero')
# Arguments related to input and output embeddings # Arguments related to input and output embeddings
parser.add_argument('--encoder-embed-dim', type=int, metavar='N', parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
...@@ -151,6 +152,8 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -151,6 +152,8 @@ class MaskedLMEncoder(FairseqEncoder):
use_gelu=args.gelu, use_gelu=args.gelu,
apply_bert_init=args.apply_bert_init, apply_bert_init=args.apply_bert_init,
learned_pos_embedding=args.encoder_learned_pos, learned_pos_embedding=args.encoder_learned_pos,
add_bias_kv=args.bias_kv,
add_zero_attn=args.zero_attn,
) )
self.share_input_output_embed = args.share_encoder_input_output_embed self.share_input_output_embed = args.share_encoder_input_output_embed
...@@ -247,7 +250,8 @@ def base_architecture(args): ...@@ -247,7 +250,8 @@ def base_architecture(args):
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096) args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.encoder_layers = getattr(args, 'encoder_layers', 6) args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8) args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
args.no_bias_kv = getattr(args, 'no_bias_kv', False) args.bias_kv = getattr(args, 'bias_kv', False)
args.zero_attn = getattr(args, 'zero_attn', False)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
args.share_encoder_input_output_embed = getattr(args, 'share_encoder_input_output_embed', False) args.share_encoder_input_output_embed = getattr(args, 'share_encoder_input_output_embed', False)
...@@ -280,7 +284,8 @@ def bert_base_architecture(args): ...@@ -280,7 +284,8 @@ def bert_base_architecture(args):
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12) args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072) args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072)
args.no_bias_kv = getattr(args, 'no_bias_kv', True) args.bias_kv = getattr(args, 'bias_kv', False)
args.zero_attn = getattr(args, 'zero_attn', False)
args.sent_loss = getattr(args, 'sent_loss', True) args.sent_loss = getattr(args, 'sent_loss', True)
args.sentence_class_num = getattr(args, 'sentence_class_num', 2) args.sentence_class_num = getattr(args, 'sentence_class_num', 2)
...@@ -318,7 +323,8 @@ def xlm_architecture(args): ...@@ -318,7 +323,8 @@ def xlm_architecture(args):
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8) args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096) args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.no_bias_kv = getattr(args, 'no_bias_kv', True) args.bias_kv = getattr(args, 'bias_kv', False)
args.zero_attn = getattr(args, 'zero_attn', False)
args.sent_loss = getattr(args, 'sent_loss', False) args.sent_loss = getattr(args, 'sent_loss', False)
......
...@@ -82,6 +82,8 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -82,6 +82,8 @@ class TransformerSentenceEncoder(nn.Module):
use_gelu: bool = True, use_gelu: bool = True,
apply_bert_init: bool = False, apply_bert_init: bool = False,
learned_pos_embedding: bool = True, learned_pos_embedding: bool = True,
add_bias_kv: bool = False,
add_zero_attn: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -128,6 +130,8 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -128,6 +130,8 @@ class TransformerSentenceEncoder(nn.Module):
encoder_normalize_before=encoder_normalize_before, encoder_normalize_before=encoder_normalize_before,
use_bert_layer_norm=use_bert_layer_norm, use_bert_layer_norm=use_bert_layer_norm,
use_gelu=use_gelu, use_gelu=use_gelu,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
) )
for _ in range(num_encoder_layers) for _ in range(num_encoder_layers)
] ]
......
...@@ -37,6 +37,8 @@ class TransformerSentenceEncoderLayer(nn.Module): ...@@ -37,6 +37,8 @@ class TransformerSentenceEncoderLayer(nn.Module):
encoder_normalize_before: bool = False, encoder_normalize_before: bool = False,
use_bert_layer_norm: bool = False, use_bert_layer_norm: bool = False,
use_gelu: bool = True, use_gelu: bool = True,
add_bias_kv: bool = False,
add_zero_attn: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -49,7 +51,11 @@ class TransformerSentenceEncoderLayer(nn.Module): ...@@ -49,7 +51,11 @@ class TransformerSentenceEncoderLayer(nn.Module):
# Initialize blocks # Initialize blocks
self.activation_fn = gelu if use_gelu else F.relu self.activation_fn = gelu if use_gelu else F.relu
self.self_attn = MultiheadAttention( self.self_attn = MultiheadAttention(
self.embedding_dim, num_attention_heads, dropout=attention_dropout self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
) )
# layer norm associated with the self attention layer # layer norm associated with the self attention layer
......
...@@ -386,7 +386,6 @@ def train_masked_language_model(data_dir, arch): ...@@ -386,7 +386,6 @@ def train_masked_language_model(data_dir, arch):
# dropout, attention args # dropout, attention args
"--dropout", "--dropout",
"0.1", "0.1",
"--no-bias-kv",
"--attention-dropout", "--attention-dropout",
"0.1", "0.1",
# MLM args # MLM args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment