"src/diffusers/models/consistency_decoder_vae.py" did not exist on "ad8f985e81d0e0bac0375cd81b909987e8b2a7f9"
Commit 6f96ad78 authored by Sai's avatar Sai Committed by Myle Ott
Browse files

Add pretrained embedding support

parent 8300a521
......@@ -60,10 +60,6 @@ class FConvModel(FairseqModel):
args.max_target_positions = args.max_positions
if not hasattr(args, 'share_input_output_embed'):
args.share_input_output_embed = False
if not hasattr(args, 'encoder_embed_path'):
args.encoder_embed_path = None
if not hasattr(args, 'decoder_embed_path'):
args.decoder_embed_path = None
encoder_embed_dict = None
if args.encoder_embed_path:
......@@ -108,6 +104,9 @@ class FConvEncoder(FairseqEncoder):
num_embeddings = len(dictionary)
self.padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
if embed_dict:
self.embed_tokens = utils.load_embedding(embed_dict, self.dictionary, self.embed_tokens)
self.embed_positions = PositionalEmbedding(
max_positions,
embed_dim,
......@@ -161,7 +160,7 @@ class FConvEncoder(FairseqEncoder):
if conv.kernel_size[0] % 2 == 1:
# padding is implicit in the conv
x = conv(x)
else:
else:
padding_l = (conv.kernel_size[0] - 1) // 2
padding_r = conv.kernel_size[0] // 2
x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r))
......
......@@ -30,8 +30,6 @@ class LSTMModel(FairseqModel):
help='encoder embedding dimension')
parser.add_argument('--encoder-embed-path', default=None, type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-hidden-size', type=int, metavar='N',
help='encoder hidden size')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='number of encoder layers')
parser.add_argument('--encoder-bidirectional', action='store_true',
......@@ -40,8 +38,6 @@ class LSTMModel(FairseqModel):
help='decoder embedding dimension')
parser.add_argument('--decoder-embed-path', default=None, type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-hidden-size', type=int, metavar='N',
help='decoder hidden size')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='number of decoder layers')
parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
......@@ -65,38 +61,21 @@ class LSTMModel(FairseqModel):
base_architecture(args)
"""Build a new model instance."""
if not hasattr(args, 'encoder_embed_path'):
args.encoder_embed_path = None
if not hasattr(args, 'decoder_embed_path'):
args.decoder_embed_path = None
if not hasattr(args, 'encoder_hidden_size'):
args.encoder_hidden_size = args.encoder_embed_dim
if not hasattr(args, 'decoder_hidden_size'):
args.decoder_hidden_size = args.decoder_embed_dim
if not hasattr(args, 'encoder_bidirectional'):
args.encoder_bidirectional = False
def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
embed_dict = utils.parse_embedding(embed_path)
utils.print_embed_overlap(embed_dict, dictionary)
return utils.load_embedding(embed_dict, dictionary, embed_tokens)
pretrained_encoder_embed = None
encoder_embed_dict = None
if args.encoder_embed_path:
pretrained_encoder_embed = load_pretrained_embedding_from_file(
args.encoder_embed_path, src_dict, args.encoder_embed_dim)
pretrained_decoder_embed = None
encoder_embed_dict = utils.parse_embedding(args.encoder_embed_path)
utils.print_embed_overlap(encoder_embed_dict, src_dict)
decoder_embed_dict = None
if args.decoder_embed_path:
pretrained_decoder_embed = load_pretrained_embedding_from_file(
args.decoder_embed_path, dst_dict, args.decoder_embed_dim)
decoder_embed_dict = utils.parse_embedding(args.decoder_embed_path)
utils.print_embed_overlap(decoder_embed_dict, dst_dict)
encoder = LSTMEncoder(
dictionary=src_dict,
embed_dim=args.encoder_embed_dim,
hidden_size=args.encoder_hidden_size,
embed_dict=encoder_embed_dict,
num_layers=args.encoder_layers,
dropout_in=args.encoder_dropout_in,
dropout_out=args.encoder_dropout_out,
......@@ -110,7 +89,7 @@ class LSTMModel(FairseqModel):
decoder = LSTMDecoder(
dictionary=dst_dict,
embed_dim=args.decoder_embed_dim,
hidden_size=args.decoder_hidden_size,
embed_dict=decoder_embed_dict,
out_embed_dim=args.decoder_out_embed_dim,
num_layers=args.decoder_layers,
dropout_in=args.decoder_dropout_in,
......@@ -125,13 +104,8 @@ class LSTMModel(FairseqModel):
class LSTMEncoder(FairseqEncoder):
"""LSTM encoder."""
def __init__(
self, dictionary, embed_dim=512, hidden_size=512, num_layers=1,
dropout_in=0.1, dropout_out=0.1, bidirectional=False,
left_pad_source=LanguagePairDataset.LEFT_PAD_SOURCE,
pretrained_embed=None,
padding_value=0.,
):
def __init__(self, dictionary, embed_dim=512, embed_dict=None,
num_layers=1, dropout_in=0.1, dropout_out=0.1):
super().__init__(dictionary)
self.num_layers = num_layers
self.dropout_in = dropout_in
......@@ -141,10 +115,10 @@ class LSTMEncoder(FairseqEncoder):
num_embeddings = len(dictionary)
self.padding_idx = dictionary.pad()
if pretrained_embed is None:
self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
else:
self.embed_tokens = pretrained_embed
self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
if embed_dict:
self.embed_tokens = utils.load_embedding(
embed_dict, self.dictionary, self.embed_tokens)
self.lstm = LSTM(
input_size=embed_dim,
......@@ -259,12 +233,10 @@ class AttentionLayer(nn.Module):
class LSTMDecoder(FairseqIncrementalDecoder):
"""LSTM decoder."""
def __init__(
self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512,
num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True,
encoder_embed_dim=512, encoder_output_units=512,
pretrained_embed=None,
):
def __init__(self, dictionary, encoder_embed_dim=512,
embed_dim=512, embed_dict=None,
out_embed_dim=512, num_layers=1, dropout_in=0.1,
dropout_out=0.1, attention=True):
super().__init__(dictionary)
self.dropout_in = dropout_in
self.dropout_out = dropout_out
......@@ -272,15 +244,11 @@ class LSTMDecoder(FairseqIncrementalDecoder):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
if pretrained_embed is None:
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
else:
self.embed_tokens = pretrained_embed
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
if embed_dict:
self.embed_tokens = utils.load_embedding(
embed_dict, self.dictionary, self.embed_tokens)
self.encoder_output_units = encoder_output_units
assert encoder_output_units == hidden_size, \
'{} {}'.format(encoder_output_units, hidden_size)
# TODO another Linear layer if not equal
self.layers = nn.ModuleList([
LSTMCell(
......
......@@ -258,11 +258,10 @@ def load_align_dict(replace_unk):
def print_embed_overlap(embed_dict, vocab_dict):
embed_keys = set(embed_dict.keys())
vocab_keys = set(vocab_dict.symbols)
overlap = len(embed_keys & vocab_keys)
print("| Found {}/{} types in embedding file.".format(overlap, len(vocab_dict)))
embed_keys = set(embed_dict.keys())
vocab_keys = set(vocab_dict.symbols)
overlap = len(embed_keys & vocab_keys)
print("| Found {}/{} types in embedding file.".format(overlap, len(vocab_dict)))
def parse_embedding(embed_path):
"""Parse embedding text file into a dictionary of word and embedding tensors.
......@@ -275,15 +274,14 @@ def parse_embedding(embed_path):
the -0.0230 -0.0264 0.0287 0.0171 0.1403
at -0.0395 -0.1286 0.0275 0.0254 -0.0932
"""
embed_dict = {}
embed_dict = dict()
with open(embed_path) as f_embed:
_ = next(f_embed) # skip header
_ = next(f_embed) #skip header
for line in f_embed:
pieces = line.strip().split()
embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]])
return embed_dict
def load_embedding(embed_dict, vocab, embedding):
for idx in range(len(vocab)):
token = vocab[idx]
......@@ -291,7 +289,6 @@ def load_embedding(embed_dict, vocab, embedding):
embedding.weight.data[idx] = embed_dict[token]
return embedding
def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
from fairseq import tokenizer
# Tokens are strings here
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment