Commit 93ec8d0b authored by Jingfei Du's avatar Jingfei Du Committed by Facebook Github Bot
Browse files

expose arguments for bias_kv and zero_attn for masked_lm

Summary: the old no_bias_kv argument for masked_lm models are not used. Split it into 2 arguments and expose them.

Reviewed By: myleott

Differential Revision: D15266154

fbshipit-source-id: 60b041f8370ca1d8869ed3402fb9a67d1cd8e0e8
parent acb9ab32
......@@ -59,9 +59,10 @@ class MaskedLMModel(BaseFairseqModel):
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--no-bias-kv', action='store_true',
help='if set, pads attn with zero instead of'
' adding a learnable bias kv')
parser.add_argument('--bias-kv', action='store_true',
help='if set, adding a learnable bias kv')
parser.add_argument('--zero-attn', action='store_true',
help='if set, pads attn with zero')
# Arguments related to input and output embeddings
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
......@@ -151,6 +152,8 @@ class MaskedLMEncoder(FairseqEncoder):
use_gelu=args.gelu,
apply_bert_init=args.apply_bert_init,
learned_pos_embedding=args.encoder_learned_pos,
add_bias_kv=args.bias_kv,
add_zero_attn=args.zero_attn,
)
self.share_input_output_embed = args.share_encoder_input_output_embed
......@@ -247,7 +250,8 @@ def base_architecture(args):
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
args.no_bias_kv = getattr(args, 'no_bias_kv', False)
args.bias_kv = getattr(args, 'bias_kv', False)
args.zero_attn = getattr(args, 'zero_attn', False)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
args.share_encoder_input_output_embed = getattr(args, 'share_encoder_input_output_embed', False)
......@@ -280,7 +284,8 @@ def bert_base_architecture(args):
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 12)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 3072)
args.no_bias_kv = getattr(args, 'no_bias_kv', True)
args.bias_kv = getattr(args, 'bias_kv', False)
args.zero_attn = getattr(args, 'zero_attn', False)
args.sent_loss = getattr(args, 'sent_loss', True)
args.sentence_class_num = getattr(args, 'sentence_class_num', 2)
......@@ -318,7 +323,8 @@ def xlm_architecture(args):
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.no_bias_kv = getattr(args, 'no_bias_kv', True)
args.bias_kv = getattr(args, 'bias_kv', False)
args.zero_attn = getattr(args, 'zero_attn', False)
args.sent_loss = getattr(args, 'sent_loss', False)
......
......@@ -82,6 +82,8 @@ class TransformerSentenceEncoder(nn.Module):
use_gelu: bool = True,
apply_bert_init: bool = False,
learned_pos_embedding: bool = True,
add_bias_kv: bool = False,
add_zero_attn: bool = False,
) -> None:
super().__init__()
......@@ -128,6 +130,8 @@ class TransformerSentenceEncoder(nn.Module):
encoder_normalize_before=encoder_normalize_before,
use_bert_layer_norm=use_bert_layer_norm,
use_gelu=use_gelu,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
for _ in range(num_encoder_layers)
]
......
......@@ -37,6 +37,8 @@ class TransformerSentenceEncoderLayer(nn.Module):
encoder_normalize_before: bool = False,
use_bert_layer_norm: bool = False,
use_gelu: bool = True,
add_bias_kv: bool = False,
add_zero_attn: bool = False,
) -> None:
super().__init__()
......@@ -49,7 +51,11 @@ class TransformerSentenceEncoderLayer(nn.Module):
# Initialize blocks
self.activation_fn = gelu if use_gelu else F.relu
self.self_attn = MultiheadAttention(
self.embedding_dim, num_attention_heads, dropout=attention_dropout
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
)
# layer norm associated with the self attention layer
......
......@@ -386,7 +386,6 @@ def train_masked_language_model(data_dir, arch):
# dropout, attention args
"--dropout",
"0.1",
"--no-bias-kv",
"--attention-dropout",
"0.1",
# MLM args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment