Commit e40363d7 authored by Sai's avatar Sai Committed by Myle Ott
Browse files

Add pretrained embedding support (#151)

parent 48c4c6d3
...@@ -30,10 +30,14 @@ class FConvModel(FairseqModel): ...@@ -30,10 +30,14 @@ class FConvModel(FairseqModel):
help='dropout probability') help='dropout probability')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N', parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension') help='encoder embedding dimension')
parser.add_argument('--encoder-embed-path', default=None, type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-layers', type=str, metavar='EXPR', parser.add_argument('--encoder-layers', type=str, metavar='EXPR',
help='encoder layers [(dim, kernel_size), ...]') help='encoder layers [(dim, kernel_size), ...]')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N', parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension') help='decoder embedding dimension')
parser.add_argument('--decoder-embed-path', default=None, type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-layers', type=str, metavar='EXPR', parser.add_argument('--decoder-layers', type=str, metavar='EXPR',
help='decoder layers [(dim, kernel_size), ...]') help='decoder layers [(dim, kernel_size), ...]')
parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
...@@ -53,9 +57,21 @@ class FConvModel(FairseqModel): ...@@ -53,9 +57,21 @@ class FConvModel(FairseqModel):
args.max_target_positions = args.max_positions args.max_target_positions = args.max_positions
if not hasattr(args, 'share_input_output_embed'): if not hasattr(args, 'share_input_output_embed'):
args.share_input_output_embed = False args.share_input_output_embed = False
encoder_embed_dict = None
if args.encoder_embed_path:
encoder_embed_dict = utils.parse_embedding(args.encoder_embed_path)
utils.print_embed_overlap(encoder_embed_dict, src_dict)
decoder_embed_dict = None
if args.decoder_embed_path:
decoder_embed_dict = utils.parse_embedding(args.decoder_embed_path)
utils.print_embed_overlap(decoder_embed_dict, dst_dict)
encoder = FConvEncoder( encoder = FConvEncoder(
src_dict, src_dict,
embed_dim=args.encoder_embed_dim, embed_dim=args.encoder_embed_dim,
embed_dict=encoder_embed_dict,
convolutions=eval(args.encoder_layers), convolutions=eval(args.encoder_layers),
dropout=args.dropout, dropout=args.dropout,
max_positions=args.max_source_positions, max_positions=args.max_source_positions,
...@@ -63,6 +79,7 @@ class FConvModel(FairseqModel): ...@@ -63,6 +79,7 @@ class FConvModel(FairseqModel):
decoder = FConvDecoder( decoder = FConvDecoder(
dst_dict, dst_dict,
embed_dim=args.decoder_embed_dim, embed_dim=args.decoder_embed_dim,
embed_dict=decoder_embed_dict,
convolutions=eval(args.decoder_layers), convolutions=eval(args.decoder_layers),
out_embed_dim=args.decoder_out_embed_dim, out_embed_dim=args.decoder_out_embed_dim,
attention=eval(args.decoder_attention), attention=eval(args.decoder_attention),
...@@ -75,8 +92,8 @@ class FConvModel(FairseqModel): ...@@ -75,8 +92,8 @@ class FConvModel(FairseqModel):
class FConvEncoder(FairseqEncoder): class FConvEncoder(FairseqEncoder):
"""Convolutional encoder""" """Convolutional encoder"""
def __init__(self, dictionary, embed_dim=512, max_positions=1024, def __init__(self, dictionary, embed_dim=512, embed_dict=None,
convolutions=((512, 3),) * 20, dropout=0.1): max_positions=1024, convolutions=((512, 3),) * 20, dropout=0.1):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout = dropout self.dropout = dropout
self.num_attention_layers = None self.num_attention_layers = None
...@@ -84,6 +101,9 @@ class FConvEncoder(FairseqEncoder): ...@@ -84,6 +101,9 @@ class FConvEncoder(FairseqEncoder):
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
padding_idx = dictionary.pad() padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
if embed_dict:
self.embed_tokens = utils.load_embedding(embed_dict, self.dictionary, self.embed_tokens)
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
max_positions, max_positions,
embed_dim, embed_dim,
...@@ -197,7 +217,8 @@ class AttentionLayer(nn.Module): ...@@ -197,7 +217,8 @@ class AttentionLayer(nn.Module):
class FConvDecoder(FairseqIncrementalDecoder): class FConvDecoder(FairseqIncrementalDecoder):
"""Convolutional decoder""" """Convolutional decoder"""
def __init__(self, dictionary, embed_dim=512, out_embed_dim=256, def __init__(self, dictionary, embed_dim=512,
embed_dict=None, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20, max_positions=1024, convolutions=((512, 3),) * 20,
attention=True, dropout=0.1, share_embed=False): attention=True, dropout=0.1, share_embed=False):
super().__init__(dictionary) super().__init__(dictionary)
...@@ -215,6 +236,9 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -215,6 +236,9 @@ class FConvDecoder(FairseqIncrementalDecoder):
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
padding_idx = dictionary.pad() padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
if embed_dict:
self.embed_tokens = utils.load_embedding(embed_dict, self.dictionary, self.embed_tokens)
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
max_positions, max_positions,
embed_dim, embed_dim,
......
...@@ -28,10 +28,14 @@ class LSTMModel(FairseqModel): ...@@ -28,10 +28,14 @@ class LSTMModel(FairseqModel):
help='dropout probability') help='dropout probability')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N', parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension') help='encoder embedding dimension')
parser.add_argument('--encoder-embed-path', default=None, type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-layers', type=int, metavar='N', parser.add_argument('--encoder-layers', type=int, metavar='N',
help='number of encoder layers') help='number of encoder layers')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N', parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension') help='decoder embedding dimension')
parser.add_argument('--decoder-embed-path', default=None, type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-layers', type=int, metavar='N', parser.add_argument('--decoder-layers', type=int, metavar='N',
help='number of decoder layers') help='number of decoder layers')
parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N', parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
...@@ -52,9 +56,21 @@ class LSTMModel(FairseqModel): ...@@ -52,9 +56,21 @@ class LSTMModel(FairseqModel):
@classmethod @classmethod
def build_model(cls, args, src_dict, dst_dict): def build_model(cls, args, src_dict, dst_dict):
"""Build a new model instance.""" """Build a new model instance."""
encoder_embed_dict = None
if args.encoder_embed_path:
encoder_embed_dict = utils.parse_embedding(args.encoder_embed_path)
utils.print_embed_overlap(encoder_embed_dict, src_dict)
decoder_embed_dict = None
if args.decoder_embed_path:
decoder_embed_dict = utils.parse_embedding(args.decoder_embed_path)
utils.print_embed_overlap(decoder_embed_dict, dst_dict)
encoder = LSTMEncoder( encoder = LSTMEncoder(
src_dict, src_dict,
embed_dim=args.encoder_embed_dim, embed_dim=args.encoder_embed_dim,
embed_dict=encoder_embed_dict,
num_layers=args.encoder_layers, num_layers=args.encoder_layers,
dropout_in=args.encoder_dropout_in, dropout_in=args.encoder_dropout_in,
dropout_out=args.encoder_dropout_out, dropout_out=args.encoder_dropout_out,
...@@ -63,6 +79,7 @@ class LSTMModel(FairseqModel): ...@@ -63,6 +79,7 @@ class LSTMModel(FairseqModel):
dst_dict, dst_dict,
encoder_embed_dim=args.encoder_embed_dim, encoder_embed_dim=args.encoder_embed_dim,
embed_dim=args.decoder_embed_dim, embed_dim=args.decoder_embed_dim,
embed_dict=decoder_embed_dict,
out_embed_dim=args.decoder_out_embed_dim, out_embed_dim=args.decoder_out_embed_dim,
num_layers=args.decoder_layers, num_layers=args.decoder_layers,
attention=bool(eval(args.decoder_attention)), attention=bool(eval(args.decoder_attention)),
...@@ -74,8 +91,8 @@ class LSTMModel(FairseqModel): ...@@ -74,8 +91,8 @@ class LSTMModel(FairseqModel):
class LSTMEncoder(FairseqEncoder): class LSTMEncoder(FairseqEncoder):
"""LSTM encoder.""" """LSTM encoder."""
def __init__(self, dictionary, embed_dim=512, num_layers=1, dropout_in=0.1, def __init__(self, dictionary, embed_dim=512, embed_dict=None,
dropout_out=0.1): num_layers=1, dropout_in=0.1, dropout_out=0.1):
super().__init__(dictionary) super().__init__(dictionary)
self.num_layers = num_layers self.num_layers = num_layers
self.dropout_in = dropout_in self.dropout_in = dropout_in
...@@ -84,6 +101,9 @@ class LSTMEncoder(FairseqEncoder): ...@@ -84,6 +101,9 @@ class LSTMEncoder(FairseqEncoder):
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
self.padding_idx = dictionary.pad() self.padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx) self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
if embed_dict:
self.embed_tokens = utils.load_embedding(
embed_dict, self.dictionary, self.embed_tokens)
self.lstm = LSTM( self.lstm = LSTM(
input_size=embed_dim, input_size=embed_dim,
...@@ -163,7 +183,8 @@ class AttentionLayer(nn.Module): ...@@ -163,7 +183,8 @@ class AttentionLayer(nn.Module):
class LSTMDecoder(FairseqIncrementalDecoder): class LSTMDecoder(FairseqIncrementalDecoder):
"""LSTM decoder.""" """LSTM decoder."""
def __init__(self, dictionary, encoder_embed_dim=512, embed_dim=512, def __init__(self, dictionary, encoder_embed_dim=512,
embed_dim=512, embed_dict=None,
out_embed_dim=512, num_layers=1, dropout_in=0.1, out_embed_dim=512, num_layers=1, dropout_in=0.1,
dropout_out=0.1, attention=True): dropout_out=0.1, attention=True):
super().__init__(dictionary) super().__init__(dictionary)
...@@ -173,6 +194,10 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -173,6 +194,10 @@ class LSTMDecoder(FairseqIncrementalDecoder):
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
padding_idx = dictionary.pad() padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
if embed_dict:
self.embed_tokens = utils.load_embedding(
embed_dict, self.dictionary, self.embed_tokens)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LSTMCell(encoder_embed_dim + embed_dim if layer == 0 else embed_dim, embed_dim) LSTMCell(encoder_embed_dim + embed_dim if layer == 0 else embed_dim, embed_dim)
......
...@@ -248,6 +248,38 @@ def load_align_dict(replace_unk): ...@@ -248,6 +248,38 @@ def load_align_dict(replace_unk):
return align_dict return align_dict
def print_embed_overlap(embed_dict, vocab_dict):
embed_keys = set(embed_dict.keys())
vocab_keys = set(vocab_dict.symbols)
overlap = len(embed_keys & vocab_keys)
print("| Found {}/{} types in embedding file.".format(overlap, len(vocab_dict)))
def parse_embedding(embed_path):
"""Parse embedding text file into a dictionary of word and embedding tensors.
The first line can have vocabulary size and dimension. The following lines
should contain word and embedding separated by spaces.
Example:
2 5
the -0.0230 -0.0264 0.0287 0.0171 0.1403
at -0.0395 -0.1286 0.0275 0.0254 -0.0932
"""
embed_dict = dict()
with open(embed_path) as f_embed:
_ = next(f_embed) #skip header
for line in f_embed:
pieces = line.strip().split()
embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]])
return embed_dict
def load_embedding(embed_dict, vocab, embedding):
for idx in range(len(vocab)):
token = vocab[idx]
if token in embed_dict:
embedding.weight.data[idx] = embed_dict[token]
return embedding
def replace_unk(hypo_str, src_str, alignment, align_dict, unk): def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
from fairseq import tokenizer from fairseq import tokenizer
# Tokens are strings here # Tokens are strings here
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment