"...modules/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "dbf06b504b525c7f6680c5709b63df6413616d2e"
Commit 9a88b71d authored by Stephen Roller's avatar Stephen Roller Committed by Myle Ott
Browse files

Support pretrained embeddings for Transformer.

Also show a nicer error message.
parent 5bf07724
...@@ -11,6 +11,8 @@ import torch ...@@ -11,6 +11,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils
from fairseq.modules import ( from fairseq.modules import (
LearnedPositionalEmbedding, MultiheadAttention, LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding, SinusoidalPositionalEmbedding,
...@@ -36,6 +38,8 @@ class TransformerModel(FairseqModel): ...@@ -36,6 +38,8 @@ class TransformerModel(FairseqModel):
help='dropout probability for attention weights') help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', type=float, metavar='D', parser.add_argument('--relu-dropout', type=float, metavar='D',
help='dropout probability after ReLU in FFN') help='dropout probability after ReLU in FFN')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N', parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension') help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
...@@ -48,6 +52,8 @@ class TransformerModel(FairseqModel): ...@@ -48,6 +52,8 @@ class TransformerModel(FairseqModel):
help='apply layernorm before each encoder block') help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', default=False, action='store_true', parser.add_argument('--encoder-learned-pos', default=False, action='store_true',
help='use learned positional embeddings in the encoder') help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N', parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension') help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
...@@ -69,12 +75,20 @@ class TransformerModel(FairseqModel): ...@@ -69,12 +75,20 @@ class TransformerModel(FairseqModel):
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
"""Build a new model instance.""" """Build a new model instance."""
# make sure that all args are properly defaulted (in case there are any new ones)
base_architecture(args)
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim): def build_embedding(dictionary, embed_dim, path=None):
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
padding_idx = dictionary.pad() padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx) emb = Embedding(num_embeddings, embed_dim, padding_idx)
# if provided, load from preloaded dictionaries
if path:
embed_dict = utils.parse_embedding(path)
utils.load_embedding(embed_dict, dictionary, emb)
return emb
if args.share_all_embeddings: if args.share_all_embeddings:
if src_dict != tgt_dict: if src_dict != tgt_dict:
...@@ -82,12 +96,21 @@ class TransformerModel(FairseqModel): ...@@ -82,12 +96,21 @@ class TransformerModel(FairseqModel):
if args.encoder_embed_dim != args.decoder_embed_dim: if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError( raise RuntimeError(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim') '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim) if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path):
raise RuntimeError('--share-all-embeddings not compatible with --decoder-embed-path')
encoder_embed_tokens = build_embedding(
src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = encoder_embed_tokens decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True args.share_decoder_input_output_embed = True
else: else:
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim) encoder_embed_tokens = build_embedding(
decoder_embed_tokens = build_embedding(tgt_dict, args.decoder_embed_dim) src_dict, args.encoder_embed_dim, args.encoder_embed_path
)
decoder_embed_tokens = build_embedding(
tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
)
encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens) encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens)
decoder = TransformerDecoder(args, tgt_dict, decoder_embed_tokens) decoder = TransformerDecoder(args, tgt_dict, decoder_embed_tokens)
...@@ -391,10 +414,12 @@ def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, le ...@@ -391,10 +414,12 @@ def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, le
@register_model_architecture('transformer', 'transformer') @register_model_architecture('transformer', 'transformer')
def base_architecture(args): def base_architecture(args):
args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 2048) args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 2048)
args.encoder_layers = getattr(args, 'encoder_layers', 6) args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8) args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim) args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim) args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim)
args.decoder_layers = getattr(args, 'decoder_layers', 6) args.decoder_layers = getattr(args, 'decoder_layers', 6)
......
...@@ -24,7 +24,7 @@ class MultiheadAttention(nn.Module): ...@@ -24,7 +24,7 @@ class MultiheadAttention(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self._mask = None self._mask = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment