Commit 9a88b71d authored by Stephen Roller's avatar Stephen Roller Committed by Myle Ott
Browse files

Support pretrained embeddings for Transformer.

Also show a nicer error message.
parent 5bf07724
......@@ -11,6 +11,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.modules import (
LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding,
......@@ -36,6 +38,8 @@ class TransformerModel(FairseqModel):
help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', type=float, metavar='D',
help='dropout probability after ReLU in FFN')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
......@@ -48,6 +52,8 @@ class TransformerModel(FairseqModel):
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', default=False, action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
......@@ -69,12 +75,20 @@ class TransformerModel(FairseqModel):
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure that all args are properly defaulted (in case there are any new ones)
base_architecture(args)
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim):
def build_embedding(dictionary, embed_dim, path=None):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx)
emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
if args.share_all_embeddings:
if src_dict != tgt_dict:
......@@ -82,12 +96,21 @@ class TransformerModel(FairseqModel):
if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim)
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path):
raise RuntimeError('--share-all-embeddings not compatible with --decoder-embed-path')
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim)
decoder_embed_tokens = build_embedding(tgt_dict, args.decoder_embed_dim)
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = build_embedding(
tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
)
encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens)
decoder = TransformerDecoder(args, tgt_dict, decoder_embed_tokens)
......@@ -391,10 +414,12 @@ def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, le
@register_model_architecture('transformer', 'transformer')
def base_architecture(args):
args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 2048)
args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
......
......@@ -24,7 +24,7 @@ class MultiheadAttention(nn.Module):
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5
self._mask = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment