Commit 6e4b7e22 authored by Myle Ott's avatar Myle Ott
Browse files

Refactor model definitions

* Move some functionality out of FConvModel into FairseqModel base class
* Move incremental decoding functionality into FairseqIncrementalDecoder module
* Refactor positional embeddings to be more specific to FConvModel
parent 820f796f
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
# #
from .cross_entropy import CrossEntropyCriterion from .cross_entropy import CrossEntropyCriterion
from .fairseq_criterion import FairseqCriterion
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
__all__ = [ __all__ = [
......
...@@ -106,12 +106,14 @@ class LanguageDatasets(object): ...@@ -106,12 +106,14 @@ class LanguageDatasets(object):
max_positions=max_positions, max_positions=max_positions,
sort_by_source_size=sort_by_source_size) sort_by_source_size=sort_by_source_size)
elif split.startswith('valid'): elif split.startswith('valid'):
batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, dst=dataset.dst, batch_sampler = list(batches_by_size(
max_positions=max_positions, dataset.src, batch_size, max_tokens, dst=dataset.dst,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test)) max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
else: else:
batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, max_positions=max_positions, batch_sampler = list(batches_by_size(
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test)) dataset.src, batch_size, max_tokens, max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
return torch.utils.data.DataLoader( return torch.utils.data.DataLoader(
dataset, dataset,
...@@ -137,6 +139,11 @@ def skip_group_enumerator(it, ngpus, offset=0): ...@@ -137,6 +139,11 @@ def skip_group_enumerator(it, ngpus, offset=0):
class LanguagePairDataset(object): class LanguagePairDataset(object):
# padding constants
LEFT_PAD_SOURCE = False
LEFT_PAD_TARGET = True
def __init__(self, src, dst, pad_idx, eos_idx): def __init__(self, src, dst, pad_idx, eos_idx):
self.src = src self.src = src
self.dst = dst self.dst = dst
...@@ -166,17 +173,13 @@ class LanguagePairDataset(object): ...@@ -166,17 +173,13 @@ class LanguagePairDataset(object):
return LanguagePairDataset.collate_tokens( return LanguagePairDataset.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning) [s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning)
def merge_positions(key, left_pad):
return LanguagePairDataset.collate_positions([s[key] for s in samples], pad_idx, left_pad)
ntokens = sum(len(s['target']) for s in samples) ntokens = sum(len(s['target']) for s in samples)
return { return {
'id': torch.LongTensor([s['id'].item() for s in samples]), 'id': torch.LongTensor([s['id'].item() for s in samples]),
'input_tokens': merge('target', left_pad=True, move_eos_to_beginning=True), 'input_tokens': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
'input_positions': merge_positions('target', left_pad=True), move_eos_to_beginning=True),
'target': merge('target', left_pad=True), 'target': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET),
'src_tokens': merge('source', left_pad=False), 'src_tokens': merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE),
'src_positions': merge_positions('source', left_pad=False),
'ntokens': ntokens, 'ntokens': ntokens,
} }
...@@ -201,18 +204,6 @@ class LanguagePairDataset(object): ...@@ -201,18 +204,6 @@ class LanguagePairDataset(object):
copy_tensor(v, res[i][:len(v)]) copy_tensor(v, res[i][:len(v)])
return res return res
@staticmethod
def collate_positions(values, pad_idx, left_pad):
start = pad_idx + 1
size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(pad_idx)
for i, v in enumerate(values):
if left_pad:
torch.arange(start, start + len(v), out=res[i][size-len(v):])
else:
torch.arange(start, start + len(v), out=res[i][:len(v)])
return res
def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
max_positions=1024, ignore_invalid_inputs=False): max_positions=1024, ignore_invalid_inputs=False):
...@@ -243,15 +234,12 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None, ...@@ -243,15 +234,12 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
cur_max_size = 0 cur_max_size = 0
ignored = [] ignored = []
for idx in indices: for idx in indices:
# - 2 here stems from make_positions() where we offset positions
# by padding_value + 1
if src.sizes[idx] < 2 or \ if src.sizes[idx] < 2 or \
(False if dst is None else dst.sizes[idx] < 2) or \ (False if dst is None else dst.sizes[idx] < 2) or \
sizes[idx] > max_positions - 2: sizes[idx] > max_positions:
if ignore_invalid_inputs: if ignore_invalid_inputs:
ignored.append(idx) ignored.append(idx)
continue continue
raise Exception("Unable to handle input id {} of " raise Exception("Unable to handle input id {} of "
"size {} / {}.".format(idx, src.sizes[idx], "size {} / {}.".format(idx, src.sizes[idx],
"none" if dst is None else dst.sizes[idx])) "none" if dst is None else dst.sizes[idx]))
...@@ -290,11 +278,9 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0, ...@@ -290,11 +278,9 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
sample_len = 0 sample_len = 0
ignored = [] ignored = []
for idx in indices: for idx in indices:
# - 2 here stems from make_positions() where we offset positions
# by padding_value + 1
if src.sizes[idx] < 2 or dst.sizes[idx] < 2 or \ if src.sizes[idx] < 2 or dst.sizes[idx] < 2 or \
src.sizes[idx] > max_positions - 2 or \ src.sizes[idx] > max_positions or \
dst.sizes[idx] > max_positions - 2: dst.sizes[idx] > max_positions:
ignored.append(idx) ignored.append(idx)
continue continue
sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx]) sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx])
......
...@@ -17,6 +17,7 @@ class Dictionary(object): ...@@ -17,6 +17,7 @@ class Dictionary(object):
self.symbols = [] self.symbols = []
self.count = [] self.count = []
self.indices = {} self.indices = {}
# dictionary indexing starts at 1 for consistency with Lua
self.add_symbol('<Lua heritage>') self.add_symbol('<Lua heritage>')
self.pad_index = self.add_symbol(pad) self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos) self.eos_index = self.add_symbol(eos)
......
...@@ -6,6 +6,11 @@ ...@@ -6,6 +6,11 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
# #
from .fairseq_decoder import FairseqDecoder
from .fairseq_encoder import FairseqEncoder
from .fairseq_incremental_decoder import FairseqIncrementalDecoder
from .fairseq_model import FairseqModel
from . import fconv from . import fconv
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch.nn as nn
class FairseqDecoder(nn.Module):
"""Base class for decoders."""
def __init__(self):
super().__init__()
def max_positions(self):
"""Maximum input length supported by the decoder."""
raise NotImplementedError
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch.nn as nn
class FairseqEncoder(nn.Module):
"""Base class for encoders."""
def __init__(self):
super().__init__()
def max_positions(self):
"""Maximum input length supported by the encoder."""
raise NotImplementedError
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch.nn as nn
from . import FairseqDecoder
class FairseqIncrementalDecoder(FairseqDecoder):
"""Base class for incremental decoders."""
def __init__(self):
super().__init__()
self._is_incremental_eval = False
self._incremental_state = {}
def forward(self, tokens, encoder_out):
raise NotImplementedError
def incremental_forward(self, tokens, encoder_out):
"""Forward pass for one time step."""
# keep only the last token for incremental forward pass
return self.forward(tokens[:, -1:], encoder_out)
def incremental_inference(self):
"""Context manager for incremental inference.
This provides an optimized forward pass for incremental inference
(i.e., it predicts one time step at a time). If the input order changes
between time steps, call reorder_incremental_state to update the
relevant buffers. To generate a fresh sequence, first call
clear_incremental_state.
Usage:
```
with model.decoder.incremental_inference():
for step in range(maxlen):
out, _ = model.decoder.incremental_forward(
tokens[:, :step], encoder_out)
probs = torch.nn.functional.log_softmax(out[:, -1, :])
```
"""
class IncrementalInference(object):
def __init__(self, decoder):
self.decoder = decoder
def __enter__(self):
self.decoder.incremental_eval(True)
def __exit__(self, *args):
self.decoder.incremental_eval(False)
return IncrementalInference(self)
def incremental_eval(self, mode=True):
"""Sets the decoder and all children in incremental evaluation mode."""
assert self._is_incremental_eval != mode, \
'incremental_eval already set to mode {}'.format(mode)
self._is_incremental_eval = mode
if mode:
self.clear_incremental_state()
def apply_incremental_eval(module):
if module != self and hasattr(module, 'incremental_eval'):
module.incremental_eval(mode)
self.apply(apply_incremental_eval)
def get_incremental_state(self, key):
"""Return cached state or None if not in incremental inference mode."""
if self._is_incremental_eval and key in self._incremental_state:
return self._incremental_state[key]
return None
def set_incremental_state(self, key, value):
"""Cache state needed for incremental inference mode."""
if self._is_incremental_eval:
self._incremental_state[key] = value
return value
def clear_incremental_state(self):
"""Clear all state used for incremental generation.
**For incremental inference only**
This should be called before generating a fresh sequence.
beam_size is required if using BeamableMM.
"""
if self._is_incremental_eval:
self._incremental_state = {}
def apply_clear_incremental_state(module):
if module != self and hasattr(module, 'clear_incremental_state'):
module.clear_incremental_state()
self.apply(apply_clear_incremental_state)
def reorder_incremental_state(self, new_order):
"""Reorder buffered internal state (for incremental generation).
**For incremental inference only**
This should be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the choice of beams.
"""
if self._is_incremental_eval:
def apply_reorder_incremental_state(module):
if module != self and hasattr(module, 'reorder_incremental_state'):
module.reorder_incremental_state(new_order)
self.apply(apply_reorder_incremental_state)
def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children."""
def apply_set_beam_size(module):
if module != self and hasattr(module, 'set_beam_size'):
module.set_beam_size(beam_size)
self.apply(apply_set_beam_size)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch.nn as nn
from . import FairseqDecoder, FairseqEncoder
class FairseqModel(nn.Module):
"""Base class for encoder-decoder models."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
assert isinstance(self.encoder, FairseqEncoder)
assert isinstance(self.decoder, FairseqDecoder)
self.src_dict = encoder.dictionary
self.dst_dict = decoder.dictionary
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
self._is_generation_fast = False
def forward(self, src_tokens, input_tokens):
encoder_out = self.encoder(src_tokens)
decoder_out, _ = self.decoder(input_tokens, encoder_out)
return decoder_out.view(-1, decoder_out.size(-1))
def max_encoder_positions(self):
"""Maximum input length supported by the encoder."""
return self.encoder.max_positions()
def max_decoder_positions(self):
"""Maximum output length supported by the decoder."""
return self.decoder.max_positions()
def make_generation_fast_(self, **kwargs):
"""Optimize model for faster generation."""
if self._is_generation_fast:
return # only apply once
self._is_generation_fast = True
# remove weight norm from all modules in the network
def apply_remove_weight_norm(module):
try:
nn.utils.remove_weight_norm(module)
except ValueError: # this module didn't have weight norm
return
self.apply(apply_remove_weight_norm)
def train(mode):
if mode:
raise RuntimeError('cannot train after make_generation_fast')
# this model should no longer be used for training
self.eval()
self.train = train
def apply_make_generation_fast_(module):
if module != self and hasattr(module, 'make_generation_fast_'):
module.make_generation_fast_(**kwargs)
self.apply(apply_make_generation_fast_)
...@@ -8,74 +8,42 @@ ...@@ -8,74 +8,42 @@
import math import math
import torch import torch
from torch.autograd import Variable
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.modules import BeamableMM, LinearizedConvolution from fairseq.data import LanguagePairDataset
from fairseq.modules import BeamableMM, GradMultiply, LinearizedConvolution
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel
class FConvModel(nn.Module):
def __init__(self, encoder, decoder):
super(FConvModel, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.src_dict = encoder.dictionary
self.dst_dict = decoder.dictionary
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
self.encoder.num_attention_layers = sum([layer is not None for layer in decoder.attention])
self._is_generation_fast = False
def forward(self, src_tokens, src_positions, input_tokens, input_positions): def make_positions(tokens, padding_idx, left_pad, offset=0):
encoder_out = self.encoder(src_tokens, src_positions) seqlen = tokens.size(1)
decoder_out = self.decoder(input_tokens, input_positions, encoder_out) if not hasattr(make_positions, 'range'):
return decoder_out.view(-1, decoder_out.size(-1)) make_positions.range = tokens.new()
if make_positions.range.numel() < offset + seqlen:
# offset positions by the padding index
torch.arange(padding_idx + 1, padding_idx + 1 + offset + seqlen,
out=make_positions.range)
mask = tokens.ne(padding_idx)
positions = make_positions.range[offset:offset+seqlen].expand_as(tokens)
if left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
return tokens.clone().masked_scatter_(mask, positions[mask])
def make_generation_fast_(self, use_beamable_mm=False):
"""Optimize model for faster generation.
Optimizations include: class FConvModel(FairseqModel):
- remove WeightNorm def __init__(self, encoder, decoder):
- (optionally) use BeamableMM in attention layers super().__init__(encoder, decoder)
self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention)
The optimized model should not be used again for training.
Note: this can be combined with incremental inference in the Decoder for
even faster generation.
"""
if self._is_generation_fast:
return # only apply once
self._is_generation_fast = True
# remove weight norm from all modules in the network
def remove_weight_norm(m):
try:
nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(remove_weight_norm)
# use BeamableMM in attention layers
if use_beamable_mm:
self.decoder._use_beamable_mm()
def train(mode):
if mode:
raise RuntimeError('cannot train after make_generation_fast')
# this model should no longer be used for training
self.eval()
self.train = train
class Encoder(nn.Module): class FConvEncoder(FairseqEncoder):
"""Convolutional encoder""" """Convolutional encoder"""
def __init__(self, dictionary, embed_dim=512, max_positions=1024, def __init__(self, dictionary, embed_dim=512, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1): convolutions=((512, 3),) * 20, dropout=0.1):
super(Encoder, self).__init__() super().__init__()
self.dictionary = dictionary self.dictionary = dictionary
self.dropout = dropout self.dropout = dropout
self.num_attention_layers = None self.num_attention_layers = None
...@@ -99,9 +67,12 @@ class Encoder(nn.Module): ...@@ -99,9 +67,12 @@ class Encoder(nn.Module):
in_channels = out_channels in_channels = out_channels
self.fc2 = Linear(in_channels, embed_dim) self.fc2 = Linear(in_channels, embed_dim)
def forward(self, tokens, positions): def forward(self, src_tokens):
positions = Variable(make_positions(src_tokens.data, self.dictionary.pad(),
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE))
# embed tokens and positions # embed tokens and positions
x = self.embed_tokens(tokens) + self.embed_positions(positions) x = self.embed_tokens(src_tokens) + self.embed_positions(positions)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
input_embedding = x input_embedding = x
...@@ -126,17 +97,21 @@ class Encoder(nn.Module): ...@@ -126,17 +97,21 @@ class Encoder(nn.Module):
x = self.fc2(x) x = self.fc2(x)
# scale gradients (this only affects backward, not forward) # scale gradients (this only affects backward, not forward)
x = grad_multiply(x, 1.0 / (2.0 * self.num_attention_layers)) x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers))
# add output to input embedding for attention # add output to input embedding for attention
y = (x + input_embedding) * math.sqrt(0.5) y = (x + input_embedding) * math.sqrt(0.5)
return x, y return x, y
def max_positions(self):
"""Maximum input length supported by the encoder."""
return self.embed_positions.num_embeddings - self.dictionary.pad() - 1
class AttentionLayer(nn.Module): class AttentionLayer(nn.Module):
def __init__(self, conv_channels, embed_dim, bmm=None): def __init__(self, conv_channels, embed_dim, bmm=None):
super(AttentionLayer, self).__init__() super().__init__()
# projects from output of convolution to embedding dimension # projects from output of convolution to embedding dimension
self.in_projection = Linear(conv_channels, embed_dim) self.in_projection = Linear(conv_channels, embed_dim)
# projects from embedding dimension to convolution size # projects from embedding dimension to convolution size
...@@ -167,13 +142,18 @@ class AttentionLayer(nn.Module): ...@@ -167,13 +142,18 @@ class AttentionLayer(nn.Module):
x = (self.out_projection(x) + residual) * math.sqrt(0.5) x = (self.out_projection(x) + residual) * math.sqrt(0.5)
return x, attn_scores return x, attn_scores
def make_generation_fast_(self, beamable_mm_beam_size=None, **kwargs):
"""Replace torch.bmm with BeamableMM."""
if beamable_mm_beam_size is not None:
self.bmm = BeamableMM(beamable_mm_beam_size)
class Decoder(nn.Module):
class FConvDecoder(FairseqIncrementalDecoder):
"""Convolutional decoder""" """Convolutional decoder"""
def __init__(self, dictionary, embed_dim=512, out_embed_dim=256, def __init__(self, dictionary, embed_dim=512, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20, max_positions=1024, convolutions=((512, 3),) * 20,
attention=True, dropout=0.1): attention=True, dropout=0.1):
super(Decoder, self).__init__() super().__init__()
self.dictionary = dictionary self.dictionary = dictionary
self.dropout = dropout self.dropout = dropout
...@@ -204,25 +184,38 @@ class Decoder(nn.Module): ...@@ -204,25 +184,38 @@ class Decoder(nn.Module):
self.fc2 = Linear(in_channels, out_embed_dim) self.fc2 = Linear(in_channels, out_embed_dim)
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
self._is_inference_incremental = False def forward(self, input_tokens, encoder_out):
positions = Variable(make_positions(input_tokens.data, self.dictionary.pad(),
left_pad=LanguagePairDataset.LEFT_PAD_TARGET))
return self._forward(input_tokens, positions, encoder_out)
def incremental_forward(self, input_tokens, encoder_out):
"""Forward pass for one time step."""
# positions is the same for every token when decoding a single step
positions = Variable(input_tokens.data.new(1, 1).fill_(
self.dictionary.pad() + input_tokens.size(1)))
# keep only the last token for incremental forward pass
return self._forward(input_tokens[:, -1:], positions, encoder_out)
def _forward(self, input_tokens, positions, encoder_out):
# split and transpose encoder outputs
encoder_a, encoder_b = self._split_encoder_out(encoder_out)
def forward(self, tokens, positions, encoder_out):
# embed tokens and positions # embed tokens and positions
x = self.embed_tokens(tokens) + self.embed_positions(positions) x = self.embed_tokens(input_tokens) + self.embed_positions(positions)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
target_embedding = x target_embedding = x
# project to size of convolution # project to size of convolution
x = self.fc1(x) x = self.fc1(x)
# transpose only once to speed up attention layers
encoder_a, encoder_b = encoder_out
encoder_a = encoder_a.transpose(1, 2).contiguous()
# B x T x C -> T x B x C # B x T x C -> T x B x C
x = x.transpose(0, 1) x = self._transpose_unless_incremental_eval(x)
# temporal convolutions # temporal convolutions
avg_attn_scores = None
num_attn_layers = len(self.attention)
for proj, conv, attention in zip(self.projections, self.convolutions, self.attention): for proj, conv, attention in zip(self.projections, self.convolutions, self.attention):
residual = x if proj is None else proj(x) residual = x if proj is None else proj(x)
...@@ -233,172 +226,54 @@ class Decoder(nn.Module): ...@@ -233,172 +226,54 @@ class Decoder(nn.Module):
# attention # attention
if attention is not None: if attention is not None:
x = x.transpose(1, 0) x = self._transpose_unless_incremental_eval(x)
x, _ = attention(x, target_embedding, (encoder_a, encoder_b))
x = x.transpose(1, 0)
# residual
x = (x + residual) * math.sqrt(0.5)
# T x B x C -> B x T x C
x = x.transpose(1, 0)
# project back to size of vocabulary
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc3(x)
return x
def context_size(self):
"""Maximum number of input elements each output element depends on"""
context = 1
for conv in self.convolutions:
context += conv.kernel_size[0] - 1
return context
def max_positions(self):
"""Returns maximum size of positions embeddings supported by this decoder"""
return self.embed_positions.num_embeddings
def incremental_inference(self, beam_size=None):
"""Context manager for incremental inference.
This provides an optimized forward pass for incremental inference
(i.e., it predicts one time step at a time). If the input order changes
between time steps, call model.decoder.reorder_incremental_state to
update the relevant buffers. To generate a fresh sequence, first call
model.decoder.start_fresh_sequence.
Usage:
```
with model.decoder.incremental_inference():
for step in range(maxlen):
out = model.decoder(tokens[:, :step], positions[:, :step],
encoder_out)
probs = F.log_softmax(out[:, -1, :])
```
"""
class IncrementalInference(object):
def __init__(self, decoder, beam_size):
self.decoder = decoder
self.beam_size = beam_size
def __enter__(self):
self.decoder._start_incremental_inference(self.beam_size)
def __exit__(self, *args):
self.decoder._stop_incremental_inference()
return IncrementalInference(self, beam_size)
def _start_incremental_inference(self, beam_size):
assert not self._is_inference_incremental, \
'already performing incremental inference'
self._is_inference_incremental = True
# save original forward
self._orig_forward = self.forward
# switch to incremental forward
self.forward = self._incremental_forward
# start a fresh sequence
self.start_fresh_sequence(beam_size)
def _stop_incremental_inference(self):
# restore original forward
self.forward = self._orig_forward
self._is_inference_incremental = False
def _incremental_forward(self, tokens, positions, encoder_out):
assert self._is_inference_incremental
# setup initial state
if self.prev_state is None:
# transpose encoder output once to speed up attention layers
encoder_a, encoder_b = encoder_out
encoder_a = encoder_a.transpose(1, 2).contiguous()
self.prev_state = {
'encoder_out': (encoder_a, encoder_b),
}
# load previous state
encoder_a, encoder_b = self.prev_state['encoder_out']
# keep only the last token for incremental forward pass
tokens = tokens[:, -1:]
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_tokens(tokens) + self.embed_positions(positions)
target_embedding = x
# project to size of convolution
x = self.fc1(x)
# temporal convolutions
avg_attn_scores = None
num_attn_layers = len(self.attention)
for proj, conv, attention in zip(self.projections, self.convolutions, self.attention):
residual = x if proj is None else proj(x)
x = conv.incremental_forward(x)
x = F.glu(x)
# attention
if attention is not None:
x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b)) x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b))
attn_scores = attn_scores / num_attn_layers attn_scores = attn_scores / num_attn_layers
if avg_attn_scores is None: if avg_attn_scores is None:
avg_attn_scores = attn_scores avg_attn_scores = attn_scores
else: else:
avg_attn_scores += attn_scores avg_attn_scores.add_(attn_scores)
x = self._transpose_unless_incremental_eval(x)
# residual # residual
x = (x + residual) * math.sqrt(0.5) x = (x + residual) * math.sqrt(0.5)
# T x B x C -> B x T x C
x = self._transpose_unless_incremental_eval(x)
# project back to size of vocabulary # project back to size of vocabulary
x = self.fc2(x) x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc3(x) x = self.fc3(x)
return x, avg_attn_scores return x, avg_attn_scores
def start_fresh_sequence(self, beam_size=None): def max_positions(self):
"""Clear all state used for incremental generation. """Maximum output length supported by the decoder."""
return self.embed_positions.num_embeddings - self.dictionary.pad() - 1
**For incremental inference only** def _split_encoder_out(self, encoder_out):
"""Split and transpose encoder outputs.
This should be called before generating a fresh sequence. This is cached when doing incremental inference.
beam_size is required if using BeamableMM.
""" """
if self._is_inference_incremental: cached_result = self.get_incremental_state('encoder_out')
self.prev_state = None if cached_result:
for conv in self.convolutions: return cached_result
conv.clear_buffer()
for attn in self.attention: # transpose only once to speed up attention layers
if isinstance(attn.bmm, BeamableMM): encoder_a, encoder_b = encoder_out
attn.bmm.set_beam_size(beam_size) encoder_a = encoder_a.transpose(1, 2).contiguous()
result = (encoder_a, encoder_b)
def reorder_incremental_state(self, new_order):
"""Reorder buffered internal state (for incremental generation). return self.set_incremental_state('encoder_out', result)
**For incremental inference only**
This should be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the choice of beams.
"""
if self._is_inference_incremental:
for conv in self.convolutions:
conv.reorder_buffer(new_order)
def _use_beamable_mm(self): def _transpose_unless_incremental_eval(self, x):
"""Replace torch.bmm with BeamableMM in attention layers.""" if self._is_incremental_eval:
beamable_mm = BeamableMM() return x
for attn in self.attention: return x.transpose(0, 1)
attn.bmm = beamable_mm
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
...@@ -434,23 +309,6 @@ def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs): ...@@ -434,23 +309,6 @@ def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
return nn.utils.weight_norm(m, dim=2) return nn.utils.weight_norm(m, dim=2)
def grad_multiply(x, scale):
return GradMultiply.apply(x, scale)
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
ctx.mark_shared_storage((x, res))
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
def get_archs(): def get_archs():
return [ return [
'fconv', 'fconv_iwslt_de_en', 'fconv_wmt_en_ro', 'fconv_wmt_en_de', 'fconv_wmt_en_fr', 'fconv', 'fconv_iwslt_de_en', 'fconv_wmt_en_ro', 'fconv_wmt_en_de', 'fconv_wmt_en_fr',
...@@ -518,14 +376,14 @@ def parse_arch(args): ...@@ -518,14 +376,14 @@ def parse_arch(args):
def build_model(args, src_dict, dst_dict): def build_model(args, src_dict, dst_dict):
encoder = Encoder( encoder = FConvEncoder(
src_dict, src_dict,
embed_dim=args.encoder_embed_dim, embed_dim=args.encoder_embed_dim,
convolutions=eval(args.encoder_layers), convolutions=eval(args.encoder_layers),
dropout=args.dropout, dropout=args.dropout,
max_positions=args.max_positions, max_positions=args.max_positions,
) )
decoder = Decoder( decoder = FConvDecoder(
dst_dict, dst_dict,
embed_dim=args.decoder_embed_dim, embed_dim=args.decoder_embed_dim,
convolutions=eval(args.decoder_layers), convolutions=eval(args.decoder_layers),
......
...@@ -8,8 +8,12 @@ ...@@ -8,8 +8,12 @@
from .beamable_mm import BeamableMM from .beamable_mm import BeamableMM
from .conv_tbc import ConvTBC from .conv_tbc import ConvTBC
from .grad_multiply import GradMultiply
from .linearized_convolution import LinearizedConvolution from .linearized_convolution import LinearizedConvolution
__all__ = [ __all__ = [
'BeamableMM', 'LinearizedConvolution', 'ConvTBC', 'BeamableMM',
'ConvTBC',
'GradMultiply',
'LinearizedConvolution',
] ]
...@@ -18,9 +18,9 @@ class BeamableMM(nn.Module): ...@@ -18,9 +18,9 @@ class BeamableMM(nn.Module):
inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)}
with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}. with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.
""" """
def __init__(self): def __init__(self, beam_size=None):
super(BeamableMM, self).__init__() super(BeamableMM, self).__init__()
self.beam_size = None self.beam_size = beam_size
def forward(self, input1, input2): def forward(self, input1, input2):
if ( if (
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
ctx.mark_shared_storage((x, res))
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
...@@ -14,33 +14,44 @@ from .conv_tbc import ConvTBC ...@@ -14,33 +14,44 @@ from .conv_tbc import ConvTBC
class LinearizedConvolution(ConvTBC): class LinearizedConvolution(ConvTBC):
"""An optimized version of nn.Conv1d. """An optimized version of nn.Conv1d.
This module replaces convolutions with linear layers as appropriate At training time, this module uses ConvTBC, which is an optimized version
and supports optimizations for incremental inference. of Conv1d. At inference time, it optimizes incremental generation (i.e.,
one time step at a time) by replacing the convolutions with linear layers.
""" """
def __init__(self, in_channels, out_channels, kernel_size, **kwargs): def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super().__init__(in_channels, out_channels, kernel_size, **kwargs) super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self.clear_buffer() self._is_incremental_eval = False
self._linearized_weight = None self._linearized_weight = None
self.register_backward_hook(self._clear_linearized_weight) self.register_backward_hook(self._clear_linearized_weight)
def remove_future_timesteps(self, x): def remove_future_timesteps(self, x):
"""Remove future time steps created by padding.""" """Remove future time steps created by padding."""
if self.kernel_size[0] > 1 and self.padding[0] > 0: if not self._is_incremental_eval and self.kernel_size[0] > 1 and self.padding[0] > 0:
x = x[:-self.padding[0], :, :] x = x[:-self.padding[0], :, :]
return x return x
def incremental_eval(self, mode=True):
self._is_incremental_eval = mode
if mode:
self.clear_incremental_state()
def forward(self, input):
if self._is_incremental_eval:
return self.incremental_forward(input)
else:
return super().forward(input)
def incremental_forward(self, input): def incremental_forward(self, input):
"""Forward convolution one time step at a time. """Forward convolution one time step at a time.
This function maintains an internal state to buffer signal and This function maintains an internal state to buffer signal and accepts
accepts a single frame as input. If the input order changes a single frame as input. If the input order changes between time steps,
between time steps, call reorder_buffer. To apply to fresh call reorder_incremental_state. To apply to fresh inputs, call
inputs, call clear_buffer. clear_incremental_state.
""" """
if self.training: if self.training or not self._is_incremental_eval:
raise RuntimeError('LinearizedConvolution only supports inference') raise RuntimeError('incremental_forward only supports incremental evaluation')
# run forward pre hooks (e.g., weight norm) # run forward pre hooks (e.g., weight norm)
for hook in self._forward_pre_hooks.values(): for hook in self._forward_pre_hooks.values():
...@@ -65,10 +76,10 @@ class LinearizedConvolution(ConvTBC): ...@@ -65,10 +76,10 @@ class LinearizedConvolution(ConvTBC):
output = F.linear(input.view(bsz, -1), weight, self.bias) output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1) return output.view(bsz, 1, -1)
def clear_buffer(self): def clear_incremental_state(self):
self.input_buffer = None self.input_buffer = None
def reorder_buffer(self, new_order): def reorder_incremental_state(self, new_order):
if self.input_buffer is not None: if self.input_buffer is not None:
self.input_buffer = self.input_buffer.index_select(0, new_order) self.input_buffer = self.input_buffer.index_select(0, new_order)
......
...@@ -13,6 +13,7 @@ import torch.nn.functional as F ...@@ -13,6 +13,7 @@ import torch.nn.functional as F
from torch.autograd import Variable from torch.autograd import Variable
from fairseq import utils from fairseq import utils
from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object): class SequenceGenerator(object):
...@@ -36,9 +37,7 @@ class SequenceGenerator(object): ...@@ -36,9 +37,7 @@ class SequenceGenerator(object):
self.vocab_size = len(models[0].dst_dict) self.vocab_size = len(models[0].dst_dict)
self.beam_size = beam_size self.beam_size = beam_size
self.minlen = minlen self.minlen = minlen
self.maxlen = min(maxlen, *[m.decoder.max_positions() - self.pad - 2 for m in self.models]) self.maxlen = min(maxlen, *[m.decoder.max_positions() for m in self.models])
self.positions = torch.LongTensor(range(self.pad + 1, self.pad + self.maxlen + 2))
self.decoder_context = models[0].decoder.context_size()
self.stop_early = stop_early self.stop_early = stop_early
self.normalize_scores = normalize_scores self.normalize_scores = normalize_scores
self.len_penalty = len_penalty self.len_penalty = len_penalty
...@@ -46,10 +45,9 @@ class SequenceGenerator(object): ...@@ -46,10 +45,9 @@ class SequenceGenerator(object):
def cuda(self): def cuda(self):
for model in self.models: for model in self.models:
model.cuda() model.cuda()
self.positions = self.positions.cuda()
return self return self
def generate_batched_itr(self, data_itr, maxlen_a=0.0, maxlen_b=200, def generate_batched_itr(self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda_device=None, timer=None): cuda_device=None, timer=None):
"""Iterate over a batched dataset and yield individual translations. """Iterate over a batched dataset and yield individual translations.
...@@ -63,13 +61,16 @@ class SequenceGenerator(object): ...@@ -63,13 +61,16 @@ class SequenceGenerator(object):
def lstrip_pad(tensor): def lstrip_pad(tensor):
return tensor[tensor.eq(self.pad).sum():] return tensor[tensor.eq(self.pad).sum():]
if maxlen_b is None:
maxlen_b = self.maxlen
for sample in data_itr: for sample in data_itr:
s = utils.prepare_sample(sample, volatile=True, cuda_device=cuda_device) s = utils.prepare_sample(sample, volatile=True, cuda_device=cuda_device)
input = s['net_input'] input = s['net_input']
srclen = input['src_tokens'].size(1) srclen = input['src_tokens'].size(1)
if timer is not None: if timer is not None:
timer.start() timer.start()
hypos = self.generate(input['src_tokens'], input['src_positions'], hypos = self.generate(input['src_tokens'], beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b)) maxlen=int(maxlen_a*srclen + maxlen_b))
if timer is not None: if timer is not None:
timer.stop(s['ntokens']) timer.stop(s['ntokens'])
...@@ -79,14 +80,15 @@ class SequenceGenerator(object): ...@@ -79,14 +80,15 @@ class SequenceGenerator(object):
ref = lstrip_pad(s['target'].data[i, :]) ref = lstrip_pad(s['target'].data[i, :])
yield id, src, ref, hypos[i] yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_positions, beam_size=None, maxlen=None): def generate(self, src_tokens, beam_size=None, maxlen=None):
"""Generate a batch of translations.""" """Generate a batch of translations."""
with ExitStack() as stack: with ExitStack() as stack:
for model in self.models: for model in self.models:
stack.enter_context(model.decoder.incremental_inference()) if isinstance(model.decoder, FairseqIncrementalDecoder):
return self._generate(src_tokens, src_positions, beam_size, maxlen) stack.enter_context(model.decoder.incremental_inference())
return self._generate(src_tokens, beam_size, maxlen)
def _generate(self, src_tokens, src_positions, beam_size=None, maxlen=None): def _generate(self, src_tokens, beam_size=None, maxlen=None):
bsz = src_tokens.size(0) bsz = src_tokens.size(0)
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
...@@ -97,10 +99,11 @@ class SequenceGenerator(object): ...@@ -97,10 +99,11 @@ class SequenceGenerator(object):
encoder_outs = [] encoder_outs = []
for model in self.models: for model in self.models:
model.eval() model.eval()
model.decoder.start_fresh_sequence(beam_size) # start a fresh sequence if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.set_beam_size(beam_size)
# compute the encoder output and expand to beam size # compute the encoder output and expand to beam size
encoder_out = model.encoder(src_tokens, src_positions) encoder_out = model.encoder(src_tokens)
encoder_out = self._expand_encoder_out(encoder_out, beam_size) encoder_out = self._expand_encoder_out(encoder_out, beam_size)
encoder_outs.append(encoder_out) encoder_outs.append(encoder_out)
...@@ -215,7 +218,8 @@ class SequenceGenerator(object): ...@@ -215,7 +218,8 @@ class SequenceGenerator(object):
# reorder decoder internal states based on the prev choice of beams # reorder decoder internal states based on the prev choice of beams
if reorder_state is not None: if reorder_state is not None:
for model in self.models: for model in self.models:
model.decoder.reorder_incremental_state(reorder_state) if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.reorder_incremental_state(reorder_state)
probs, avg_attn_scores = self._decode(tokens[:, :step+1], encoder_outs) probs, avg_attn_scores = self._decode(tokens[:, :step+1], encoder_outs)
if step == 0: if step == 0:
...@@ -315,19 +319,16 @@ class SequenceGenerator(object): ...@@ -315,19 +319,16 @@ class SequenceGenerator(object):
return finalized return finalized
def _decode(self, tokens, encoder_outs): def _decode(self, tokens, encoder_outs):
length = tokens.size(1) # wrap in Variable
# repeat the first length positions to fill batch
positions = self.positions[:length].view(1, length)
# wrap in Variables
tokens = Variable(tokens, volatile=True) tokens = Variable(tokens, volatile=True)
positions = Variable(positions, volatile=True)
avg_probs = None avg_probs = None
avg_attn = None avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs): for model, encoder_out in zip(self.models, encoder_outs):
decoder_out, attn = model.decoder(tokens, positions, encoder_out) if isinstance(model.decoder, FairseqIncrementalDecoder):
decoder_out, attn = model.decoder.incremental_forward(tokens, encoder_out)
else:
decoder_out, attn = model.decoder.forward(tokens, encoder_out)
probs = F.softmax(decoder_out[:, -1, :]).data probs = F.softmax(decoder_out[:, -1, :]).data
attn = attn[:, -1, :].data attn = attn[:, -1, :].data
if avg_probs is None or avg_attn is None: if avg_probs is None or avg_attn is None:
......
...@@ -151,6 +151,6 @@ def prepare_sample(sample, volatile=False, cuda_device=None): ...@@ -151,6 +151,6 @@ def prepare_sample(sample, volatile=False, cuda_device=None):
'target': make_variable(sample['target']), 'target': make_variable(sample['target']),
'net_input': { 'net_input': {
key: make_variable(sample[key]) key: make_variable(sample[key])
for key in ['src_tokens', 'src_positions', 'input_tokens', 'input_positions'] for key in ['src_tokens', 'input_tokens']
}, },
} }
...@@ -52,19 +52,15 @@ def main(): ...@@ -52,19 +52,15 @@ def main():
if not args.interactive: if not args.interactive:
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset]))) print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
# Max positions is the model property but it is needed in data reader to be able to
# ignore too long sentences
args.max_positions = min(args.max_positions, *(m.decoder.max_positions() for m in models))
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
model.make_generation_fast_(not args.no_beamable_mm) model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
# Initialize generator # Initialize generator
translator = SequenceGenerator( translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop), models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen normalize_scores=(not args.unnormalized), len_penalty=args.lenpen)
)
if use_cuda: if use_cuda:
translator.cuda() translator.cuda()
...@@ -112,12 +108,9 @@ def main(): ...@@ -112,12 +108,9 @@ def main():
if args.interactive: if args.interactive:
for line in sys.stdin: for line in sys.stdin:
tokens = tokenizer.Tokenizer.tokenize(line, dataset.src_dict, add_if_not_exist=False).long() tokens = tokenizer.Tokenizer.tokenize(line, dataset.src_dict, add_if_not_exist=False).long()
start = dataset.src_dict.pad() + 1
positions = torch.arange(start, start + len(tokens)).type_as(tokens)
if use_cuda: if use_cuda:
positions = positions.cuda()
tokens = tokens.cuda() tokens = tokens.cuda()
translations = translator.generate(Variable(tokens.view(1, -1)), Variable(positions.view(1, -1))) translations = translator.generate(Variable(tokens.view(1, -1)))
hypos = translations[0] hypos = translations[0]
display_hypotheses(None, tokens, line, None, hypos[:min(len(hypos), args.nbest)]) display_hypotheses(None, tokens, line, None, hypos[:min(len(hypos), args.nbest)])
...@@ -132,8 +125,9 @@ def main(): ...@@ -132,8 +125,9 @@ def main():
# Generate and compute BLEU score # Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk()) scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
max_positions = min(model.max_encoder_positions() for model in models)
itr = dataset.dataloader(args.gen_subset, batch_size=args.batch_size, itr = dataset.dataloader(args.gen_subset, batch_size=args.batch_size,
max_positions=args.max_positions, max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test) skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
num_sentences = 0 num_sentences = 0
with progress_bar(itr, smoothing=0, leave=False) as t: with progress_bar(itr, smoothing=0, leave=False) as t:
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import unittest import unittest
from fairseq.modules import ConvTBC from fairseq.modules import ConvTBC
import torch.nn as nn import torch.nn as nn
from torch.autograd import Variable, gradcheck from torch.autograd import Variable
class TestConvTBC(unittest.TestCase): class TestConvTBC(unittest.TestCase):
...@@ -31,7 +31,7 @@ class TestConvTBC(unittest.TestCase): ...@@ -31,7 +31,7 @@ class TestConvTBC(unittest.TestCase):
output1d = conv1d(input1d) output1d = conv1d(input1d)
self.assertAlmostEqual(output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data) self.assertAlmostEqual(output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data)
grad_tbc = torch.randn(output_tbc.size()) grad_tbc = torch.randn(output_tbc.size())
grad1d = grad_tbc.transpose(0, 1).transpose(1, 2).contiguous() grad1d = grad_tbc.transpose(0, 1).transpose(1, 2).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment