Commit 39264559 authored by Peng-Jen Chen's avatar Peng-Jen Chen Committed by Facebook Github Bot
Browse files

Make learned positional embedding optional

Summary:
- Add learned positional embedding binary flag to masked LM model.
- Add base arch config for masked LM model which sets all the binary parameters to False. Otherwise some of the binary flag parameters will always be override by config in `xlm_architecture` (e.g. encoder_learned_pos)

Reviewed By: liezl200

Differential Revision: D15054487

fbshipit-source-id: d78827f352b9160a89c9dc4f45b9fce15a2f234d
parent 34726d56
......@@ -40,12 +40,12 @@ class MaskedLMModel(BaseFairseqModel):
def add_args(parser):
"""Add model-specific arguments to the parser."""
# Arguments related to dropout
parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', default=0.1, type=float,
parser.add_argument('--attention-dropout', type=float,
metavar='D', help='dropout probability for'
' attention weights')
parser.add_argument('--act-dropout', default=0.1, type=float,
parser.add_argument('--act-dropout', type=float,
metavar='D', help='dropout probability after'
' activation in FFN')
......@@ -66,17 +66,18 @@ class MaskedLMModel(BaseFairseqModel):
parser.add_argument('--share-encoder-input-output-embed',
action='store_true', help='share encoder input'
' and output embeddings')
parser.add_argument('--encoder-learned-pos', action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--no-token-positional-embeddings',
action='store_true',
help='if set, disables positional embeddings'
' (outside self attention)')
parser.add_argument('--num-segment', type=int, metavar='N', default=2,
parser.add_argument('--num-segment', type=int, metavar='N',
help='num segment in the input')
# Arguments related to sentence level prediction
parser.add_argument('--sentence-class-num', type=int, metavar='N',
default=2, help='number of classes for sentence'
' task')
help='number of classes for sentence task')
parser.add_argument('--sent-loss', action='store_true', help='if set,'
' calculate sentence level predictions')
......@@ -93,7 +94,7 @@ class MaskedLMModel(BaseFairseqModel):
help='apply layernorm before each encoder block')
parser.add_argument('--gelu', action='store_true',
help='Use gelu activation function in encoder'
' Layer')
' layer')
def forward(self, tokens, segment_labels):
return self.encoder(tokens, segment_labels)
......@@ -131,14 +132,6 @@ class MaskedLMEncoder(FairseqEncoder):
self.vocab_size = dictionary.__len__()
self.max_positions = args.max_positions
use_position_embeddings = (
not getattr(args, 'no_token_positional_embeddings', False)
)
encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
use_bert_layer_norm = getattr(args, 'bert_layer_norm', False)
use_gelu = getattr(args, 'gelu', False)
apply_bert_init = getattr(args, 'apply_bert_init', False)
self.sentence_encoder = TransformerSentenceEncoder(
padding_idx=self.padding_idx,
vocab_size=self.vocab_size,
......@@ -151,15 +144,15 @@ class MaskedLMEncoder(FairseqEncoder):
activation_dropout=args.act_dropout,
max_seq_len=self.max_positions,
num_segments=args.num_segment,
use_position_embeddings=use_position_embeddings,
encoder_normalize_before=encoder_normalize_before,
use_bert_layer_norm=use_bert_layer_norm,
use_gelu=use_gelu,
apply_bert_init=apply_bert_init,
use_position_embeddings=not args.no_token_positional_embeddings,
encoder_normalize_before=args.encoder_normalize_before,
use_bert_layer_norm=args.bert_layer_norm,
use_gelu=args.gelu,
apply_bert_init=args.apply_bert_init,
learned_pos_embedding=args.encoder_learned_pos,
)
self.share_input_output_embed = getattr(
args, 'share_encoder_input_output_embed', False)
self.share_input_output_embed = args.share_encoder_input_output_embed
self.embed_out = None
self.sentence_projection_layer = None
self.sentence_out_dim = args.sentence_class_num
......@@ -244,6 +237,34 @@ class MaskedLMEncoder(FairseqEncoder):
return state_dict
@register_model_architecture('masked_lm', 'masked_lm')
def base_architecture(args):
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.act_dropout = getattr(args, 'act_dropout', 0.1)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
args.no_bias_kv = getattr(args, 'no_bias_kv', False)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
args.share_encoder_input_output_embed = getattr(args, 'share_encoder_input_output_embed', False)
args.encoder_learned_pos = getattr(args, 'encoder_learned_pos', False)
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
args.num_segment = getattr(args, 'num_segment', 2)
args.sentence_class_num = getattr(args, 'sentence_class_num', 2)
args.sent_loss = getattr(args, 'sent_loss', False)
args.apply_bert_init = getattr(args, 'apply_bert_init', False)
args.bert_layer_norm = getattr(args, 'bert_layer_norm', False)
args.encoder_normalize_before = getattr(
args, 'encoder_normalize_before', False)
args.gelu = getattr(args, 'gelu', False)
@register_model_architecture('masked_lm', 'bert_base')
def base_bert_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 768)
......@@ -270,6 +291,7 @@ def base_bert_architecture(args):
args, 'encoder_normalize_before', True)
args.bert_layer_norm = getattr(args, 'bert_layer_norm', True)
args.gelu = getattr(args, 'gelu', True)
base_architecture(args)
@register_model_architecture('masked_lm', 'xlm_base')
......@@ -295,3 +317,4 @@ def xlm_architecture(args):
args.bert_layer_norm = getattr(args, 'bert_layer_norm', False)
args.gelu = getattr(args, 'gelu', True)
args.apply_bert_init = getattr(args, 'apply_bert_init', True)
base_architecture(args)
......@@ -14,7 +14,7 @@ import torch.nn.functional as F
from fairseq import options, utils
from fairseq.modules import (
AdaptiveInput, AdaptiveSoftmax, CharacterTokenEmbedder, LayerNorm,
LearnedPositionalEmbedding, MultiheadAttention, SinusoidalPositionalEmbedding,
MultiheadAttention, PositionalEmbedding, SinusoidalPositionalEmbedding,
)
from . import (
......@@ -804,20 +804,6 @@ def Linear(in_features, out_features, bias=True):
return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, learned=False):
if learned:
m = LearnedPositionalEmbedding(
num_embeddings + padding_idx + 1, embedding_dim, padding_idx,
)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
else:
m = SinusoidalPositionalEmbedding(
embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1,
)
return m
@register_model_architecture('transformer_lm', 'transformer_lm')
def base_lm_architecture(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
......
......@@ -23,6 +23,7 @@ from .linearized_convolution import LinearizedConvolution
from .logsumexp_moe import LogSumExpMoE
from .mean_pool_gating_network import MeanPoolGatingNetwork
from .multihead_attention import MultiheadAttention
from .positional_embedding import PositionalEmbedding
from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
from .transformer_sentence_encoder_layer import TransformerSentenceEncoderLayer
......@@ -49,6 +50,7 @@ __all__ = [
'LogSumExpMoE',
'MeanPoolGatingNetwork',
'MultiheadAttention',
'PositionalEmbedding',
'ScalarBias',
'SinusoidalPositionalEmbedding',
'TransformerSentenceEncoderLayer',
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.nn as nn
from .learned_positional_embedding import LearnedPositionalEmbedding
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, learned=False):
if learned:
m = LearnedPositionalEmbedding(
num_embeddings + padding_idx + 1, embedding_dim, padding_idx,
)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
else:
m = SinusoidalPositionalEmbedding(
embedding_dim, padding_idx, init_size=num_embeddings + padding_idx + 1,
)
return m
......@@ -11,7 +11,7 @@ import torch.nn.functional as F
from typing import Tuple
from fairseq.modules import (
MultiheadAttention, LearnedPositionalEmbedding, TransformerSentenceEncoderLayer
MultiheadAttention, PositionalEmbedding, TransformerSentenceEncoderLayer
)
......@@ -39,19 +39,6 @@ def init_bert_params(module):
module.in_proj_weight.data.normal_(mean=0.0, std=0.02)
def PositionalEmbedding(
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
)-> nn.Embedding:
m = LearnedPositionalEmbedding(
num_embeddings + padding_idx + 1, embedding_dim, padding_idx,
)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
class TransformerSentenceEncoder(nn.Module):
"""
Implementation for a Bi-directional Transformer based Sentence Encoder used
......@@ -94,6 +81,7 @@ class TransformerSentenceEncoder(nn.Module):
use_bert_layer_norm: bool = False,
use_gelu: bool = True,
apply_bert_init: bool = False,
learned_pos_embedding: bool = True,
) -> None:
super().__init__()
......@@ -105,6 +93,7 @@ class TransformerSentenceEncoder(nn.Module):
self.num_segments = num_segments
self.use_position_embeddings = use_position_embeddings
self.apply_bert_init = apply_bert_init
self.learned_pos_embedding = learned_pos_embedding
self.embed_tokens = nn.Embedding(
self.vocab_size, self.embedding_dim, self.padding_idx
......@@ -121,6 +110,7 @@ class TransformerSentenceEncoder(nn.Module):
self.max_seq_len,
self.embedding_dim,
self.padding_idx,
self.learned_pos_embedding,
)
if self.use_position_embeddings
else None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment