Commit 03c4a716 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

Fix generation when vocabulary is small relative to beam size (fixes #7)

parent 2d3161da
...@@ -28,7 +28,7 @@ class FConvModel(nn.Module): ...@@ -28,7 +28,7 @@ class FConvModel(nn.Module):
decoder_out = self.decoder(input_tokens, input_positions, encoder_out) decoder_out = self.decoder(input_tokens, input_positions, encoder_out)
return decoder_out.view(-1, decoder_out.size(-1)) return decoder_out.view(-1, decoder_out.size(-1))
def make_generation_fast_(self, beam_size, use_beamable_mm=False): def make_generation_fast_(self, use_beamable_mm=False):
"""Optimize model for faster generation. """Optimize model for faster generation.
Optimizations include: Optimizations include:
...@@ -54,7 +54,7 @@ class FConvModel(nn.Module): ...@@ -54,7 +54,7 @@ class FConvModel(nn.Module):
# use BeamableMM in attention layers # use BeamableMM in attention layers
if use_beamable_mm: if use_beamable_mm:
self.decoder._use_beamable_mm(beam_size) self.decoder._use_beamable_mm()
def train(mode): def train(mode):
if mode: if mode:
...@@ -243,14 +243,14 @@ class Decoder(nn.Module): ...@@ -243,14 +243,14 @@ class Decoder(nn.Module):
context += conv.kernel_size[0] - 1 context += conv.kernel_size[0] - 1
return context return context
def incremental_inference(self): def incremental_inference(self, beam_size=None):
"""Context manager for incremental inference. """Context manager for incremental inference.
This provides an optimized forward pass for incremental inference This provides an optimized forward pass for incremental inference
(i.e., it predicts one time step at a time). If the input order changes (i.e., it predicts one time step at a time). If the input order changes
between time steps, call model.decoder.reorder_incremental_state to between time steps, call model.decoder.reorder_incremental_state to
update the relevant buffers. To generate a fresh sequence, first call update the relevant buffers. To generate a fresh sequence, first call
model.decoder.clear_incremental_state. model.decoder.start_fresh_sequence.
Usage: Usage:
``` ```
...@@ -263,18 +263,19 @@ class Decoder(nn.Module): ...@@ -263,18 +263,19 @@ class Decoder(nn.Module):
""" """
class IncrementalInference(object): class IncrementalInference(object):
def __init__(self, decoder): def __init__(self, decoder, beam_size):
self.decoder = decoder self.decoder = decoder
self.beam_size = beam_size
def __enter__(self): def __enter__(self):
self.decoder._start_incremental_inference() self.decoder._start_incremental_inference(self.beam_size)
def __exit__(self, *args): def __exit__(self, *args):
self.decoder._stop_incremental_inference() self.decoder._stop_incremental_inference()
return IncrementalInference(self) return IncrementalInference(self, beam_size)
def _start_incremental_inference(self): def _start_incremental_inference(self, beam_size):
assert not self._is_inference_incremental, \ assert not self._is_inference_incremental, \
'already performing incremental inference' 'already performing incremental inference'
self._is_inference_incremental = True self._is_inference_incremental = True
...@@ -287,7 +288,7 @@ class Decoder(nn.Module): ...@@ -287,7 +288,7 @@ class Decoder(nn.Module):
self.forward = self._incremental_forward self.forward = self._incremental_forward
# start a fresh sequence # start a fresh sequence
self.clear_incremental_state() self.start_fresh_sequence(beam_size)
def _stop_incremental_inference(self): def _stop_incremental_inference(self):
# restore original forward and convolution layers # restore original forward and convolution layers
...@@ -348,17 +349,21 @@ class Decoder(nn.Module): ...@@ -348,17 +349,21 @@ class Decoder(nn.Module):
return x, avg_attn_scores return x, avg_attn_scores
def clear_incremental_state(self): def start_fresh_sequence(self, beam_size=None):
"""Clear all state used for incremental generation. """Clear all state used for incremental generation.
**For incremental inference only** **For incremental inference only**
This should be called before generating a fresh sequence. This should be called before generating a fresh sequence.
beam_size is required if using BeamableMM.
""" """
if self._is_inference_incremental: if self._is_inference_incremental:
self.prev_state = None self.prev_state = None
for conv in self.convolutions: for conv in self.convolutions:
conv.clear_buffer() conv.clear_buffer()
for attn in self.attention:
if isinstance(attn.bmm, BeamableMM):
attn.bmm.set_beam_size(beam_size)
def reorder_incremental_state(self, new_order): def reorder_incremental_state(self, new_order):
"""Reorder buffered internal state (for incremental generation). """Reorder buffered internal state (for incremental generation).
...@@ -373,9 +378,9 @@ class Decoder(nn.Module): ...@@ -373,9 +378,9 @@ class Decoder(nn.Module):
for conv in self.convolutions: for conv in self.convolutions:
conv.reorder_buffer(new_order) conv.reorder_buffer(new_order)
def _use_beamable_mm(self, beam_size): def _use_beamable_mm(self):
"""Replace torch.bmm with BeamableMM in attention layers.""" """Replace torch.bmm with BeamableMM in attention layers."""
beamable_mm = BeamableMM(beam_size) beamable_mm = BeamableMM()
for attn in self.attention: for attn in self.attention:
attn.bmm = beamable_mm attn.bmm = beamable_mm
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
# #
from .beamable_mm import * from .beamable_mm import BeamableMM
from .linearized_convolution import *
from .conv_tbc import ConvTBC from .conv_tbc import ConvTBC
from .linearized_convolution import LinearizedConvolution
__all__ = [ __all__ = [
'BeamableMM', 'LinearizedConvolution', 'ConvTBC', 'BeamableMM', 'LinearizedConvolution', 'ConvTBC',
......
...@@ -18,16 +18,16 @@ class BeamableMM(nn.Module): ...@@ -18,16 +18,16 @@ class BeamableMM(nn.Module):
inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)}
with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}. with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.
""" """
def __init__(self, beam_size): def __init__(self):
super(BeamableMM, self).__init__() super(BeamableMM, self).__init__()
self.beam_size = beam_size self.beam_size = None
def forward(self, input1, input2): def forward(self, input1, input2):
if ( if (
not self.training and # test mode not self.training and # test mode
self.beam_size > 0 and # beam size is set self.beam_size is not None and # beam size is set
input1.dim() == 3 and # only support batched input input1.dim() == 3 and # only support batched input
input1.size(1) == 1 # single time step update input1.size(1) == 1 # single time step update
): ):
bsz, beam = input1.size(0), self.beam_size bsz, beam = input1.size(0), self.beam_size
...@@ -45,3 +45,6 @@ class BeamableMM(nn.Module): ...@@ -45,3 +45,6 @@ class BeamableMM(nn.Module):
return output.view(bsz, 1, -1) return output.view(bsz, 1, -1)
else: else:
return input1.bmm(input2) return input1.bmm(input2)
def set_beam_size(self, beam_size):
self.beam_size = beam_size
...@@ -87,13 +87,16 @@ class SequenceGenerator(object): ...@@ -87,13 +87,16 @@ class SequenceGenerator(object):
def _generate(self, src_tokens, src_positions, beam_size=None, maxlen=None): def _generate(self, src_tokens, src_positions, beam_size=None, maxlen=None):
bsz = src_tokens.size(0) bsz = src_tokens.size(0)
beam_size = beam_size if beam_size is not None else self.beam_size
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
# the max beam size is the dictionary size - 1, since we never select pad
beam_size = beam_size if beam_size is not None else self.beam_size
beam_size = min(beam_size, len(self.dict) - 1)
encoder_outs = [] encoder_outs = []
for model in self.models: for model in self.models:
model.eval() model.eval()
model.decoder.clear_incremental_state() # start a fresh sequence model.decoder.start_fresh_sequence(beam_size) # start a fresh sequence
# compute the encoder output and expand to beam size # compute the encoder output and expand to beam size
encoder_out = model.encoder(src_tokens, src_positions) encoder_out = model.encoder(src_tokens, src_positions)
...@@ -172,7 +175,7 @@ class SequenceGenerator(object): ...@@ -172,7 +175,7 @@ class SequenceGenerator(object):
sents_seen.add(sent) sents_seen.add(sent)
def get_hypo(): def get_hypo():
hypo = tokens[idx, 1:step+2].clone() hypo = tokens[idx, 1:step+2].clone() # skip the first index, which is EOS
hypo[step] = self.eos hypo[step] = self.eos
alignment = align[idx, 1:step+2].clone() alignment = align[idx, 1:step+2].clone()
return { return {
...@@ -219,6 +222,7 @@ class SequenceGenerator(object): ...@@ -219,6 +222,7 @@ class SequenceGenerator(object):
else: else:
# make probs contain cumulative scores for each hypothesis # make probs contain cumulative scores for each hypothesis
probs.add_(scores.view(-1, 1)) probs.add_(scores.view(-1, 1))
probs[:, self.pad] = -math.inf # never select pad
# record alignment to source tokens, based on attention # record alignment to source tokens, based on attention
_ignore_scores = buffer('_ignore_scores', type_of=scores) _ignore_scores = buffer('_ignore_scores', type_of=scores)
...@@ -229,7 +233,9 @@ class SequenceGenerator(object): ...@@ -229,7 +233,9 @@ class SequenceGenerator(object):
cand_scores = buffer('cand_scores', type_of=scores) cand_scores = buffer('cand_scores', type_of=scores)
cand_indices = buffer('cand_indices') cand_indices = buffer('cand_indices')
cand_beams = buffer('cand_beams') cand_beams = buffer('cand_beams')
probs.view(bsz, -1).topk(cand_size, out=(cand_scores, cand_indices)) probs.view(bsz, -1).topk(
min(cand_size, probs.view(bsz, -1).size(1) - 1), # -1 so we never select pad
out=(cand_scores, cand_indices))
torch.div(cand_indices, self.vocab_size, out=cand_beams) torch.div(cand_indices, self.vocab_size, out=cand_beams)
cand_indices.fmod_(self.vocab_size) cand_indices.fmod_(self.vocab_size)
...@@ -256,7 +262,7 @@ class SequenceGenerator(object): ...@@ -256,7 +262,7 @@ class SequenceGenerator(object):
# and values < cand_size indicate candidate active hypos. # and values < cand_size indicate candidate active hypos.
# After, the min values per row are the top candidate active hypos # After, the min values per row are the top candidate active hypos
active_mask = buffer('active_mask') active_mask = buffer('active_mask')
torch.add((eos_mask*cand_size).type_as(cand_offsets), cand_offsets, torch.add((eos_mask*cand_size).type_as(cand_offsets), cand_offsets[:eos_mask.size(1)],
out=active_mask) out=active_mask)
# get the top beam_size active hypotheses, which are just the hypos # get the top beam_size active hypotheses, which are just the hypos
......
...@@ -47,7 +47,7 @@ def main(): ...@@ -47,7 +47,7 @@ def main():
# Optimize model for generation # Optimize model for generation
for model in models: for model in models:
model.make_generation_fast_(args.beam, not args.no_beamable_mm) model.make_generation_fast_(not args.no_beamable_mm)
# Initialize generator # Initialize generator
translator = SequenceGenerator(models, dataset.dst_dict, beam_size=args.beam, translator = SequenceGenerator(models, dataset.dst_dict, beam_size=args.beam,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment