Commit a09fe803 authored by Myle Ott's avatar Myle Ott
Browse files

Fix BeamableMM

parent 9f7c3ec6
...@@ -110,7 +110,9 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -110,7 +110,9 @@ class FairseqIncrementalDecoder(FairseqDecoder):
def set_beam_size(self, beam_size): def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children.""" """Sets the beam size in the decoder and all children."""
def apply_set_beam_size(module): if getattr(self, '_beam_size', -1) != beam_size:
if module != self and hasattr(module, 'set_beam_size'): def apply_set_beam_size(module):
module.set_beam_size(beam_size) if module != self and hasattr(module, 'set_beam_size'):
self.apply(apply_set_beam_size) module.set_beam_size(beam_size)
self.apply(apply_set_beam_size)
self._beam_size = beam_size
...@@ -62,6 +62,11 @@ class FairseqModel(nn.Module): ...@@ -62,6 +62,11 @@ class FairseqModel(nn.Module):
return return
self.apply(apply_remove_weight_norm) self.apply(apply_remove_weight_norm)
def apply_make_generation_fast_(module):
if module != self and hasattr(module, 'make_generation_fast_'):
module.make_generation_fast_(**kwargs)
self.apply(apply_make_generation_fast_)
def train(mode): def train(mode):
if mode: if mode:
raise RuntimeError('cannot train after make_generation_fast') raise RuntimeError('cannot train after make_generation_fast')
...@@ -69,8 +74,3 @@ class FairseqModel(nn.Module): ...@@ -69,8 +74,3 @@ class FairseqModel(nn.Module):
# this model should no longer be used for training # this model should no longer be used for training
self.eval() self.eval()
self.train = train self.train = train
def apply_make_generation_fast_(module):
if module != self and hasattr(module, 'make_generation_fast_'):
module.make_generation_fast_(**kwargs)
self.apply(apply_make_generation_fast_)
...@@ -145,7 +145,8 @@ class AttentionLayer(nn.Module): ...@@ -145,7 +145,8 @@ class AttentionLayer(nn.Module):
def make_generation_fast_(self, beamable_mm_beam_size=None, **kwargs): def make_generation_fast_(self, beamable_mm_beam_size=None, **kwargs):
"""Replace torch.bmm with BeamableMM.""" """Replace torch.bmm with BeamableMM."""
if beamable_mm_beam_size is not None: if beamable_mm_beam_size is not None:
self.bmm = BeamableMM(beamable_mm_beam_size) del self.bmm
self.add_module('bmm', BeamableMM(beamable_mm_beam_size))
class FConvDecoder(FairseqIncrementalDecoder): class FConvDecoder(FairseqIncrementalDecoder):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment