Commit 6e4b7e22 authored by Myle Ott's avatar Myle Ott
Browse files

Refactor model definitions

* Move some functionality out of FConvModel into FairseqModel base class
* Move incremental decoding functionality into FairseqIncrementalDecoder module
* Refactor positional embeddings to be more specific to FConvModel
parent 820f796f
......@@ -7,7 +7,6 @@
#
from .cross_entropy import CrossEntropyCriterion
from .fairseq_criterion import FairseqCriterion
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
__all__ = [
......
......@@ -106,12 +106,14 @@ class LanguageDatasets(object):
max_positions=max_positions,
sort_by_source_size=sort_by_source_size)
elif split.startswith('valid'):
batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, dst=dataset.dst,
max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
batch_sampler = list(batches_by_size(
dataset.src, batch_size, max_tokens, dst=dataset.dst,
max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
else:
batch_sampler = list(batches_by_size(dataset.src, batch_size, max_tokens, max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
batch_sampler = list(batches_by_size(
dataset.src, batch_size, max_tokens, max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test))
return torch.utils.data.DataLoader(
dataset,
......@@ -137,6 +139,11 @@ def skip_group_enumerator(it, ngpus, offset=0):
class LanguagePairDataset(object):
# padding constants
LEFT_PAD_SOURCE = False
LEFT_PAD_TARGET = True
def __init__(self, src, dst, pad_idx, eos_idx):
self.src = src
self.dst = dst
......@@ -166,17 +173,13 @@ class LanguagePairDataset(object):
return LanguagePairDataset.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning)
def merge_positions(key, left_pad):
return LanguagePairDataset.collate_positions([s[key] for s in samples], pad_idx, left_pad)
ntokens = sum(len(s['target']) for s in samples)
return {
'id': torch.LongTensor([s['id'].item() for s in samples]),
'input_tokens': merge('target', left_pad=True, move_eos_to_beginning=True),
'input_positions': merge_positions('target', left_pad=True),
'target': merge('target', left_pad=True),
'src_tokens': merge('source', left_pad=False),
'src_positions': merge_positions('source', left_pad=False),
'input_tokens': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True),
'target': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET),
'src_tokens': merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE),
'ntokens': ntokens,
}
......@@ -201,18 +204,6 @@ class LanguagePairDataset(object):
copy_tensor(v, res[i][:len(v)])
return res
@staticmethod
def collate_positions(values, pad_idx, left_pad):
start = pad_idx + 1
size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(pad_idx)
for i, v in enumerate(values):
if left_pad:
torch.arange(start, start + len(v), out=res[i][size-len(v):])
else:
torch.arange(start, start + len(v), out=res[i][:len(v)])
return res
def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
max_positions=1024, ignore_invalid_inputs=False):
......@@ -243,15 +234,12 @@ def batches_by_size(src, batch_size=None, max_tokens=None, dst=None,
cur_max_size = 0
ignored = []
for idx in indices:
# - 2 here stems from make_positions() where we offset positions
# by padding_value + 1
if src.sizes[idx] < 2 or \
(False if dst is None else dst.sizes[idx] < 2) or \
sizes[idx] > max_positions - 2:
sizes[idx] > max_positions:
if ignore_invalid_inputs:
ignored.append(idx)
continue
raise Exception("Unable to handle input id {} of "
"size {} / {}.".format(idx, src.sizes[idx],
"none" if dst is None else dst.sizes[idx]))
......@@ -290,11 +278,9 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, epoch=1, sample=0,
sample_len = 0
ignored = []
for idx in indices:
# - 2 here stems from make_positions() where we offset positions
# by padding_value + 1
if src.sizes[idx] < 2 or dst.sizes[idx] < 2 or \
src.sizes[idx] > max_positions - 2 or \
dst.sizes[idx] > max_positions - 2:
src.sizes[idx] > max_positions or \
dst.sizes[idx] > max_positions:
ignored.append(idx)
continue
sample_len = max(sample_len, src.sizes[idx], dst.sizes[idx])
......
......@@ -17,6 +17,7 @@ class Dictionary(object):
self.symbols = []
self.count = []
self.indices = {}
# dictionary indexing starts at 1 for consistency with Lua
self.add_symbol('<Lua heritage>')
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
......
......@@ -6,6 +6,11 @@
# can be found in the PATENTS file in the same directory.
#
from .fairseq_decoder import FairseqDecoder
from .fairseq_encoder import FairseqEncoder
from .fairseq_incremental_decoder import FairseqIncrementalDecoder
from .fairseq_model import FairseqModel
from . import fconv
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch.nn as nn
class FairseqDecoder(nn.Module):
"""Base class for decoders."""
def __init__(self):
super().__init__()
def max_positions(self):
"""Maximum input length supported by the decoder."""
raise NotImplementedError
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch.nn as nn
class FairseqEncoder(nn.Module):
"""Base class for encoders."""
def __init__(self):
super().__init__()
def max_positions(self):
"""Maximum input length supported by the encoder."""
raise NotImplementedError
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch.nn as nn
from . import FairseqDecoder
class FairseqIncrementalDecoder(FairseqDecoder):
"""Base class for incremental decoders."""
def __init__(self):
super().__init__()
self._is_incremental_eval = False
self._incremental_state = {}
def forward(self, tokens, encoder_out):
raise NotImplementedError
def incremental_forward(self, tokens, encoder_out):
"""Forward pass for one time step."""
# keep only the last token for incremental forward pass
return self.forward(tokens[:, -1:], encoder_out)
def incremental_inference(self):
"""Context manager for incremental inference.
This provides an optimized forward pass for incremental inference
(i.e., it predicts one time step at a time). If the input order changes
between time steps, call reorder_incremental_state to update the
relevant buffers. To generate a fresh sequence, first call
clear_incremental_state.
Usage:
```
with model.decoder.incremental_inference():
for step in range(maxlen):
out, _ = model.decoder.incremental_forward(
tokens[:, :step], encoder_out)
probs = torch.nn.functional.log_softmax(out[:, -1, :])
```
"""
class IncrementalInference(object):
def __init__(self, decoder):
self.decoder = decoder
def __enter__(self):
self.decoder.incremental_eval(True)
def __exit__(self, *args):
self.decoder.incremental_eval(False)
return IncrementalInference(self)
def incremental_eval(self, mode=True):
"""Sets the decoder and all children in incremental evaluation mode."""
assert self._is_incremental_eval != mode, \
'incremental_eval already set to mode {}'.format(mode)
self._is_incremental_eval = mode
if mode:
self.clear_incremental_state()
def apply_incremental_eval(module):
if module != self and hasattr(module, 'incremental_eval'):
module.incremental_eval(mode)
self.apply(apply_incremental_eval)
def get_incremental_state(self, key):
"""Return cached state or None if not in incremental inference mode."""
if self._is_incremental_eval and key in self._incremental_state:
return self._incremental_state[key]
return None
def set_incremental_state(self, key, value):
"""Cache state needed for incremental inference mode."""
if self._is_incremental_eval:
self._incremental_state[key] = value
return value
def clear_incremental_state(self):
"""Clear all state used for incremental generation.
**For incremental inference only**
This should be called before generating a fresh sequence.
beam_size is required if using BeamableMM.
"""
if self._is_incremental_eval:
self._incremental_state = {}
def apply_clear_incremental_state(module):
if module != self and hasattr(module, 'clear_incremental_state'):
module.clear_incremental_state()
self.apply(apply_clear_incremental_state)
def reorder_incremental_state(self, new_order):
"""Reorder buffered internal state (for incremental generation).
**For incremental inference only**
This should be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the choice of beams.
"""
if self._is_incremental_eval:
def apply_reorder_incremental_state(module):
if module != self and hasattr(module, 'reorder_incremental_state'):
module.reorder_incremental_state(new_order)
self.apply(apply_reorder_incremental_state)
def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children."""
def apply_set_beam_size(module):
if module != self and hasattr(module, 'set_beam_size'):
module.set_beam_size(beam_size)
self.apply(apply_set_beam_size)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch.nn as nn
from . import FairseqDecoder, FairseqEncoder
class FairseqModel(nn.Module):
"""Base class for encoder-decoder models."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
assert isinstance(self.encoder, FairseqEncoder)
assert isinstance(self.decoder, FairseqDecoder)
self.src_dict = encoder.dictionary
self.dst_dict = decoder.dictionary
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
self._is_generation_fast = False
def forward(self, src_tokens, input_tokens):
encoder_out = self.encoder(src_tokens)
decoder_out, _ = self.decoder(input_tokens, encoder_out)
return decoder_out.view(-1, decoder_out.size(-1))
def max_encoder_positions(self):
"""Maximum input length supported by the encoder."""
return self.encoder.max_positions()
def max_decoder_positions(self):
"""Maximum output length supported by the decoder."""
return self.decoder.max_positions()
def make_generation_fast_(self, **kwargs):
"""Optimize model for faster generation."""
if self._is_generation_fast:
return # only apply once
self._is_generation_fast = True
# remove weight norm from all modules in the network
def apply_remove_weight_norm(module):
try:
nn.utils.remove_weight_norm(module)
except ValueError: # this module didn't have weight norm
return
self.apply(apply_remove_weight_norm)
def train(mode):
if mode:
raise RuntimeError('cannot train after make_generation_fast')
# this model should no longer be used for training
self.eval()
self.train = train
def apply_make_generation_fast_(module):
if module != self and hasattr(module, 'make_generation_fast_'):
module.make_generation_fast_(**kwargs)
self.apply(apply_make_generation_fast_)
......@@ -8,74 +8,42 @@
import math
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from fairseq.modules import BeamableMM, LinearizedConvolution
from fairseq.data import LanguagePairDataset
from fairseq.modules import BeamableMM, GradMultiply, LinearizedConvolution
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel
class FConvModel(nn.Module):
def __init__(self, encoder, decoder):
super(FConvModel, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.src_dict = encoder.dictionary
self.dst_dict = decoder.dictionary
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
self.encoder.num_attention_layers = sum([layer is not None for layer in decoder.attention])
self._is_generation_fast = False
def forward(self, src_tokens, src_positions, input_tokens, input_positions):
encoder_out = self.encoder(src_tokens, src_positions)
decoder_out = self.decoder(input_tokens, input_positions, encoder_out)
return decoder_out.view(-1, decoder_out.size(-1))
def make_positions(tokens, padding_idx, left_pad, offset=0):
seqlen = tokens.size(1)
if not hasattr(make_positions, 'range'):
make_positions.range = tokens.new()
if make_positions.range.numel() < offset + seqlen:
# offset positions by the padding index
torch.arange(padding_idx + 1, padding_idx + 1 + offset + seqlen,
out=make_positions.range)
mask = tokens.ne(padding_idx)
positions = make_positions.range[offset:offset+seqlen].expand_as(tokens)
if left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
return tokens.clone().masked_scatter_(mask, positions[mask])
def make_generation_fast_(self, use_beamable_mm=False):
"""Optimize model for faster generation.
Optimizations include:
- remove WeightNorm
- (optionally) use BeamableMM in attention layers
The optimized model should not be used again for training.
Note: this can be combined with incremental inference in the Decoder for
even faster generation.
"""
if self._is_generation_fast:
return # only apply once
self._is_generation_fast = True
# remove weight norm from all modules in the network
def remove_weight_norm(m):
try:
nn.utils.remove_weight_norm(m)
except ValueError: # this module didn't have weight norm
return
self.apply(remove_weight_norm)
# use BeamableMM in attention layers
if use_beamable_mm:
self.decoder._use_beamable_mm()
def train(mode):
if mode:
raise RuntimeError('cannot train after make_generation_fast')
# this model should no longer be used for training
self.eval()
self.train = train
class FConvModel(FairseqModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention)
class Encoder(nn.Module):
class FConvEncoder(FairseqEncoder):
"""Convolutional encoder"""
def __init__(self, dictionary, embed_dim=512, max_positions=1024,
convolutions=((512, 3),) * 20, dropout=0.1):
super(Encoder, self).__init__()
super().__init__()
self.dictionary = dictionary
self.dropout = dropout
self.num_attention_layers = None
......@@ -99,9 +67,12 @@ class Encoder(nn.Module):
in_channels = out_channels
self.fc2 = Linear(in_channels, embed_dim)
def forward(self, tokens, positions):
def forward(self, src_tokens):
positions = Variable(make_positions(src_tokens.data, self.dictionary.pad(),
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE))
# embed tokens and positions
x = self.embed_tokens(tokens) + self.embed_positions(positions)
x = self.embed_tokens(src_tokens) + self.embed_positions(positions)
x = F.dropout(x, p=self.dropout, training=self.training)
input_embedding = x
......@@ -126,17 +97,21 @@ class Encoder(nn.Module):
x = self.fc2(x)
# scale gradients (this only affects backward, not forward)
x = grad_multiply(x, 1.0 / (2.0 * self.num_attention_layers))
x = GradMultiply.apply(x, 1.0 / (2.0 * self.num_attention_layers))
# add output to input embedding for attention
y = (x + input_embedding) * math.sqrt(0.5)
return x, y
def max_positions(self):
"""Maximum input length supported by the encoder."""
return self.embed_positions.num_embeddings - self.dictionary.pad() - 1
class AttentionLayer(nn.Module):
def __init__(self, conv_channels, embed_dim, bmm=None):
super(AttentionLayer, self).__init__()
super().__init__()
# projects from output of convolution to embedding dimension
self.in_projection = Linear(conv_channels, embed_dim)
# projects from embedding dimension to convolution size
......@@ -167,13 +142,18 @@ class AttentionLayer(nn.Module):
x = (self.out_projection(x) + residual) * math.sqrt(0.5)
return x, attn_scores
def make_generation_fast_(self, beamable_mm_beam_size=None, **kwargs):
"""Replace torch.bmm with BeamableMM."""
if beamable_mm_beam_size is not None:
self.bmm = BeamableMM(beamable_mm_beam_size)
class Decoder(nn.Module):
class FConvDecoder(FairseqIncrementalDecoder):
"""Convolutional decoder"""
def __init__(self, dictionary, embed_dim=512, out_embed_dim=256,
max_positions=1024, convolutions=((512, 3),) * 20,
attention=True, dropout=0.1):
super(Decoder, self).__init__()
super().__init__()
self.dictionary = dictionary
self.dropout = dropout
......@@ -204,25 +184,38 @@ class Decoder(nn.Module):
self.fc2 = Linear(in_channels, out_embed_dim)
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
self._is_inference_incremental = False
def forward(self, input_tokens, encoder_out):
positions = Variable(make_positions(input_tokens.data, self.dictionary.pad(),
left_pad=LanguagePairDataset.LEFT_PAD_TARGET))
return self._forward(input_tokens, positions, encoder_out)
def incremental_forward(self, input_tokens, encoder_out):
"""Forward pass for one time step."""
# positions is the same for every token when decoding a single step
positions = Variable(input_tokens.data.new(1, 1).fill_(
self.dictionary.pad() + input_tokens.size(1)))
# keep only the last token for incremental forward pass
return self._forward(input_tokens[:, -1:], positions, encoder_out)
def _forward(self, input_tokens, positions, encoder_out):
# split and transpose encoder outputs
encoder_a, encoder_b = self._split_encoder_out(encoder_out)
def forward(self, tokens, positions, encoder_out):
# embed tokens and positions
x = self.embed_tokens(tokens) + self.embed_positions(positions)
x = self.embed_tokens(input_tokens) + self.embed_positions(positions)
x = F.dropout(x, p=self.dropout, training=self.training)
target_embedding = x
# project to size of convolution
x = self.fc1(x)
# transpose only once to speed up attention layers
encoder_a, encoder_b = encoder_out
encoder_a = encoder_a.transpose(1, 2).contiguous()
# B x T x C -> T x B x C
x = x.transpose(0, 1)
x = self._transpose_unless_incremental_eval(x)
# temporal convolutions
avg_attn_scores = None
num_attn_layers = len(self.attention)
for proj, conv, attention in zip(self.projections, self.convolutions, self.attention):
residual = x if proj is None else proj(x)
......@@ -233,172 +226,54 @@ class Decoder(nn.Module):
# attention
if attention is not None:
x = x.transpose(1, 0)
x, _ = attention(x, target_embedding, (encoder_a, encoder_b))
x = x.transpose(1, 0)
# residual
x = (x + residual) * math.sqrt(0.5)
# T x B x C -> B x T x C
x = x.transpose(1, 0)
# project back to size of vocabulary
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc3(x)
return x
def context_size(self):
"""Maximum number of input elements each output element depends on"""
context = 1
for conv in self.convolutions:
context += conv.kernel_size[0] - 1
return context
def max_positions(self):
"""Returns maximum size of positions embeddings supported by this decoder"""
return self.embed_positions.num_embeddings
def incremental_inference(self, beam_size=None):
"""Context manager for incremental inference.
This provides an optimized forward pass for incremental inference
(i.e., it predicts one time step at a time). If the input order changes
between time steps, call model.decoder.reorder_incremental_state to
update the relevant buffers. To generate a fresh sequence, first call
model.decoder.start_fresh_sequence.
Usage:
```
with model.decoder.incremental_inference():
for step in range(maxlen):
out = model.decoder(tokens[:, :step], positions[:, :step],
encoder_out)
probs = F.log_softmax(out[:, -1, :])
```
"""
class IncrementalInference(object):
def __init__(self, decoder, beam_size):
self.decoder = decoder
self.beam_size = beam_size
def __enter__(self):
self.decoder._start_incremental_inference(self.beam_size)
x = self._transpose_unless_incremental_eval(x)
def __exit__(self, *args):
self.decoder._stop_incremental_inference()
return IncrementalInference(self, beam_size)
def _start_incremental_inference(self, beam_size):
assert not self._is_inference_incremental, \
'already performing incremental inference'
self._is_inference_incremental = True
# save original forward
self._orig_forward = self.forward
# switch to incremental forward
self.forward = self._incremental_forward
# start a fresh sequence
self.start_fresh_sequence(beam_size)
def _stop_incremental_inference(self):
# restore original forward
self.forward = self._orig_forward
self._is_inference_incremental = False
def _incremental_forward(self, tokens, positions, encoder_out):
assert self._is_inference_incremental
# setup initial state
if self.prev_state is None:
# transpose encoder output once to speed up attention layers
encoder_a, encoder_b = encoder_out
encoder_a = encoder_a.transpose(1, 2).contiguous()
self.prev_state = {
'encoder_out': (encoder_a, encoder_b),
}
# load previous state
encoder_a, encoder_b = self.prev_state['encoder_out']
# keep only the last token for incremental forward pass
tokens = tokens[:, -1:]
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_tokens(tokens) + self.embed_positions(positions)
target_embedding = x
# project to size of convolution
x = self.fc1(x)
# temporal convolutions
avg_attn_scores = None
num_attn_layers = len(self.attention)
for proj, conv, attention in zip(self.projections, self.convolutions, self.attention):
residual = x if proj is None else proj(x)
x = conv.incremental_forward(x)
x = F.glu(x)
# attention
if attention is not None:
x, attn_scores = attention(x, target_embedding, (encoder_a, encoder_b))
attn_scores = attn_scores / num_attn_layers
if avg_attn_scores is None:
avg_attn_scores = attn_scores
else:
avg_attn_scores += attn_scores
avg_attn_scores.add_(attn_scores)
x = self._transpose_unless_incremental_eval(x)
# residual
x = (x + residual) * math.sqrt(0.5)
# T x B x C -> B x T x C
x = self._transpose_unless_incremental_eval(x)
# project back to size of vocabulary
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc3(x)
return x, avg_attn_scores
def start_fresh_sequence(self, beam_size=None):
"""Clear all state used for incremental generation.
def max_positions(self):
"""Maximum output length supported by the decoder."""
return self.embed_positions.num_embeddings - self.dictionary.pad() - 1
**For incremental inference only**
def _split_encoder_out(self, encoder_out):
"""Split and transpose encoder outputs.
This should be called before generating a fresh sequence.
beam_size is required if using BeamableMM.
This is cached when doing incremental inference.
"""
if self._is_inference_incremental:
self.prev_state = None
for conv in self.convolutions:
conv.clear_buffer()
for attn in self.attention:
if isinstance(attn.bmm, BeamableMM):
attn.bmm.set_beam_size(beam_size)
def reorder_incremental_state(self, new_order):
"""Reorder buffered internal state (for incremental generation).
**For incremental inference only**
This should be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the choice of beams.
"""
if self._is_inference_incremental:
for conv in self.convolutions:
conv.reorder_buffer(new_order)
cached_result = self.get_incremental_state('encoder_out')
if cached_result:
return cached_result
# transpose only once to speed up attention layers
encoder_a, encoder_b = encoder_out
encoder_a = encoder_a.transpose(1, 2).contiguous()
result = (encoder_a, encoder_b)
return self.set_incremental_state('encoder_out', result)
def _use_beamable_mm(self):
"""Replace torch.bmm with BeamableMM in attention layers."""
beamable_mm = BeamableMM()
for attn in self.attention:
attn.bmm = beamable_mm
def _transpose_unless_incremental_eval(self, x):
if self._is_incremental_eval:
return x
return x.transpose(0, 1)
def Embedding(num_embeddings, embedding_dim, padding_idx):
......@@ -434,23 +309,6 @@ def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
return nn.utils.weight_norm(m, dim=2)
def grad_multiply(x, scale):
return GradMultiply.apply(x, scale)
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
ctx.mark_shared_storage((x, res))
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
def get_archs():
return [
'fconv', 'fconv_iwslt_de_en', 'fconv_wmt_en_ro', 'fconv_wmt_en_de', 'fconv_wmt_en_fr',
......@@ -518,14 +376,14 @@ def parse_arch(args):
def build_model(args, src_dict, dst_dict):
encoder = Encoder(
encoder = FConvEncoder(
src_dict,
embed_dim=args.encoder_embed_dim,
convolutions=eval(args.encoder_layers),
dropout=args.dropout,
max_positions=args.max_positions,
)
decoder = Decoder(
decoder = FConvDecoder(
dst_dict,
embed_dim=args.decoder_embed_dim,
convolutions=eval(args.decoder_layers),
......
......@@ -8,8 +8,12 @@
from .beamable_mm import BeamableMM
from .conv_tbc import ConvTBC
from .grad_multiply import GradMultiply
from .linearized_convolution import LinearizedConvolution
__all__ = [
'BeamableMM', 'LinearizedConvolution', 'ConvTBC',
'BeamableMM',
'ConvTBC',
'GradMultiply',
'LinearizedConvolution',
]
......@@ -18,9 +18,9 @@ class BeamableMM(nn.Module):
inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)}
with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}.
"""
def __init__(self):
def __init__(self, beam_size=None):
super(BeamableMM, self).__init__()
self.beam_size = None
self.beam_size = beam_size
def forward(self, input1, input2):
if (
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
class GradMultiply(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
ctx.mark_shared_storage((x, res))
return res
@staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None
......@@ -14,33 +14,44 @@ from .conv_tbc import ConvTBC
class LinearizedConvolution(ConvTBC):
"""An optimized version of nn.Conv1d.
This module replaces convolutions with linear layers as appropriate
and supports optimizations for incremental inference.
At training time, this module uses ConvTBC, which is an optimized version
of Conv1d. At inference time, it optimizes incremental generation (i.e.,
one time step at a time) by replacing the convolutions with linear layers.
"""
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self.clear_buffer()
self._is_incremental_eval = False
self._linearized_weight = None
self.register_backward_hook(self._clear_linearized_weight)
def remove_future_timesteps(self, x):
"""Remove future time steps created by padding."""
if self.kernel_size[0] > 1 and self.padding[0] > 0:
if not self._is_incremental_eval and self.kernel_size[0] > 1 and self.padding[0] > 0:
x = x[:-self.padding[0], :, :]
return x
def incremental_eval(self, mode=True):
self._is_incremental_eval = mode
if mode:
self.clear_incremental_state()
def forward(self, input):
if self._is_incremental_eval:
return self.incremental_forward(input)
else:
return super().forward(input)
def incremental_forward(self, input):
"""Forward convolution one time step at a time.
This function maintains an internal state to buffer signal and
accepts a single frame as input. If the input order changes
between time steps, call reorder_buffer. To apply to fresh
inputs, call clear_buffer.
This function maintains an internal state to buffer signal and accepts
a single frame as input. If the input order changes between time steps,
call reorder_incremental_state. To apply to fresh inputs, call
clear_incremental_state.
"""
if self.training:
raise RuntimeError('LinearizedConvolution only supports inference')
if self.training or not self._is_incremental_eval:
raise RuntimeError('incremental_forward only supports incremental evaluation')
# run forward pre hooks (e.g., weight norm)
for hook in self._forward_pre_hooks.values():
......@@ -65,10 +76,10 @@ class LinearizedConvolution(ConvTBC):
output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1)
def clear_buffer(self):
def clear_incremental_state(self):
self.input_buffer = None
def reorder_buffer(self, new_order):
def reorder_incremental_state(self, new_order):
if self.input_buffer is not None:
self.input_buffer = self.input_buffer.index_select(0, new_order)
......
......@@ -13,6 +13,7 @@ import torch.nn.functional as F
from torch.autograd import Variable
from fairseq import utils
from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object):
......@@ -36,9 +37,7 @@ class SequenceGenerator(object):
self.vocab_size = len(models[0].dst_dict)
self.beam_size = beam_size
self.minlen = minlen
self.maxlen = min(maxlen, *[m.decoder.max_positions() - self.pad - 2 for m in self.models])
self.positions = torch.LongTensor(range(self.pad + 1, self.pad + self.maxlen + 2))
self.decoder_context = models[0].decoder.context_size()
self.maxlen = min(maxlen, *[m.decoder.max_positions() for m in self.models])
self.stop_early = stop_early
self.normalize_scores = normalize_scores
self.len_penalty = len_penalty
......@@ -46,10 +45,9 @@ class SequenceGenerator(object):
def cuda(self):
for model in self.models:
model.cuda()
self.positions = self.positions.cuda()
return self
def generate_batched_itr(self, data_itr, maxlen_a=0.0, maxlen_b=200,
def generate_batched_itr(self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda_device=None, timer=None):
"""Iterate over a batched dataset and yield individual translations.
......@@ -63,13 +61,16 @@ class SequenceGenerator(object):
def lstrip_pad(tensor):
return tensor[tensor.eq(self.pad).sum():]
if maxlen_b is None:
maxlen_b = self.maxlen
for sample in data_itr:
s = utils.prepare_sample(sample, volatile=True, cuda_device=cuda_device)
input = s['net_input']
srclen = input['src_tokens'].size(1)
if timer is not None:
timer.start()
hypos = self.generate(input['src_tokens'], input['src_positions'],
hypos = self.generate(input['src_tokens'], beam_size=beam_size,
maxlen=int(maxlen_a*srclen + maxlen_b))
if timer is not None:
timer.stop(s['ntokens'])
......@@ -79,14 +80,15 @@ class SequenceGenerator(object):
ref = lstrip_pad(s['target'].data[i, :])
yield id, src, ref, hypos[i]
def generate(self, src_tokens, src_positions, beam_size=None, maxlen=None):
def generate(self, src_tokens, beam_size=None, maxlen=None):
"""Generate a batch of translations."""
with ExitStack() as stack:
for model in self.models:
stack.enter_context(model.decoder.incremental_inference())
return self._generate(src_tokens, src_positions, beam_size, maxlen)
if isinstance(model.decoder, FairseqIncrementalDecoder):
stack.enter_context(model.decoder.incremental_inference())
return self._generate(src_tokens, beam_size, maxlen)
def _generate(self, src_tokens, src_positions, beam_size=None, maxlen=None):
def _generate(self, src_tokens, beam_size=None, maxlen=None):
bsz = src_tokens.size(0)
maxlen = min(maxlen, self.maxlen) if maxlen is not None else self.maxlen
......@@ -97,10 +99,11 @@ class SequenceGenerator(object):
encoder_outs = []
for model in self.models:
model.eval()
model.decoder.start_fresh_sequence(beam_size) # start a fresh sequence
if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.set_beam_size(beam_size)
# compute the encoder output and expand to beam size
encoder_out = model.encoder(src_tokens, src_positions)
encoder_out = model.encoder(src_tokens)
encoder_out = self._expand_encoder_out(encoder_out, beam_size)
encoder_outs.append(encoder_out)
......@@ -215,7 +218,8 @@ class SequenceGenerator(object):
# reorder decoder internal states based on the prev choice of beams
if reorder_state is not None:
for model in self.models:
model.decoder.reorder_incremental_state(reorder_state)
if isinstance(model.decoder, FairseqIncrementalDecoder):
model.decoder.reorder_incremental_state(reorder_state)
probs, avg_attn_scores = self._decode(tokens[:, :step+1], encoder_outs)
if step == 0:
......@@ -315,19 +319,16 @@ class SequenceGenerator(object):
return finalized
def _decode(self, tokens, encoder_outs):
length = tokens.size(1)
# repeat the first length positions to fill batch
positions = self.positions[:length].view(1, length)
# wrap in Variables
# wrap in Variable
tokens = Variable(tokens, volatile=True)
positions = Variable(positions, volatile=True)
avg_probs = None
avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs):
decoder_out, attn = model.decoder(tokens, positions, encoder_out)
if isinstance(model.decoder, FairseqIncrementalDecoder):
decoder_out, attn = model.decoder.incremental_forward(tokens, encoder_out)
else:
decoder_out, attn = model.decoder.forward(tokens, encoder_out)
probs = F.softmax(decoder_out[:, -1, :]).data
attn = attn[:, -1, :].data
if avg_probs is None or avg_attn is None:
......
......@@ -151,6 +151,6 @@ def prepare_sample(sample, volatile=False, cuda_device=None):
'target': make_variable(sample['target']),
'net_input': {
key: make_variable(sample[key])
for key in ['src_tokens', 'src_positions', 'input_tokens', 'input_positions']
for key in ['src_tokens', 'input_tokens']
},
}
......@@ -52,19 +52,15 @@ def main():
if not args.interactive:
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
# Max positions is the model property but it is needed in data reader to be able to
# ignore too long sentences
args.max_positions = min(args.max_positions, *(m.decoder.max_positions() for m in models))
# Optimize ensemble for generation
for model in models:
model.make_generation_fast_(not args.no_beamable_mm)
model.make_generation_fast_(
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)
# Initialize generator
translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen
)
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen)
if use_cuda:
translator.cuda()
......@@ -112,12 +108,9 @@ def main():
if args.interactive:
for line in sys.stdin:
tokens = tokenizer.Tokenizer.tokenize(line, dataset.src_dict, add_if_not_exist=False).long()
start = dataset.src_dict.pad() + 1
positions = torch.arange(start, start + len(tokens)).type_as(tokens)
if use_cuda:
positions = positions.cuda()
tokens = tokens.cuda()
translations = translator.generate(Variable(tokens.view(1, -1)), Variable(positions.view(1, -1)))
translations = translator.generate(Variable(tokens.view(1, -1)))
hypos = translations[0]
display_hypotheses(None, tokens, line, None, hypos[:min(len(hypos), args.nbest)])
......@@ -132,8 +125,9 @@ def main():
# Generate and compute BLEU score
scorer = bleu.Scorer(dataset.dst_dict.pad(), dataset.dst_dict.eos(), dataset.dst_dict.unk())
max_positions = min(model.max_encoder_positions() for model in models)
itr = dataset.dataloader(args.gen_subset, batch_size=args.batch_size,
max_positions=args.max_positions,
max_positions=max_positions,
skip_invalid_size_inputs_valid_test=args.skip_invalid_size_inputs_valid_test)
num_sentences = 0
with progress_bar(itr, smoothing=0, leave=False) as t:
......
......@@ -10,7 +10,7 @@ import torch
import unittest
from fairseq.modules import ConvTBC
import torch.nn as nn
from torch.autograd import Variable, gradcheck
from torch.autograd import Variable
class TestConvTBC(unittest.TestCase):
......@@ -31,7 +31,7 @@ class TestConvTBC(unittest.TestCase):
output1d = conv1d(input1d)
self.assertAlmostEqual(output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data)
grad_tbc = torch.randn(output_tbc.size())
grad1d = grad_tbc.transpose(0, 1).transpose(1, 2).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment