Commit 39264559 authored by Peng-Jen Chen's avatar Peng-Jen Chen Committed by Facebook Github Bot
Browse files

Make learned positional embedding optional

Summary:
- Add learned positional embedding binary flag to masked LM model.
- Add base arch config for masked LM model which sets all the binary parameters to False. Otherwise some of the binary flag parameters will always be override by config in `xlm_architecture` (e.g. encoder_learned_pos)

Reviewed By: liezl200

Differential Revision: D15054487

fbshipit-source-id: d78827f352b9160a89c9dc4f45b9fce15a2f234d
parent 34726d56
...@@ -40,12 +40,12 @@ class MaskedLMModel(BaseFairseqModel): ...@@ -40,12 +40,12 @@ class MaskedLMModel(BaseFairseqModel):
def add_args(parser): def add_args(parser):
"""Add model-specific arguments to the parser.""" """Add model-specific arguments to the parser."""
# Arguments related to dropout # Arguments related to dropout
parser.add_argument('--dropout', default=0.1, type=float, metavar='D', parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability') help='dropout probability')
parser.add_argument('--attention-dropout', default=0.1, type=float, parser.add_argument('--attention-dropout', type=float,
metavar='D', help='dropout probability for' metavar='D', help='dropout probability for'
' attention weights') ' attention weights')
parser.add_argument('--act-dropout', default=0.1, type=float, parser.add_argument('--act-dropout', type=float,
metavar='D', help='dropout probability after' metavar='D', help='dropout probability after'
' activation in FFN') ' activation in FFN')
...@@ -66,17 +66,18 @@ class MaskedLMModel(BaseFairseqModel): ...@@ -66,17 +66,18 @@ class MaskedLMModel(BaseFairseqModel):
parser.add_argument('--share-encoder-input-output-embed', parser.add_argument('--share-encoder-input-output-embed',
action='store_true', help='share encoder input' action='store_true', help='share encoder input'
' and output embeddings') ' and output embeddings')
parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--no-token-positional-embeddings', parser.add_argument('--no-token-positional-embeddings',
action='store_true', action='store_true',
help='if set, disables positional embeddings' help='if set, disables positional embeddings'
' (outside self attention)') ' (outside self attention)')
parser.add_argument('--num-segment', type=int, metavar='N', default=2, parser.add_argument('--num-segment', type=int, metavar='N',
help='num segment in the input') help='num segment in the input')
# Arguments related to sentence level prediction # Arguments related to sentence level prediction
parser.add_argument('--sentence-class-num', type=int, metavar='N', parser.add_argument('--sentence-class-num', type=int, metavar='N',
default=2, help='number of classes for sentence' help='number of classes for sentence task')
' task')
parser.add_argument('--sent-loss', action='store_true', help='if set,' parser.add_argument('--sent-loss', action='store_true', help='if set,'
' calculate sentence level predictions') ' calculate sentence level predictions')
...@@ -93,7 +94,7 @@ class MaskedLMModel(BaseFairseqModel): ...@@ -93,7 +94,7 @@ class MaskedLMModel(BaseFairseqModel):
help='apply layernorm before each encoder block') help='apply layernorm before each encoder block')
parser.add_argument('--gelu', action='store_true', parser.add_argument('--gelu', action='store_true',
help='Use gelu activation function in encoder' help='Use gelu activation function in encoder'
' Layer') ' layer')
def forward(self, tokens, segment_labels): def forward(self, tokens, segment_labels):
return self.encoder(tokens, segment_labels) return self.encoder(tokens, segment_labels)
...@@ -131,14 +132,6 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -131,14 +132,6 @@ class MaskedLMEncoder(FairseqEncoder):
self.vocab_size = dictionary.__len__() self.vocab_size = dictionary.__len__()
self.max_positions = args.max_positions self.max_positions = args.max_positions
use_position_embeddings = (
not getattr(args, 'no_token_positional_embeddings', False)
)
encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
use_bert_layer_norm = getattr(args, 'bert_layer_norm', False)
use_gelu = getattr(args, 'gelu', False)
apply_bert_init = getattr(args, 'apply_bert_init', False)
self.sentence_encoder = TransformerSentenceEncoder( self.sentence_encoder = TransformerSentenceEncoder(
padding_idx=self.padding_idx, padding_idx=self.padding_idx,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
...@@ -151,15 +144,15 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -151,15 +144,15 @@ class MaskedLMEncoder(FairseqEncoder):
activation_dropout=args.act_dropout, activation_dropout=args.act_dropout,
max_seq_len=self.max_positions, max_seq_len=self.max_positions,
num_segments=args.num_segment, num_segments=args.num_segment,
use_position_embeddings=use_position_embeddings, use_position_embeddings=not args.no_token_positional_embeddings,
encoder_normalize_before=encoder_normalize_before, encoder_normalize_before=args.encoder_normalize_before,
use_bert_layer_norm=use_bert_layer_norm, use_bert_layer_norm=args.bert_layer_norm,
use_gelu=use_gelu, use_gelu=args.gelu,
apply_bert_init=apply_bert_init, apply_bert_init=args.apply_bert_init,
learned_pos_embedding=args.encoder_learned_pos,
) )
self.share_input_output_embed = getattr( self.share_input_output_embed = args.share_encoder_input_output_embed
args, 'share_encoder_input_output_embed', False)
self.embed_out = None self.embed_out = None
self.sentence_projection_layer = None self.sentence_projection_layer = None
self.sentence_out_dim = args.sentence_class_num self.sentence_out_dim = args.sentence_class_num
...@@ -244,6 +237,34 @@ class MaskedLMEncoder(FairseqEncoder): ...@@ -244,6 +237,34 @@ class MaskedLMEncoder(FairseqEncoder):
return state_dict return state_dict
@register_model_architecture('masked_lm', 'masked_lm')
def base_architecture(args):
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.act_dropout = getattr(args, 'act_dropout', 0.1)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
args.no_bias_kv = getattr(args, 'no_bias_kv', False)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
args.share_encoder_input_output_embed = getattr(args, 'share_encoder_input_output_embed', False)
args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', False)
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
args.num_segment = getattr(args, 'num_segment', 2)
args.sentence_class_num = getattr(args, 'sentence_class_num', 2)
args.sent_loss = getattr(args, 'sent_loss', False)
args.apply_bert_init = getattr(args, 'apply_bert_init', False)
args.bert_layer_norm = getattr(args, 'bert_layer_norm', False)
args.encoder_normalize_before = getattr(
args, 'encoder_normalize_before', False)
args.gelu = getattr(args, 'gelu', False)
@register_model_architecture('masked_lm', 'bert_base') @register_model_architecture('masked_lm', 'bert_base')
def base_bert_architecture(args): def base_bert_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
...@@ -270,6 +291,7 @@ def base_bert_architecture(args): ...@@ -270,6 +291,7 @@ def base_bert_architecture(args):
args, 'encoder_normalize_before', True) args, 'encoder_normalize_before', True)
args.bert_layer_norm = getattr(args, 'bert_layer_norm', True) args.bert_layer_norm = getattr(args, 'bert_layer_norm', True)
args.gelu = getattr(args, 'gelu', True) args.gelu = getattr(args, 'gelu', True)
base_architecture(args)
@register_model_architecture('masked_lm', 'xlm_base') @register_model_architecture('masked_lm', 'xlm_base')
...@@ -295,3 +317,4 @@ def xlm_architecture(args): ...@@ -295,3 +317,4 @@ def xlm_architecture(args):
args.bert_layer_norm = getattr(args, 'bert_layer_norm', False) args.bert_layer_norm = getattr(args, 'bert_layer_norm', False)
args.gelu = getattr(args, 'gelu', True) args.gelu = getattr(args, 'gelu', True)
args.apply_bert_init = getattr(args, 'apply_bert_init', True) args.apply_bert_init = getattr(args, 'apply_bert_init', True)
base_architecture(args)
...@@ -14,7 +14,7 @@ import torch.nn.functional as F ...@@ -14,7 +14,7 @@ import torch.nn.functional as F
from fairseq import options, utils from fairseq import options, utils
from fairseq.modules import ( from fairseq.modules import (
AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LayerNorm, AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LayerNorm,
LearnedPositionalEmbedding, MultiheadAttention, SinusoidalPositionalEmbedding, MultiheadAttention, PositionalEmbedding, SinusoidalPositionalEmbedding,
) )
from . import ( from . import (
...@@ -804,20 +804,6 @@ def Linear(in_features, out_features, bias=True): ...@@ -804,20 +804,6 @@ def Linear(in_features, out_features, bias=True):
return m return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, learned=False):
if learned:
m = LearnedPositionalEmbedding(
num_embeddings + padding_idx + 1, embedding_dim, padding_idx,
)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
else:
m = SinusoidalPositionalEmbedding(
embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1,
)
return m
@register_model_architecture('transformer_lm', 'transformer_lm') @register_model_architecture('transformer_lm', 'transformer_lm')
def base_lm_architecture(args): def base_lm_architecture(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
......
...@@ -23,6 +23,7 @@ from .linearized_convolution import LinearizedConvolution ...@@ -23,6 +23,7 @@ from .linearized_convolution import LinearizedConvolution
from .logsumexp_moe import LogSumExpMoE from .logsumexp_moe import LogSumExpMoE
from .mean_pool_gating_network import MeanPoolGatingNetwork from .mean_pool_gating_network import MeanPoolGatingNetwork
from .multihead_attention import MultiheadAttention from .multihead_attention import MultiheadAttention
from .positional_embedding import PositionalEmbedding
from .scalar_bias import ScalarBias from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer
...@@ -49,6 +50,7 @@ __all__ = [ ...@@ -49,6 +50,7 @@ __all__ = [
'LogSumExpMoE', 'LogSumExpMoE',
'MeanPoolGatingNetwork', 'MeanPoolGatingNetwork',
'MultiheadAttention', 'MultiheadAttention',
'PositionalEmbedding',
'ScalarBias', 'ScalarBias',
'SinusoidalPositionalEmbedding', 'SinusoidalPositionalEmbedding',
'TransformerSentenceEncoderLayer', 'TransformerSentenceEncoderLayer',
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.nn as nn
from .learned_positional_embedding import LearnedPositionalEmbedding
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, learned=False):
if learned:
m = LearnedPositionalEmbedding(
num_embeddings + padding_idx + 1, embedding_dim, padding_idx,
)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
else:
m = SinusoidalPositionalEmbedding(
embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1,
)
return m
...@@ -11,7 +11,7 @@ import torch.nn.functional as F ...@@ -11,7 +11,7 @@ import torch.nn.functional as F
from typing import Tuple from typing import Tuple
from fairseq.modules import ( from fairseq.modules import (
MultiheadAttention, LearnedPositionalEmbedding, TransformerSentenceEncoderLayer MultiheadAttention, PositionalEmbedding, TransformerSentenceEncoderLayer
) )
...@@ -39,19 +39,6 @@ def init_bert_params(module): ...@@ -39,19 +39,6 @@ def init_bert_params(module):
module.in_proj_weight.data.normal_(mean=0.0, std=0.02) module.in_proj_weight.data.normal_(mean=0.0, std=0.02)
def PositionalEmbedding(
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
)-> nn.Embedding:
m = LearnedPositionalEmbedding(
num_embeddings + padding_idx + 1, embedding_dim, padding_idx,
)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
class TransformerSentenceEncoder(nn.Module): class TransformerSentenceEncoder(nn.Module):
""" """
Implementation for a Bi-directional Transformer based Sentence Encoder used Implementation for a Bi-directional Transformer based Sentence Encoder used
...@@ -94,6 +81,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -94,6 +81,7 @@ class TransformerSentenceEncoder(nn.Module):
use_bert_layer_norm: bool = False, use_bert_layer_norm: bool = False,
use_gelu: bool = True, use_gelu: bool = True,
apply_bert_init: bool = False, apply_bert_init: bool = False,
learned_pos_embedding: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -105,6 +93,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -105,6 +93,7 @@ class TransformerSentenceEncoder(nn.Module):
self.num_segments = num_segments self.num_segments = num_segments
self.use_position_embeddings = use_position_embeddings self.use_position_embeddings = use_position_embeddings
self.apply_bert_init = apply_bert_init self.apply_bert_init = apply_bert_init
self.learned_pos_embedding = learned_pos_embedding
self.embed_tokens = nn.Embedding( self.embed_tokens = nn.Embedding(
self.vocab_size, self.embedding_dim, self.padding_idx self.vocab_size, self.embedding_dim, self.padding_idx
...@@ -121,6 +110,7 @@ class TransformerSentenceEncoder(nn.Module): ...@@ -121,6 +110,7 @@ class TransformerSentenceEncoder(nn.Module):
self.max_seq_len, self.max_seq_len,
self.embedding_dim, self.embedding_dim,
self.padding_idx, self.padding_idx,
self.learned_pos_embedding,
) )
if self.use_position_embeddings if self.use_position_embeddings
else None else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment