Commit 90b2d8de authored by Zrachel's avatar Zrachel Committed by Myle Ott
Browse files

fix bug in lstm model (#68)

parent 884e3046
......@@ -122,6 +122,8 @@ class LSTMDecoder(FairseqIncrementalDecoder):
for layer in range(num_layers)
])
self.attention = AttentionLayer(encoder_embed_dim, embed_dim)
if embed_dim != out_embed_dim:
self.additional_fc = Linear(embed_dim, out_embed_dim)
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, input_tokens, encoder_out):
......@@ -197,6 +199,9 @@ class LSTMDecoder(FairseqIncrementalDecoder):
attn_scores = attn_scores.transpose(0, 2)
# project back to size of vocabulary
if hasattr(self, 'additional_fc'):
x = self.additional_fc(x)
x = F.dropout(x, p=self.dropout_out, training=self.training)
x = self.fc_out(x)
return x, attn_scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment