"src/vscode:/vscode.git/clone" did not exist on "8c6b47cfdea1962e23d3407f034b3b00dda8f2d6"
Commit a09fe803 authored by Myle Ott's avatar Myle Ott
Browse files

Fix BeamableMM

parent 9f7c3ec6
......@@ -110,7 +110,9 @@ class FairseqIncrementalDecoder(FairseqDecoder):
def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children."""
if getattr(self, '_beam_size', -1) != beam_size:
def apply_set_beam_size(module):
if module != self and hasattr(module, 'set_beam_size'):
module.set_beam_size(beam_size)
self.apply(apply_set_beam_size)
self._beam_size = beam_size
......@@ -62,6 +62,11 @@ class FairseqModel(nn.Module):
return
self.apply(apply_remove_weight_norm)
def apply_make_generation_fast_(module):
if module != self and hasattr(module, 'make_generation_fast_'):
module.make_generation_fast_(**kwargs)
self.apply(apply_make_generation_fast_)
def train(mode):
if mode:
raise RuntimeError('cannot train after make_generation_fast')
......@@ -69,8 +74,3 @@ class FairseqModel(nn.Module):
# this model should no longer be used for training
self.eval()
self.train = train
def apply_make_generation_fast_(module):
if module != self and hasattr(module, 'make_generation_fast_'):
module.make_generation_fast_(**kwargs)
self.apply(apply_make_generation_fast_)
......@@ -145,7 +145,8 @@ class AttentionLayer(nn.Module):
def make_generation_fast_(self, beamable_mm_beam_size=None, **kwargs):
"""Replace torch.bmm with BeamableMM."""
if beamable_mm_beam_size is not None:
self.bmm = BeamableMM(beamable_mm_beam_size)
del self.bmm
self.add_module('bmm', BeamableMM(beamable_mm_beam_size))
class FConvDecoder(FairseqIncrementalDecoder):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment