Commit 9438019f authored by Myle Ott's avatar Myle Ott Committed by Sergey Edunov
Browse files

Refactor incremental generation to be more explicit and less magical (#222)

parent e7094b14
......@@ -13,100 +13,21 @@ class FairseqIncrementalDecoder(FairseqDecoder):
def __init__(self, dictionary):
super().__init__(dictionary)
self._is_incremental_eval = False
self._incremental_state = {}
def forward(self, prev_output_tokens, encoder_out):
if self._is_incremental_eval:
raise NotImplementedError
else:
raise NotImplementedError
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
raise NotImplementedError
def incremental_inference(self):
"""Context manager for incremental inference.
This provides an optimized forward pass for incremental inference
(i.e., it predicts one time step at a time). If the input order changes
between time steps, call reorder_incremental_state to update the
relevant buffers. To generate a fresh sequence, first call
clear_incremental_state.
Usage:
```
with model.decoder.incremental_inference():
for step in range(maxlen):
out, _ = model.decoder(tokens[:, :step], encoder_out)
probs = model.get_normalized_probs(out[:, -1, :], log_probs=False)
```
"""
class IncrementalInference(object):
def __init__(self, decoder):
self.decoder = decoder
def __enter__(self):
self.decoder.incremental_eval(True)
def __exit__(self, *args):
self.decoder.incremental_eval(False)
return IncrementalInference(self)
def incremental_eval(self, mode=True):
"""Sets the decoder and all children in incremental evaluation mode."""
assert self._is_incremental_eval != mode, \
'incremental_eval already set to mode {}'.format(mode)
self._is_incremental_eval = mode
if mode:
self.clear_incremental_state()
def apply_incremental_eval(module):
if module != self and hasattr(module, 'incremental_eval'):
module.incremental_eval(mode)
self.apply(apply_incremental_eval)
def get_incremental_state(self, key):
"""Return cached state or None if not in incremental inference mode."""
if self._is_incremental_eval and key in self._incremental_state:
return self._incremental_state[key]
return None
def set_incremental_state(self, key, value):
"""Cache state needed for incremental inference mode."""
if self._is_incremental_eval:
self._incremental_state[key] = value
return value
def clear_incremental_state(self):
"""Clear all state used for incremental generation.
**For incremental inference only**
This should be called before generating a fresh sequence.
beam_size is required if using BeamableMM.
"""
if self._is_incremental_eval:
del self._incremental_state
self._incremental_state = {}
def apply_clear_incremental_state(module):
if module != self and hasattr(module, 'clear_incremental_state'):
module.clear_incremental_state()
self.apply(apply_clear_incremental_state)
def reorder_incremental_state(self, new_order):
"""Reorder buffered internal state (for incremental generation).
**For incremental inference only**
def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder incremental state.
This should be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the choice of beams.
order changes between time steps based on the selection of beams.
"""
if self._is_incremental_eval:
def apply_reorder_incremental_state(module):
if module != self and hasattr(module, 'reorder_incremental_state'):
module.reorder_incremental_state(new_order)
self.apply(apply_reorder_incremental_state)
def apply_reorder_incremental_state(module):
if module != self and hasattr(module, 'reorder_incremental_state'):
module.reorder_incremental_state(incremental_state, new_order)
self.apply(apply_reorder_incremental_state)
def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children."""
......
......@@ -10,6 +10,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import utils
from fairseq.data import LanguagePairDataset
from fairseq.modules import BeamableMM, GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution
......@@ -229,19 +230,13 @@ class FConvDecoder(FairseqIncrementalDecoder):
else:
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def forward(self, prev_output_tokens, encoder_out):
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
# split and transpose encoder outputs
encoder_a, encoder_b = self._split_encoder_out(encoder_out)
encoder_a, encoder_b = self._split_encoder_out(encoder_out, incremental_state)
# embed positions
positions = self.embed_positions(prev_output_tokens)
if self._is_incremental_eval:
# keep only the last token for incremental forward pass
prev_output_tokens = prev_output_tokens[:, -1:]
# embed tokens and positions
x = self.embed_tokens(prev_output_tokens) + positions
# embed tokens and combine with positional embeddings
x = self._embed_tokens(prev_output_tokens, incremental_state)
x += self.embed_positions(prev_output_tokens, incremental_state)
x = F.dropout(x, p=self.dropout, training=self.training)
target_embedding = x
......@@ -249,7 +244,7 @@ class FConvDecoder(FairseqIncrementalDecoder):
x = self.fc1(x)
# B x T x C -> T x B x C
x = self._transpose_unless_incremental_eval(x)
x = self._transpose_if_training(x, incremental_state)
# temporal convolutions
avg_attn_scores = None
......@@ -258,13 +253,14 @@ class FConvDecoder(FairseqIncrementalDecoder):
residual = x if proj is None else proj(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = conv(x)
x = conv.remove_future_timesteps(x)
x = conv(x, incremental_state)
if incremental_state is None:
x = conv.remove_future_timesteps(x)
x = F.glu(x, dim=2)
# attention
if attention is not None:
x = self._transpose_unless_incremental_eval(x)
x = self._transpose_if_training(x, incremental_state)
x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b))
attn_scores = attn_scores / num_attn_layers
......@@ -273,13 +269,13 @@ class FConvDecoder(FairseqIncrementalDecoder):
else:
avg_attn_scores.add_(attn_scores)
x = self._transpose_unless_incremental_eval(x)
x = self._transpose_if_training(x, incremental_state)
# residual
x = (x + residual) * math.sqrt(0.5)
# T x B x C -> B x T x C
x = self._transpose_unless_incremental_eval(x)
x = self._transpose_if_training(x, incremental_state)
# project back to size of vocabulary
x = self.fc2(x)
......@@ -288,10 +284,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
return x, avg_attn_scores
def reorder_incremental_state(self, new_order):
"""Reorder buffered internal state (for incremental generation)."""
super().reorder_incremental_state(new_order)
def max_positions(self):
"""Maximum output length supported by the decoder."""
return self.embed_positions.max_positions()
......@@ -306,13 +298,19 @@ class FConvDecoder(FairseqIncrementalDecoder):
state_dict['decoder.version'] = torch.Tensor([1])
return state_dict
def _split_encoder_out(self, encoder_out):
def _embed_tokens(self, tokens, incremental_state):
if incremental_state is not None:
# keep only the last token for incremental forward pass
tokens = tokens[:, -1:]
return self.embed_tokens(tokens)
def _split_encoder_out(self, encoder_out, incremental_state):
"""Split and transpose encoder outputs.
This is cached when doing incremental inference.
"""
cached_result = self.get_incremental_state('encoder_out')
if cached_result:
cached_result = utils.get_incremental_state(self, incremental_state, 'encoder_out')
if cached_result is not None:
return cached_result
# transpose only once to speed up attention layers
......@@ -320,12 +318,14 @@ class FConvDecoder(FairseqIncrementalDecoder):
encoder_a = encoder_a.transpose(1, 2).contiguous()
result = (encoder_a, encoder_b)
return self.set_incremental_state('encoder_out', result)
if incremental_state is not None:
utils.set_incremental_state(self, incremental_state, 'encoder_out', result)
return result
def _transpose_unless_incremental_eval(self, x):
if self._is_incremental_eval:
return x
return x.transpose(0, 1)
def _transpose_if_training(self, x, incremental_state):
if incremental_state is None:
x = x.transpose(0, 1)
return x
def Embedding(num_embeddings, embedding_dim, padding_idx):
......
......@@ -183,12 +183,9 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self.additional_fc = Linear(embed_dim, out_embed_dim)
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out):
if self._is_incremental_eval:
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
return self._forward(prev_output_tokens, encoder_out)
def _forward(self, prev_output_tokens, encoder_out):
bsz, seqlen = prev_output_tokens.size()
# get outputs from encoder
......@@ -204,15 +201,15 @@ class LSTMDecoder(FairseqIncrementalDecoder):
x = x.transpose(0, 1)
# initialize previous states (or get from cache during incremental generation)
prev_hiddens = self.get_incremental_state('prev_hiddens')
if not prev_hiddens:
# first time step, initialize previous states
prev_hiddens, prev_cells = self._init_prev_states(encoder_out)
input_feed = Variable(x.data.new(bsz, embed_dim).zero_())
cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
if cached_state is not None:
prev_hiddens, prev_cells, input_feed = cached_state
else:
# previous states are cached
prev_cells = self.get_incremental_state('prev_cells')
input_feed = self.get_incremental_state('input_feed')
_, encoder_hiddens, encoder_cells = encoder_out
num_layers = len(self.layers)
prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
prev_cells = [encoder_cells[i] for i in range(num_layers)]
input_feed = Variable(x.data.new(bsz, embed_dim).zero_())
attn_scores = Variable(x.data.new(srclen, seqlen, bsz).zero_())
outs = []
......@@ -242,9 +239,8 @@ class LSTMDecoder(FairseqIncrementalDecoder):
outs.append(out)
# cache previous states (no-op except during incremental generation)
self.set_incremental_state('prev_hiddens', prev_hiddens)
self.set_incremental_state('prev_cells', prev_cells)
self.set_incremental_state('input_feed', input_feed)
utils.set_incremental_state(
self, incremental_state, 'cached_state', (prev_hiddens, prev_cells, input_feed))
# collect outputs across time steps
x = torch.cat(outs, dim=0).view(seqlen, bsz, embed_dim)
......@@ -263,34 +259,25 @@ class LSTMDecoder(FairseqIncrementalDecoder):
return x, attn_scores
def reorder_incremental_state(self, new_order):
"""Reorder buffered internal state (for incremental generation)."""
super().reorder_incremental_state(new_order)
new_order = Variable(new_order)
def reorder_incremental_state(self, incremental_state, new_order):
cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
if cached_state is None:
return
def reorder_state(key):
old = self.get_incremental_state(key)
if isinstance(old, list):
new = [old_i.index_select(0, new_order) for old_i in old]
else:
new = old.index_select(0, new_order)
self.set_incremental_state(key, new)
def reorder_state(state):
if isinstance(state, list):
return [reorder_state(state_i) for state_i in state]
return state.index_select(0, new_order)
reorder_state('prev_hiddens')
reorder_state('prev_cells')
reorder_state('input_feed')
if not isinstance(new_order, Variable):
new_order = Variable(new_order)
new_state = tuple(map(reorder_state, cached_state))
utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
def max_positions(self):
"""Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number
def _init_prev_states(self, encoder_out):
_, encoder_hiddens, encoder_cells = encoder_out
num_layers = len(self.layers)
prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
prev_cells = [encoder_cells[i] for i in range(num_layers)]
return prev_hiddens, prev_cells
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
......
......@@ -20,14 +20,10 @@ class LearnedPositionalEmbedding(nn.Embedding):
def __init__(self, num_embeddings, embedding_dim, padding_idx, left_pad):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.left_pad = left_pad
self._is_incremental_eval = False
def incremental_eval(self, mode=True):
self._is_incremental_eval = mode
def forward(self, input):
def forward(self, input, incremental_state=None):
"""Input is expected to be of size [bsz x seqlen]."""
if self._is_incremental_eval:
if incremental_state is not None:
# positions is the same for every token when decoding a single step
positions = Variable(
input.data.new(1, 1).fill_(self.padding_idx + input.size(1)))
......
......@@ -22,35 +22,20 @@ class LinearizedConvolution(ConvTBC):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self._is_incremental_eval = False
self._linearized_weight = None
self.register_backward_hook(self._clear_linearized_weight)
def remove_future_timesteps(self, x):
"""Remove future time steps created by padding."""
if not self._is_incremental_eval and self.kernel_size[0] > 1 and self.padding[0] > 0:
x = x[:-self.padding[0], :, :]
return x
def incremental_eval(self, mode=True):
self._is_incremental_eval = mode
if mode:
self.clear_incremental_state()
def forward(self, input):
if self._is_incremental_eval:
return self.incremental_forward(input)
else:
def forward(self, input, incremental_state=None):
"""
Input: Time x Batch x Channel.
Args:
incremental_state: Used to buffer signal; if not None, then input is
expected to contain a single frame. If the input order changes
between time steps, call reorder_incremental_state.
"""
if incremental_state is None:
return super().forward(input)
def incremental_forward(self, input):
"""Forward convolution one time step at a time.
This function maintains an internal state to buffer signal and accepts
a single frame as input. If the input order changes between time steps,
call reorder_incremental_state. To apply to fresh inputs, call
clear_incremental_state.
"""
# reshape weight
weight = self._get_linearized_weight()
kw = self.kernel_size[0]
......@@ -58,25 +43,37 @@ class LinearizedConvolution(ConvTBC):
bsz = input.size(0) # input: bsz x len x dim
if kw > 1:
input = input.data
if self.input_buffer is None:
self.input_buffer = input.new(bsz, kw, input.size(2))
self.input_buffer.zero_()
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is None:
input_buffer = input.new(bsz, kw, input.size(2)).zero_()
self._set_input_buffer(incremental_state, input_buffer)
else:
# shift buffer
self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone()
input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone()
# append next input
self.input_buffer[:, -1, :] = input[:, -1, :]
input = utils.volatile_variable(self.input_buffer)
input_buffer[:, -1, :] = input[:, -1, :]
input = utils.volatile_variable(input_buffer)
with utils.maybe_no_grad():
output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1)
def clear_incremental_state(self):
self.input_buffer = None
def remove_future_timesteps(self, x):
"""Remove future time steps created by padding."""
if self.kernel_size[0] > 1 and self.padding[0] > 0:
x = x[:-self.padding[0], :, :]
return x
def reorder_incremental_state(self, incremental_state, new_order):
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
input_buffer = input_buffer.index_select(0, new_order)
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state(self, incremental_state, 'input_buffer')
def reorder_incremental_state(self, new_order):
if self.input_buffer is not None:
self.input_buffer = self.input_buffer.index_select(0, new_order)
def _set_input_buffer(self, incremental_state, new_buffer):
return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer)
def _get_linearized_weight(self):
if self._linearized_weight is None:
......
......@@ -5,7 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from contextlib import ExitStack
import math
import torch
......@@ -87,12 +86,8 @@ class SequenceGenerator(object):
def generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
"""Generate a batch of translations."""
with ExitStack() as stack:
for model in self.models:
if isinstance(model.decoder, FairseqIncrementalDecoder):
stack.enter_context(model.decoder.incremental_inference())
with utils.maybe_no_grad():
return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens)
with utils.maybe_no_grad():
return self._generate(src_tokens, src_lengths, beam_size, maxlen, prefix_tokens)
def _generate(self, src_tokens, src_lengths, beam_size=None, maxlen=None, prefix_tokens=None):
bsz, srclen = src_tokens.size()
......@@ -103,11 +98,14 @@ class SequenceGenerator(object):
beam_size = min(beam_size, self.vocab_size - 1)
encoder_outs = []
incremental_states = {}
for model in self.models:
if not self.retain_dropout:
model.eval()
if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.set_beam_size(beam_size)
incremental_states[model] = {}
else:
incremental_states[model] = None
# compute the encoder output for each beam
encoder_out = model.encoder(
......@@ -245,9 +243,11 @@ class SequenceGenerator(object):
if reorder_state is not None:
for model in self.models:
if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.reorder_incremental_state(reorder_state)
model.decoder.reorder_incremental_state(
incremental_states[model], reorder_state)
probs, avg_attn_scores = self._decode(tokens[:, :step+1], encoder_outs)
probs, avg_attn_scores = self._decode(
tokens[:, :step+1], encoder_outs, incremental_states)
if step == 0:
# at the first step all hypotheses are equally likely, so use
# only the first beam
......@@ -287,7 +287,6 @@ class SequenceGenerator(object):
)
torch.div(cand_indices, self.vocab_size, out=cand_beams)
cand_indices.fmod_(self.vocab_size)
else:
# finalize all active hypotheses once we hit maxlen
# pick the hypothesis with the highest prob of EOS right now
......@@ -403,7 +402,7 @@ class SequenceGenerator(object):
return finalized
def _decode(self, tokens, encoder_outs):
def _decode(self, tokens, encoder_outs, incremental_states):
# wrap in Variable
tokens = utils.volatile_variable(tokens)
......@@ -411,7 +410,7 @@ class SequenceGenerator(object):
avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs):
with utils.maybe_no_grad():
decoder_out, attn = model.decoder(tokens, encoder_out)
decoder_out, attn = model.decoder(tokens, encoder_out, incremental_states[model])
probs = model.get_normalized_probs(decoder_out[:, -1, :], log_probs=False).data
if avg_probs is None:
avg_probs = probs
......
......@@ -5,6 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from collections import defaultdict
import contextlib
import logging
import os
......@@ -198,6 +199,36 @@ def make_variable(sample, volatile=False, cuda=False):
return _make_variable(sample)
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
def _get_full_incremental_state_key(module_instance, key):
module_name = module_instance.__class__.__name__
# assign a unique ID to each module instance, so that incremental state is
# not shared across module instances
if not hasattr(module_instance, '_fairseq_instance_id'):
INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
module_instance._fairseq_instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
return '{}.{}.{}'.format(module_name, module_instance._fairseq_instance_id, key)
def get_incremental_state(module, incremental_state, key):
"""Helper for getting incremental state for an nn.Module."""
full_key = _get_full_incremental_state_key(module, key)
if incremental_state is None or full_key not in incremental_state:
return None
return incremental_state[full_key]
def set_incremental_state(module, incremental_state, key, value):
"""Helper for setting incremental state for an nn.Module."""
if incremental_state is not None:
full_key = _get_full_incremental_state_key(module, key)
incremental_state[full_key] = value
def load_align_dict(replace_unk):
if replace_unk is None:
align_dict = None
......
......@@ -9,7 +9,7 @@
import torch
from torch.autograd import Variable
from fairseq import data, dictionary
from fairseq import data, dictionary, utils
from fairseq.models import (
FairseqEncoder,
FairseqIncrementalDecoder,
......@@ -96,24 +96,21 @@ class TestIncrementalDecoder(FairseqIncrementalDecoder):
args.max_decoder_positions = getattr(args, 'max_decoder_positions', 100)
self.args = args
def forward(self, prev_output_tokens, encoder_out):
if self._is_incremental_eval:
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
return self._forward(prev_output_tokens, encoder_out)
def _forward(self, prev_output_tokens, encoder_out):
bbsz = prev_output_tokens.size(0)
vocab = len(self.dictionary)
src_len = encoder_out.size(1)
tgt_len = prev_output_tokens.size(1)
# determine number of steps
if self._is_incremental_eval:
if incremental_state is not None:
# cache step number
step = self.get_incremental_state('step')
step = utils.get_incremental_state(self, incremental_state, 'step')
if step is None:
step = 0
self.set_incremental_state('step', step + 1)
utils.set_incremental_state(self, incremental_state, 'step', step + 1)
steps = [step]
else:
steps = list(range(tgt_len))
......
......@@ -6,8 +6,6 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
from fairseq import options
from distributed_train import main as distributed_main
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment