Commit 4973d05a authored by Myle Ott's avatar Myle Ott
Browse files

Flake8

parent e40363d7
...@@ -119,9 +119,9 @@ class FConvEncoder(FairseqEncoder): ...@@ -119,9 +119,9 @@ class FConvEncoder(FairseqEncoder):
self.projections.append(Linear(in_channels, out_channels) self.projections.append(Linear(in_channels, out_channels)
if in_channels != out_channels else None) if in_channels != out_channels else None)
if kernel_size % 2 == 1: if kernel_size % 2 == 1:
padding = kernel_size //2 padding = kernel_size // 2
else: else:
padding = 0 padding = 0
self.convolutions.append( self.convolutions.append(
ConvTBC(in_channels, out_channels * 2, kernel_size, ConvTBC(in_channels, out_channels * 2, kernel_size,
dropout=dropout, padding=padding) dropout=dropout, padding=padding)
...@@ -148,7 +148,7 @@ class FConvEncoder(FairseqEncoder): ...@@ -148,7 +148,7 @@ class FConvEncoder(FairseqEncoder):
if conv.kernel_size[0] % 2 == 1: if conv.kernel_size[0] % 2 == 1:
# padding is implicit in the conv # padding is implicit in the conv
x = conv(x) x = conv(x)
else: else:
padding_l = (conv.kernel_size[0] - 1) // 2 padding_l = (conv.kernel_size[0] - 1) // 2
padding_r = conv.kernel_size[0] // 2 padding_r = conv.kernel_size[0] // 2
x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r)) x = F.pad(x, (0, 0, 0, 0, padding_l, padding_r))
......
...@@ -92,7 +92,7 @@ class LSTMModel(FairseqModel): ...@@ -92,7 +92,7 @@ class LSTMModel(FairseqModel):
class LSTMEncoder(FairseqEncoder): class LSTMEncoder(FairseqEncoder):
"""LSTM encoder.""" """LSTM encoder."""
def __init__(self, dictionary, embed_dim=512, embed_dict=None, def __init__(self, dictionary, embed_dim=512, embed_dict=None,
num_layers=1, dropout_in=0.1, dropout_out=0.1): num_layers=1, dropout_in=0.1, dropout_out=0.1):
super().__init__(dictionary) super().__init__(dictionary)
self.num_layers = num_layers self.num_layers = num_layers
self.dropout_in = dropout_in self.dropout_in = dropout_in
...@@ -102,8 +102,7 @@ class LSTMEncoder(FairseqEncoder): ...@@ -102,8 +102,7 @@ class LSTMEncoder(FairseqEncoder):
self.padding_idx = dictionary.pad() self.padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx) self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
if embed_dict: if embed_dict:
self.embed_tokens = utils.load_embedding( self.embed_tokens = utils.load_embedding(embed_dict, self.dictionary, self.embed_tokens)
embed_dict, self.dictionary, self.embed_tokens)
self.lstm = LSTM( self.lstm = LSTM(
input_size=embed_dim, input_size=embed_dim,
...@@ -183,7 +182,7 @@ class AttentionLayer(nn.Module): ...@@ -183,7 +182,7 @@ class AttentionLayer(nn.Module):
class LSTMDecoder(FairseqIncrementalDecoder): class LSTMDecoder(FairseqIncrementalDecoder):
"""LSTM decoder.""" """LSTM decoder."""
def __init__(self, dictionary, encoder_embed_dim=512, def __init__(self, dictionary, encoder_embed_dim=512,
embed_dim=512, embed_dict=None, embed_dim=512, embed_dict=None,
out_embed_dim=512, num_layers=1, dropout_in=0.1, out_embed_dim=512, num_layers=1, dropout_in=0.1,
dropout_out=0.1, attention=True): dropout_out=0.1, attention=True):
...@@ -195,9 +194,7 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -195,9 +194,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
padding_idx = dictionary.pad() padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
if embed_dict: if embed_dict:
self.embed_tokens = utils.load_embedding( self.embed_tokens = utils.load_embedding(embed_dict, self.dictionary, self.embed_tokens)
embed_dict, self.dictionary, self.embed_tokens)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
LSTMCell(encoder_embed_dim + embed_dim if layer == 0 else embed_dim, embed_dim) LSTMCell(encoder_embed_dim + embed_dim if layer == 0 else embed_dim, embed_dim)
......
...@@ -249,10 +249,11 @@ def load_align_dict(replace_unk): ...@@ -249,10 +249,11 @@ def load_align_dict(replace_unk):
def print_embed_overlap(embed_dict, vocab_dict): def print_embed_overlap(embed_dict, vocab_dict):
embed_keys = set(embed_dict.keys()) embed_keys = set(embed_dict.keys())
vocab_keys = set(vocab_dict.symbols) vocab_keys = set(vocab_dict.symbols)
overlap = len(embed_keys & vocab_keys) overlap = len(embed_keys & vocab_keys)
print("| Found {}/{} types in embedding file.".format(overlap, len(vocab_dict))) print("| Found {}/{} types in embedding file.".format(overlap, len(vocab_dict)))
def parse_embedding(embed_path): def parse_embedding(embed_path):
"""Parse embedding text file into a dictionary of word and embedding tensors. """Parse embedding text file into a dictionary of word and embedding tensors.
...@@ -267,12 +268,13 @@ def parse_embedding(embed_path): ...@@ -267,12 +268,13 @@ def parse_embedding(embed_path):
""" """
embed_dict = dict() embed_dict = dict()
with open(embed_path) as f_embed: with open(embed_path) as f_embed:
_ = next(f_embed) #skip header _ = next(f_embed) # skip header
for line in f_embed: for line in f_embed:
pieces = line.strip().split() pieces = line.strip().split()
embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]]) embed_dict[pieces[0]] = torch.Tensor([float(weight) for weight in pieces[1:]])
return embed_dict return embed_dict
def load_embedding(embed_dict, vocab, embedding): def load_embedding(embed_dict, vocab, embedding):
for idx in range(len(vocab)): for idx in range(len(vocab)):
token = vocab[idx] token = vocab[idx]
...@@ -280,6 +282,7 @@ def load_embedding(embed_dict, vocab, embedding): ...@@ -280,6 +282,7 @@ def load_embedding(embed_dict, vocab, embedding):
embedding.weight.data[idx] = embed_dict[token] embedding.weight.data[idx] = embed_dict[token]
return embedding return embedding
def replace_unk(hypo_str, src_str, alignment, align_dict, unk): def replace_unk(hypo_str, src_str, alignment, align_dict, unk):
from fairseq import tokenizer from fairseq import tokenizer
# Tokens are strings here # Tokens are strings here
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment