Unverified Commit af38ed48 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Fix bidirectional LSTM concatenation (#249)

parent 53c7d271
...@@ -222,10 +222,7 @@ class LSTMEncoder(FairseqEncoder): ...@@ -222,10 +222,7 @@ class LSTMEncoder(FairseqEncoder):
if self.bidirectional: if self.bidirectional:
def combine_bidir(outs): def combine_bidir(outs):
return torch.cat([ return outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous().view(self.num_layers, bsz, -1)
torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view(1, bsz, self.output_units)
for i in range(self.num_layers)
], dim=0)
final_hiddens = combine_bidir(final_hiddens) final_hiddens = combine_bidir(final_hiddens)
final_cells = combine_bidir(final_cells) final_cells = combine_bidir(final_cells)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment