Commit 6ec5022e authored by Myle Ott's avatar Myle Ott
Browse files

Move reorder_encoder_out to FairseqEncoder and fix non-incremental decoding

parent e9967cd3
...@@ -26,6 +26,12 @@ class CompositeEncoder(FairseqEncoder): ...@@ -26,6 +26,12 @@ class CompositeEncoder(FairseqEncoder):
encoder_out[key] = self.encoders[key](src_tokens, src_lengths) encoder_out[key] = self.encoders[key](src_tokens, src_lengths)
return encoder_out return encoder_out
def reorder_encoder_out(self, encoder_out, new_order):
"""Reorder encoder output according to new_order."""
for key in self.encoders:
encoder_out[key] = self.encoders[key].reorder_encoder_out(encoder_out[key], new_order)
return encoder_out
def max_positions(self): def max_positions(self):
return min([self.encoders[key].max_positions() for key in self.encoders]) return min([self.encoders[key].max_positions() for key in self.encoders])
......
...@@ -18,6 +18,10 @@ class FairseqEncoder(nn.Module): ...@@ -18,6 +18,10 @@ class FairseqEncoder(nn.Module):
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
raise NotImplementedError raise NotImplementedError
def reorder_encoder_out(self, encoder_out, new_order):
"""Reorder encoder output according to new_order."""
raise NotImplementedError
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
raise NotImplementedError raise NotImplementedError
......
...@@ -32,9 +32,6 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -32,9 +32,6 @@ class FairseqIncrementalDecoder(FairseqDecoder):
) )
self.apply(apply_reorder_incremental_state) self.apply(apply_reorder_incremental_state)
def reorder_encoder_out(self, encoder_out, new_order):
return encoder_out
def set_beam_size(self, beam_size): def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children.""" """Sets the beam size in the decoder and all children."""
if getattr(self, '_beam_size', -1) != beam_size: if getattr(self, '_beam_size', -1) != beam_size:
......
...@@ -268,6 +268,17 @@ class FConvEncoder(FairseqEncoder): ...@@ -268,6 +268,17 @@ class FConvEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask, # B x T 'encoder_padding_mask': encoder_padding_mask, # B x T
} }
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_out'] is not None:
encoder_out_dict['encoder_out'] = (
encoder_out_dict['encoder_out'][0].index_select(0, new_order),
encoder_out_dict['encoder_out'][1].index_select(0, new_order),
)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
return self.embed_positions.max_positions() return self.embed_positions.max_positions()
...@@ -496,12 +507,6 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -496,12 +507,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
encoder_out = tuple(eo.index_select(0, new_order) for eo in encoder_out) encoder_out = tuple(eo.index_select(0, new_order) for eo in encoder_out)
utils.set_incremental_state(self, incremental_state, 'encoder_out', encoder_out) utils.set_incremental_state(self, incremental_state, 'encoder_out', encoder_out)
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.max_positions() if self.embed_positions is not None else float('inf') return self.embed_positions.max_positions() if self.embed_positions is not None else float('inf')
......
...@@ -226,6 +226,19 @@ class FConvEncoder(FairseqEncoder): ...@@ -226,6 +226,19 @@ class FConvEncoder(FairseqEncoder):
'encoder_out': (x, y), 'encoder_out': (x, y),
} }
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder_out'] = tuple(
eo.index_select(0, new_order) for eo in encoder_out_dict['encoder_out']
)
if 'pretrained' in encoder_out_dict:
encoder_out_dict['pretrained']['encoder_out'] = tuple(
eo.index_select(0, new_order)
for eo in encoder_out_dict['pretrained']['encoder_out']
)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
return self.embed_positions.max_positions() return self.embed_positions.max_positions()
...@@ -409,30 +422,12 @@ class FConvDecoder(FairseqDecoder): ...@@ -409,30 +422,12 @@ class FConvDecoder(FairseqDecoder):
else: else:
return x, avg_attn_scores return x, avg_attn_scores
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation)."""
super().reorder_incremental_state(incremental_state, new_order)
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder']['encoder_out'] = tuple(
eo.index_select(0, new_order) for eo in encoder_out_dict['encoder']['encoder_out']
)
if 'pretrained' in encoder_out_dict:
encoder_out_dict['pretrained']['encoder']['encoder_out'] = tuple(
eo.index_select(0, new_order)
for eo in encoder_out_dict['pretrained']['encoder']['encoder_out']
)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.max_positions() return self.embed_positions.max_positions()
def _split_encoder_out(self, encoder_out): def _split_encoder_out(self, encoder_out):
"""Split and transpose encoder outputs. """Split and transpose encoder outputs."""
"""
# transpose only once to speed up attention layers # transpose only once to speed up attention layers
encoder_a, encoder_b = encoder_out encoder_a, encoder_b = encoder_out
encoder_a = encoder_a.transpose(0, 1).contiguous() encoder_a = encoder_a.transpose(0, 1).contiguous()
......
...@@ -197,6 +197,16 @@ class LSTMEncoder(FairseqEncoder): ...@@ -197,6 +197,16 @@ class LSTMEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None 'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None
} }
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder_out'] = tuple(
eo.index_select(1, new_order)
for eo in encoder_out_dict['encoder_out']
)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(1, new_order)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
return int(1e5) # an arbitrary large number return int(1e5) # an arbitrary large number
...@@ -366,16 +376,6 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -366,16 +376,6 @@ class LSTMDecoder(FairseqIncrementalDecoder):
new_state = tuple(map(reorder_state, cached_state)) new_state = tuple(map(reorder_state, cached_state))
utils.set_incremental_state(self, incremental_state, 'cached_state', new_state) utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder_out'] = tuple(
eo.index_select(1, new_order)
for eo in encoder_out_dict['encoder_out']
)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(1, new_order)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number return int(1e5) # an arbitrary large number
......
...@@ -164,6 +164,15 @@ class TransformerEncoder(FairseqEncoder): ...@@ -164,6 +164,15 @@ class TransformerEncoder(FairseqEncoder):
'encoder_padding_mask': encoder_padding_mask, # B x T 'encoder_padding_mask': encoder_padding_mask, # B x T
} }
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_out'] is not None:
encoder_out_dict['encoder_out'] = \
encoder_out_dict['encoder_out'].index_select(1, new_order)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
return self.embed_positions.max_positions() return self.embed_positions.max_positions()
...@@ -245,12 +254,6 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -245,12 +254,6 @@ class TransformerDecoder(FairseqIncrementalDecoder):
return x, attn return x, attn
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.max_positions() return self.embed_positions.max_positions()
......
...@@ -268,7 +268,7 @@ class SequenceGenerator(object): ...@@ -268,7 +268,7 @@ class SequenceGenerator(object):
for i, model in enumerate(self.models): for i, model in enumerate(self.models):
if isinstance(model.decoder, FairseqIncrementalDecoder): if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.reorder_incremental_state(incremental_states[model], reorder_state) model.decoder.reorder_incremental_state(incremental_states[model], reorder_state)
encoder_outs[i] = model.decoder.reorder_encoder_out(encoder_outs[i], reorder_state) encoder_outs[i] = model.encoder.reorder_encoder_out(encoder_outs[i], reorder_state)
probs, avg_attn_scores = self._decode( probs, avg_attn_scores = self._decode(
tokens[:, :step + 1], encoder_outs, incremental_states) tokens[:, :step + 1], encoder_outs, incremental_states)
......
...@@ -108,6 +108,9 @@ class TestEncoder(FairseqEncoder): ...@@ -108,6 +108,9 @@ class TestEncoder(FairseqEncoder):
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
return src_tokens return src_tokens
def reorder_encoder_out(self, encoder_out, new_order):
return encoder_out.index_select(0, new_order)
class TestIncrementalDecoder(FairseqIncrementalDecoder): class TestIncrementalDecoder(FairseqIncrementalDecoder):
def __init__(self, args, dictionary): def __init__(self, args, dictionary):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment