"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "f0ac7aaafcbafbecffd4f7c5a34213f9f9528db0"
Commit 90b2d8de authored by Zrachel's avatar Zrachel Committed by Myle Ott
Browse files

fix bug in lstm model (#68)

parent 884e3046
...@@ -122,6 +122,8 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -122,6 +122,8 @@ class LSTMDecoder(FairseqIncrementalDecoder):
for layer in range(num_layers) for layer in range(num_layers)
]) ])
self.attention = AttentionLayer(encoder_embed_dim, embed_dim) self.attention = AttentionLayer(encoder_embed_dim, embed_dim)
if embed_dim != out_embed_dim:
self.additional_fc = Linear(embed_dim, out_embed_dim)
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, input_tokens, encoder_out): def forward(self, input_tokens, encoder_out):
...@@ -197,6 +199,9 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -197,6 +199,9 @@ class LSTMDecoder(FairseqIncrementalDecoder):
attn_scores = attn_scores.transpose(0, 2) attn_scores = attn_scores.transpose(0, 2)
# project back to size of vocabulary # project back to size of vocabulary
if hasattr(self, 'additional_fc'):
x = self.additional_fc(x)
x = F.dropout(x, p=self.dropout_out, training=self.training)
x = self.fc_out(x) x = self.fc_out(x)
return x, attn_scores return x, attn_scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment