Commit b59815bc authored by Angela Fan's avatar Angela Fan Committed by Myle Ott
Browse files

added multiscale gated self attention layer with multiple heads, and pretrained fusion models

parent 50931d69
......@@ -12,6 +12,7 @@ from .fairseq_decoder import FairseqDecoder # noqa: F401
from .fairseq_encoder import FairseqEncoder # noqa: F401
from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401
from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401
from .composite_encoder import CompositeEncoder # noqa: F401
MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from . import FairseqEncoder
class CompositeEncoder(FairseqEncoder):
"""
Encoder class that forwards on multiple encoders, for example for a fusion model or question-answering
Accepts a dictionary of encoder, the first encoder's dictionary is used for initialization
"""
def __init__(self, encoders):
super().__init__(next(iter(encoders.values())).dictionary)
self.encoders = encoders
for key in self.encoders:
self.add_module(key, self.encoders[key])
def forward(self, src_tokens, src_lengths):
encoder_out = {}
for key in self.encoders:
encoder_out[key] = self.encoders[key](src_tokens, src_lengths)
return encoder_out
def max_positions(self):
return min([self.encoders[key].max_positions() for key in self.encoders])
def upgrade_state_dict(self, state_dict):
for key in self.encoders:
self.encoders[key].upgrade_state_dict(state_dict)
return state_dict
This diff is collapsed.
......@@ -8,19 +8,23 @@
from .adaptive_softmax import AdaptiveSoftmax
from .beamable_mm import BeamableMM
from .conv_tbc import ConvTBC
from .downsampled_multihead_attention import DownsampledMultiHeadAttention
from .grad_multiply import GradMultiply
from .learned_positional_embedding import LearnedPositionalEmbedding
from .linearized_convolution import LinearizedConvolution
from .multihead_attention import MultiheadAttention
from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
__all__ = [
'AdaptiveSoftmax',
'BeamableMM',
'ConvTBC',
'DownsampledMultiHeadAttention',
'GradMultiply',
'LearnedPositionalEmbedding',
'LinearizedConvolution',
'MultiheadAttention',
'ScalarBias',
'SinusoidalPositionalEmbedding',
]
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from fairseq.modules.scalar_bias import scalar_bias
class SingleHeadAttention(nn.Module):
"""
Single-head attention that supports Gating and Downsampling
"""
def __init__(
self,
out_channels,
embed_dim,
head_dim,
head_index,
dropout=0.,
bias=True,
project_input=True,
gated=False,
downsample=False,
num_heads=1
):
super().__init__()
self.embed_dim = embed_dim
self.dropout = dropout
self.head_index = head_index
self.head_dim = head_dim
self.project_input = project_input
self.gated = gated
self.downsample = downsample
self.num_heads = num_heads
self.projection = None
k_layers = []
v_layers = []
if self.downsample:
k_layers.append(Downsample(self.head_index))
v_layers.append(Downsample(self.head_index))
out_proj_size = self.head_dim
else:
out_proj_size = self.head_dim * self.num_heads
if self.gated:
k_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias))
self.in_proj_q = GatedLinear(self.embed_dim, out_proj_size, bias=bias)
v_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias))
else:
k_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias))
self.in_proj_q = Linear(self.embed_dim, out_proj_size, bias=bias)
v_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias))
self.in_proj_k = nn.Sequential(*k_layers)
self.in_proj_v = nn.Sequential(*v_layers)
if self.downsample:
self.out_proj = Linear(out_proj_size, self.head_dim, bias=bias)
else:
self.out_proj = Linear(out_proj_size, out_channels, bias=bias)
self.scaling = self.head_dim**-0.5
def forward(
self,
query,
key,
value,
mask_future_timesteps=False,
key_padding_mask=None,
use_scalar_bias=False
):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
src_len, bsz, out_channels = key.size()
tgt_len = query.size(0)
assert list(query.size()) == [tgt_len, bsz, out_channels]
assert key.size() == value.size()
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.downsample:
size = bsz
else:
size = bsz * self.num_heads
k = key
v = value
q = query
if self.project_input:
q = self.in_proj_q(q)
k = self.in_proj_k(k)
v = self.in_proj_v(v)
src_len = k.size()[0]
q *= self.scaling
if not self.downsample:
q = q.view(tgt_len, size, self.head_dim)
k = k.view(src_len, size, self.head_dim)
v = v.view(src_len, size, self.head_dim)
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
if mask_future_timesteps:
assert query.size() == key.size(), \
'mask_future_timesteps only applies to self-attention'
attn_weights *= Variable(torch.tril(
attn_weights.data.new([1]).expand(tgt_len, tgt_len).clone(),
diagonal=-1,
)[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0))
attn_weights += Variable(torch.triu(
attn_weights.data.new([-math.inf]).expand(tgt_len, tgt_len).clone(),
diagonal=0
)[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0))
tgt_size = tgt_len
if use_scalar_bias:
attn_weights = scalar_bias(attn_weights, 2)
v = scalar_bias(v, 1)
tgt_size += 1
if key_padding_mask is not None:
# don't attend to padding symbols
if key_padding_mask.max() > 0:
if self.downsample:
attn_weights = attn_weights.view(bsz, 1, tgt_len, src_len)
else:
attn_weights = attn_weights.view(size, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
-math.inf,
)
attn_weights = attn_weights.view(size, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn = torch.bmm(attn_weights, v)
if self.downsample:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.head_dim)
else:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
attn = self.out_proj(attn)
return attn, attn_weights
class DownsampledMultiHeadAttention(nn.ModuleList):
"""
Multi-headed attention with Gating and Downsampling
"""
def __init__(
self,
out_channels,
embed_dim,
num_heads,
dropout=0.,
bias=True,
project_input=True,
gated=False,
downsample=False
):
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.downsample = downsample
self.gated = gated
self.project_input = project_input
assert self.head_dim * num_heads == embed_dim
if self.downsample:
attention_heads = []
for index in range(self.num_heads):
attention_heads.append(SingleHeadAttention(out_channels, self.embed_dim, self.head_dim, index, self.dropout, bias, self.project_input, self.gated, self.downsample, self.num_heads))
super().__init__(modules=attention_heads)
self.out_proj = Linear(embed_dim, out_channels, bias=bias)
else:
# either we have a list of attention heads, or just one attention head
# if not being downsampled, we can do the heads with one linear layer instead of separate ones
super().__init__()
self.attention_module = SingleHeadAttention(out_channels, self.embed_dim, self.head_dim, 1, self.dropout, bias, self.project_input, self.gated, self.downsample, self.num_heads)
def forward(
self,
query,
key,
value,
mask_future_timesteps=False,
key_padding_mask=None,
use_scalar_bias=False
):
src_len, bsz, embed_dim = key.size()
tgt_len = query.size(0)
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
assert key.size() == value.size()
tgt_size = tgt_len
if use_scalar_bias:
tgt_size += 1
attn = []
attn_weights = []
if self.downsample:
for attention_head_number in range(self.num_heads):
# call the forward of each attention head
_attn, _attn_weight = self[attention_head_number](query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias)
attn.append(_attn)
attn_weights.append(_attn_weight)
full_attn = torch.cat(attn, dim=2)
full_attn = self.out_proj(full_attn)
return full_attn, attn_weights[0].clone()
else:
_attn, _attn_weight = self.attention_module(query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias)
attn.append(_attn)
attn_weights.append(_attn_weight)
full_attn = torch.cat(attn, dim=2)
full_attn_weights = torch.cat(attn_weights)
full_attn_weights = full_attn_weights.view(bsz, self.num_heads, tgt_size, src_len)
full_attn_weights = full_attn_weights.sum(dim=1) / self.num_heads
return full_attn, full_attn_weights
class Downsample(nn.Module):
"""
Selects every nth element, where n is the index
"""
def __init__(self, index):
super().__init__()
self.index = index
def forward(self, x):
return x[::self.index+1]
def Linear(in_features, out_features, dropout=0., bias=True):
"""Weight-normalized Linear layer (input: B x T x C)"""
m = nn.Linear(in_features, out_features, bias=bias)
m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features))
m.bias.data.zero_()
return nn.utils.weight_norm(m)
def GatedLinear(in_features, out_features, dropout=0., bias=True):
"""Weight-normalized Linear layer (input: B x T x C) with interspersed GLU units"""
return nn.Sequential(
Linear(in_features, out_features*4, dropout, bias),
nn.GLU(),
Linear(out_features*2, out_features*2, dropout, bias),
nn.GLU(),
Linear(out_features, out_features, dropout, bias)
)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch
class ScalarBias(torch.autograd.Function):
"""
Adds a vector of scalars, used in self-attention mechanism to allow
the model to optionally attend to this vector instead of the past
"""
@staticmethod
def forward(ctx, input, dim, bias_init):
size = list(input.size())
size[dim] += 1
output = input.new(*size).fill_(bias_init)
output.narrow(dim, 1, size[dim] - 1).copy_(input)
ctx.dim = dim
return output
@staticmethod
def backward(ctx, grad):
return grad.narrow(ctx.dim, 1, grad.size(ctx.dim) - 1), None, None
def scalar_bias(input, dim, bias_init=0):
return ScalarBias.apply(input, dim, bias_init)
......@@ -232,8 +232,8 @@ def add_checkpoint_args(parser):
def add_common_eval_args(group):
group.add_argument('--path', metavar='FILE', action='append',
help='path(s) to model file(s)')
group.add_argument('--path', metavar='FILE',
help='path(s) to model file(s), comma separated')
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring')
group.add_argument('--cpu', action='store_true', help='generate on CPU')
......@@ -259,6 +259,8 @@ def add_generation_args(parser):
group.add_argument('--max-len-b', default=200, type=int, metavar='N',
help=('generate sequences of maximum length ax + b, '
'where x is the source length'))
group.add_argument('--min-len', default=1, type=float, metavar='N',
help=('minimum generation length'))
group.add_argument('--no-early-stop', action='store_true',
help=('continue searching even after finalizing k=beam '
'hypotheses; this is more correct, but increases '
......@@ -279,6 +281,10 @@ def add_generation_args(parser):
help='initialize generation by target prefix of given length')
group.add_argument('--sampling', action='store_true',
help='sample hypotheses instead of using beam search')
group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS',
help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-temperature', default=1, type=float, metavar='N',
help='temperature for random sampling')
return group
......
......@@ -15,9 +15,9 @@ from fairseq.models import FairseqIncrementalDecoder
class SequenceGenerator(object):
def __init__(self, models, beam_size=1, minlen=1, maxlen=None,
stop_early=True, normalize_scores=True, len_penalty=1,
unk_penalty=0, retain_dropout=False, sampling=False):
unk_penalty=0, retain_dropout=False, sampling=False, sampling_topk=-1,
sampling_temperature=1):
"""Generates translations of a given source sentence.
Args:
min/maxlen: The length of the generated output will be bounded by
minlen and maxlen (not including the end-of-sentence marker).
......@@ -45,6 +45,8 @@ class SequenceGenerator(object):
self.unk_penalty = unk_penalty
self.retain_dropout = retain_dropout
self.sampling = sampling
self.sampling_topk = sampling_topk
self.sampling_temperature = sampling_temperature
def cuda(self):
for model in self.models:
......@@ -54,7 +56,6 @@ class SequenceGenerator(object):
def generate_batched_itr(self, data_itr, beam_size=None, maxlen_a=0.0, maxlen_b=None,
cuda=False, timer=None, prefix_size=0):
"""Iterate over a batched dataset and yield individual translations.
Args:
maxlen_a/b: generate sequences of maximum length ax + b,
where x is the source sentence length.
......@@ -169,11 +170,9 @@ class SequenceGenerator(object):
"""
Finalize the given hypotheses at this step, while keeping the total
number of finalized hypotheses per sentence <= beam_size.
Note: the input must be in the desired finalization order, so that
hypotheses that appear earlier in the input are preferred to those
that appear later.
Args:
step: current time step
bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
......@@ -221,7 +220,6 @@ class SequenceGenerator(object):
# remove padding tokens from attn scores
nonpad_idxs = src_tokens[sent].ne(self.pad)
hypo_attn = attn_clone[i][nonpad_idxs]
_, alignment = hypo_attn.max(dim=0)
return {
......@@ -303,15 +301,29 @@ class SequenceGenerator(object):
cand_beams.resize_as_(cand_indices).fill_(0)
elif self.sampling:
assert self.pad == 1, 'sampling assumes the first two symbols can be ignored'
exp_probs = probs.exp_().view(-1, self.vocab_size)
if self.sampling_topk > 0:
values, indices = probs[:, 2:].topk(self.sampling_topk)
exp_probs = values.div_(self.sampling_temperature).exp()
if step == 0:
torch.multinomial(exp_probs, beam_size, replacement=True, out=cand_indices)
else:
torch.multinomial(exp_probs, 1, replacement=True, out=cand_indices)
torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
torch.gather(indices, dim=1, index=cand_indices, out=cand_indices)
cand_indices.add_(2)
else:
exp_probs = probs.div_(self.sampling_temperature).exp_().view(-1, self.vocab_size)
if step == 0:
# we exclude the first two vocab items, one of which is pad
torch.multinomial(exp_probs[:, 2:], beam_size, replacement=True, out=cand_indices)
cand_indices.add_(2)
else:
torch.multinomial(exp_probs[:, 2:], 1, replacement=True, out=cand_indices)
cand_indices.add_(2)
torch.gather(exp_probs, dim=1, index=cand_indices, out=cand_scores)
cand_scores.log_()
cand_indices = cand_indices.view(bsz, -1).repeat(1, 2)
cand_scores = cand_scores.view(bsz, -1).repeat(1, 2)
......@@ -489,6 +501,7 @@ class SequenceGenerator(object):
avg_probs = None
avg_attn = None
for model, encoder_out in zip(self.models, encoder_outs):
with utils.maybe_no_grad():
if incremental_states[model] is not None:
......@@ -497,6 +510,7 @@ class SequenceGenerator(object):
decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1, :]
attn = decoder_out[1]
probs = model.get_normalized_probs(decoder_out, log_probs=False).data
if avg_probs is None:
avg_probs = probs
......
......@@ -157,7 +157,7 @@ def load_ensemble_for_inference(filenames, src_dict=None, dst_dict=None,
ensemble = []
for state in states:
model = models.build_model(args, src_dict, dst_dict)
model.load_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)
return ensemble, args
......
......@@ -31,8 +31,9 @@ def main(args):
dataset = data_loaders.load_dataset(args, [args.gen_subset], args.replace_unk is not None)
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path)))
models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)
print('| loading model(s) from {}'.format(args.path))
model_paths = args.path.split(',')
models, _ = utils.load_ensemble_for_inference(model_paths, dataset.src_dict, dataset.dst_dict)
print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))
......@@ -70,7 +71,9 @@ def main(args):
translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen, sampling=args.sampling)
unk_penalty=args.unkpen, sampling=args.sampling, sampling_topk=args.sampling_topk,
minlen=args.min_len)
if use_cuda:
translator.cuda()
......
......@@ -64,8 +64,9 @@ def main(args):
use_cuda = torch.cuda.is_available() and not args.cpu
# Load ensemble
print('| loading model(s) from {}'.format(', '.join(args.path)))
models, model_args = utils.load_ensemble_for_inference(args.path, data_dir=args.data)
print('| loading model(s) from {}'.format(args.path))
model_paths = args.path.split(',')
models, model_args = utils.load_ensemble_for_inference(model_paths, data_dir=args.data)
src_dict, dst_dict = models[0].src_dict, models[0].dst_dict
print('| [{}] dictionary: {} types'.format(model_args.source_lang, len(src_dict)))
......@@ -81,7 +82,9 @@ def main(args):
translator = SequenceGenerator(
models, beam_size=args.beam, stop_early=(not args.no_early_stop),
normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
unk_penalty=args.unkpen, sampling=args.sampling)
unk_penalty=args.unkpen, sampling=args.sampling, sampling_topk=args.sampling_topk,
minlen=args.min_len)
if use_cuda:
translator.cuda()
......
......@@ -40,8 +40,8 @@ def main(args):
for split in splits:
print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))
# Build model and criterion
model = models.build_model(args, dataset.src_dict, dataset.dst_dict)
criterion = criterions.build_criterion(args, dataset.src_dict, dataset.dst_dict)
print('| model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
print('| num. model params: {}'.format(sum(p.data.numel() for p in model.parameters())))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment