"vscode:/vscode.git/clone" did not exist on "a0193fd53987d384de61aa3e9a6a34c75caf3fe2"
Commit 7b086021 authored by Myle Ott's avatar Myle Ott
Browse files

Make LSTM backwards compatible and fix incremental generation

parent 63dc27e8
...@@ -94,7 +94,7 @@ class AttentionLayer(nn.Module): ...@@ -94,7 +94,7 @@ class AttentionLayer(nn.Module):
# compute attention # compute attention
attn_scores = (source_hids * x.unsqueeze(0)).sum(dim=2) attn_scores = (source_hids * x.unsqueeze(0)).sum(dim=2)
attn_scores = F.softmax(attn_scores.t(), dim=1).t() # srclen x bsz attn_scores = F.softmax(attn_scores.t()).t() # srclen x bsz
# sum weighted sources # sum weighted sources
x = (attn_scores.unsqueeze(2) * source_hids).sum(dim=0) x = (attn_scores.unsqueeze(2) * source_hids).sum(dim=0)
...@@ -125,6 +125,11 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -125,6 +125,11 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, input_tokens, encoder_out): def forward(self, input_tokens, encoder_out):
if self._is_incremental_eval:
input_tokens = input_tokens[:, -1:]
return self._forward(input_tokens, encoder_out)
def _forward(self, input_tokens, encoder_out):
bsz, seqlen = input_tokens.size() bsz, seqlen = input_tokens.size()
# get outputs from encoder # get outputs from encoder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment