Commit cfd2a3a0 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

core changes to support latte collab

parent fbe8ce65
...@@ -42,11 +42,14 @@ class AdaptiveLoss(FairseqCriterion): ...@@ -42,11 +42,14 @@ class AdaptiveLoss(FairseqCriterion):
adaptive_softmax = model.decoder.adaptive_softmax adaptive_softmax = model.decoder.adaptive_softmax
net_output = model(**sample['net_input']) net_output = model(**sample['net_input'])
target = model.get_targets(sample, net_output).view(-1) orig_target = model.get_targets(sample, net_output)
bsz = target.size(0) nsentences = orig_target.size(0)
orig_target = orig_target.view(-1)
logits, target = adaptive_softmax(net_output[0], target) bsz = orig_target.size(0)
logits, target = adaptive_softmax(net_output[0], orig_target)
assert len(target) == len(logits) assert len(target) == len(logits)
loss = net_output[0].new(1 if reduce else bsz).zero_() loss = net_output[0].new(1 if reduce else bsz).zero_()
...@@ -57,11 +60,13 @@ class AdaptiveLoss(FairseqCriterion): ...@@ -57,11 +60,13 @@ class AdaptiveLoss(FairseqCriterion):
loss += F.cross_entropy(logits[i], target[i], size_average=False, ignore_index=self.padding_idx, loss += F.cross_entropy(logits[i], target[i], size_average=False, ignore_index=self.padding_idx,
reduce=reduce) reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] orig = utils.strip_pad(orig_target, self.padding_idx)
ntokens = orig.numel()
sample_size = sample['target'].size(0) if self.args.sentence_avg else ntokens
logging_output = { logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data, 'loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'], 'ntokens': ntokens,
'nsentences': sample['target'].size(0), 'nsentences': nsentences,
'sample_size': sample_size, 'sample_size': sample_size,
} }
return loss, sample_size, logging_output return loss, sample_size, logging_output
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from .dictionary import Dictionary from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset from .fairseq_dataset import FairseqDataset
from .indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset from .indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset from .language_pair_dataset import LanguagePairDataset
......
...@@ -199,3 +199,20 @@ class Dictionary(object): ...@@ -199,3 +199,20 @@ class Dictionary(object):
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long() t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
t[-1] = self.eos() t[-1] = self.eos()
return t return t
class TruncatedDictionary(object):
def __init__(self, wrapped_dict, length):
self.__class__ = type(dict.__class__.__name__,
(self.__class__, dict.__class__), {})
self.__dict__ = dict.__dict__
self.wrapped_dict = wrapped_dict
self.length = min(len(self.wrapped_dict), length)
def __len__(self):
return self.length
def __getitem__(self, i):
if i < self.length:
return self.wrapped_dict[i]
return self.wrapped_dict.unk()
...@@ -61,6 +61,7 @@ def collate( ...@@ -61,6 +61,7 @@ def collate(
'src_lengths': src_lengths, 'src_lengths': src_lengths,
}, },
'target': target, 'target': target,
'nsentences': samples[0]['source'].size(0),
} }
if prev_output_tokens is not None: if prev_output_tokens is not None:
batch['net_input']['prev_output_tokens'] = prev_output_tokens batch['net_input']['prev_output_tokens'] = prev_output_tokens
......
...@@ -9,27 +9,39 @@ import numpy as np ...@@ -9,27 +9,39 @@ import numpy as np
import torch import torch
from . import data_utils, FairseqDataset from . import data_utils, FairseqDataset
from typing import List
def collate(samples, pad_idx, eos_idx): def collate(samples, pad_idx, eos_idx):
if len(samples) == 0: if len(samples) == 0:
return {} return {}
def merge(key): def merge(key, is_list=False):
return data_utils.collate_tokens( if is_list:
[s[key] for s in samples], pad_idx, eos_idx, left_pad=False, res = []
) for i in range(len(samples[0][key])):
res.append(data_utils.collate_tokens(
[s[key][i] for s in samples], pad_idx, eos_idx, left_pad=False,
))
return res
else:
return data_utils.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad=False,
)
is_target_list = isinstance(samples[0]['target'], list)
return { return {
'id': torch.LongTensor([s['id'] for s in samples]), 'id': torch.LongTensor([s['id'] for s in samples]),
'ntokens': sum(len(s['target']) for s in samples), 'ntokens': sum(len(s['source']) for s in samples),
'net_input': { 'net_input': {
'src_tokens': merge('source'), 'src_tokens': merge('source'),
'src_lengths': torch.LongTensor([ 'src_lengths': torch.LongTensor([
s['source'].numel() for s in samples s['source'].numel() for s in samples
]), ]),
}, },
'target': merge('target'), 'target': merge('target', is_target_list),
'nsentences': samples[0]['source'].size(0),
} }
...@@ -42,22 +54,78 @@ class MonolingualDataset(FairseqDataset): ...@@ -42,22 +54,78 @@ class MonolingualDataset(FairseqDataset):
sizes (List[int]): sentence lengths sizes (List[int]): sentence lengths
vocab (~fairseq.data.Dictionary): vocabulary vocab (~fairseq.data.Dictionary): vocabulary
shuffle (bool, optional): shuffle the elements before batching. shuffle (bool, optional): shuffle the elements before batching.
Default: ``True`` Default: ``True``
""" """
def __init__(self, dataset, sizes, vocab, shuffle=True): def __init__(self, dataset, sizes, src_vocab, tgt_vocab, add_eos_for_other_targets, shuffle,
targets=None):
self.dataset = dataset self.dataset = dataset
self.sizes = np.array(sizes) self.sizes = np.array(sizes)
self.vocab = vocab self.vocab = src_vocab
self.tgt_vocab = tgt_vocab
self.add_eos_for_other_targets = add_eos_for_other_targets
self.shuffle = shuffle self.shuffle = shuffle
assert targets is None or all(
t in {'self', 'future', 'past'} for t in targets), "targets must be none or one of 'self', 'future', 'past'"
if targets is not None and len(targets) == 0:
targets = None
self.targets = targets
def __getitem__(self, index): def __getitem__(self, index):
source, target = self.dataset[index] source, future_target, past_target = self.dataset[index]
source, target = self._make_source_target(source, future_target, past_target)
return {'id': index, 'source': source, 'target': target} return {'id': index, 'source': source, 'target': target}
def __len__(self): def __len__(self):
return len(self.dataset) return len(self.dataset)
def _make_source_target(self, source, future_target, past_target):
if self.targets is not None:
target = []
if self.add_eos_for_other_targets and (('self' in self.targets) or ('past' in self.targets)) \
and source[-1] != self.vocab.eos():
# append eos at the end of source
source = torch.cat([source, source.new([self.vocab.eos()])])
if 'future' in self.targets:
future_target = torch.cat([future_target, future_target.new([self.vocab.pad()])])
if 'past' in self.targets:
# first token is before the start of sentence which is only used in "none" break mode when
# add_eos_for_other_targets is False
past_target = torch.cat([past_target.new([self.vocab.pad()]), past_target[1:], source[-2, None]])
for t in self.targets:
if t == 'self':
target.append(source)
elif t == 'future':
target.append(future_target)
elif t == 'past':
target.append(past_target)
else:
raise Exception('invalid target ' + t)
if len(target) == 1:
target = target[0]
else:
target = future_target
return source, self._filter_vocab(target)
def _filter_vocab(self, target):
if len(self.tgt_vocab) != len(self.vocab):
def _filter(target):
mask = target.ge(len(self.tgt_vocab))
if mask.any():
target[mask] = self.tgt_vocab.unk()
return target
if isinstance(target, list):
return [_filter(t) for t in target]
return _filter(target)
return target
def collater(self, samples): def collater(self, samples):
"""Merge a list of samples to form a mini-batch. """Merge a list of samples to form a mini-batch.
...@@ -86,8 +154,10 @@ class MonolingualDataset(FairseqDataset): ...@@ -86,8 +154,10 @@ class MonolingualDataset(FairseqDataset):
if isinstance(max_positions, float) or isinstance(max_positions, int): if isinstance(max_positions, float) or isinstance(max_positions, int):
tgt_len = min(tgt_len, max_positions) tgt_len = min(tgt_len, max_positions)
bsz = num_tokens // tgt_len bsz = num_tokens // tgt_len
target = self.vocab.dummy_sentence(tgt_len + 1) target = self.vocab.dummy_sentence(tgt_len + 2)
source, target = target[:-1], target[1:] source, past_target, future_target = target[1:-1], target[2:], target[:-2]
source, target = self._make_source_target(source, past_target, future_target)
return self.collater([ return self.collater([
{'id': i, 'source': source, 'target': target} {'id': i, 'source': source, 'target': target}
for i in range(bsz) for i in range(bsz)
......
...@@ -29,11 +29,13 @@ class TokenBlockDataset(torch.utils.data.Dataset): ...@@ -29,11 +29,13 @@ class TokenBlockDataset(torch.utils.data.Dataset):
include_targets: return next tokens as targets include_targets: return next tokens as targets
""" """
def __init__(self, tokens, sizes, block_size, break_mode=None, include_targets=False): def __init__(self, tokens, sizes, block_size, pad, eos, break_mode=None, include_targets=False):
super().__init__() super().__init__()
self.tokens = tokens self.tokens = tokens
self.total_size = len(tokens) self.total_size = len(tokens)
self.pad = pad
self.eos = eos
self.include_targets = include_targets self.include_targets = include_targets
self.slice_indices = [] self.slice_indices = []
...@@ -81,12 +83,18 @@ class TokenBlockDataset(torch.utils.data.Dataset): ...@@ -81,12 +83,18 @@ class TokenBlockDataset(torch.utils.data.Dataset):
if self.include_targets: if self.include_targets:
# target is the sentence, for source, rotate item one token to the left (would start with eos) # target is the sentence, for source, rotate item one token to the left (would start with eos)
# past target is rotated to the left by 2 (padded if its first)
if s == 0: if s == 0:
source = np.concatenate([self.tokens[-1:], self.tokens[0:e - 1]]) source = np.concatenate([[self.eos], self.tokens[0:e - 1]])
past_target = np.concatenate([[self.pad, self.eos], self.tokens[0:e - 2]])
else: else:
source = self.tokens[s - 1:e - 1] source = self.tokens[s - 1:e - 1]
if s == 1:
past_target = np.concatenate([[self.eos], self.tokens[0:e - 2]])
else:
past_target = self.tokens[s - 2:e - 2]
return torch.LongTensor(source), item return torch.LongTensor(source), item, torch.LongTensor(past_target)
return item return item
def __len__(self): def __len__(self):
......
...@@ -65,6 +65,9 @@ class BaseFairseqModel(nn.Module): ...@@ -65,6 +65,9 @@ class BaseFairseqModel(nn.Module):
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
"""Upgrade old state dicts to work with newer code.""" """Upgrade old state dicts to work with newer code."""
self.upgrade_state_dict_named(state_dict, '')
def upgrade_state_dict_named(self, state_dict, name):
assert state_dict is not None assert state_dict is not None
def do_upgrade(m, prefix): def do_upgrade(m, prefix):
...@@ -79,7 +82,7 @@ class BaseFairseqModel(nn.Module): ...@@ -79,7 +82,7 @@ class BaseFairseqModel(nn.Module):
c.upgrade_state_dict(state_dict) c.upgrade_state_dict(state_dict)
do_upgrade(c, name) do_upgrade(c, name)
do_upgrade(self, '') do_upgrade(self, name)
def make_generation_fast_(self, **kwargs): def make_generation_fast_(self, **kwargs):
"""Optimize model for faster generation.""" """Optimize model for faster generation."""
...@@ -196,3 +199,7 @@ class FairseqLanguageModel(BaseFairseqModel): ...@@ -196,3 +199,7 @@ class FairseqLanguageModel(BaseFairseqModel):
def max_positions(self): def max_positions(self):
"""Maximum length supported by the model.""" """Maximum length supported by the model."""
return self.decoder.max_positions() return self.decoder.max_positions()
@property
def supported_targets(self):
return {'future'}
...@@ -213,7 +213,7 @@ class TransformerLanguageModel(FairseqLanguageModel): ...@@ -213,7 +213,7 @@ class TransformerLanguageModel(FairseqLanguageModel):
else: else:
embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad()) embed_tokens = Embedding(len(task.dictionary), args.decoder_input_dim, task.dictionary.pad())
decoder = TransformerDecoder(args, task.dictionary, embed_tokens, no_encoder_attn=True, final_norm=False) decoder = TransformerDecoder(args, task.output_dictionary, embed_tokens, no_encoder_attn=True, final_norm=False)
return TransformerLanguageModel(decoder) return TransformerLanguageModel(decoder)
...@@ -442,6 +442,8 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -442,6 +442,8 @@ class TransformerDecoder(FairseqIncrementalDecoder):
x = x.transpose(0, 1) x = x.transpose(0, 1)
attn = None attn = None
inner_states = [x]
# decoder layers # decoder layers
for layer in self.layers: for layer in self.layers:
x, attn = layer( x, attn = layer(
...@@ -449,7 +451,9 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -449,7 +451,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
encoder_out['encoder_out'] if encoder_out is not None else None, encoder_out['encoder_out'] if encoder_out is not None else None,
encoder_out['encoder_padding_mask'] if encoder_out is not None else None, encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
incremental_state, incremental_state,
self_attn_mask=self.buffered_future_mask(x) if incremental_state is None else None,
) )
inner_states.append(x)
if self.normalize: if self.normalize:
x = self.layer_norm(x) x = self.layer_norm(x)
...@@ -467,7 +471,7 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -467,7 +471,7 @@ class TransformerDecoder(FairseqIncrementalDecoder):
else: else:
x = F.linear(x, self.embed_out) x = F.linear(x, self.embed_out)
return x, attn return x, {'attn': attn, 'inner_states': inner_states}
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
...@@ -475,6 +479,14 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -475,6 +479,14 @@ class TransformerDecoder(FairseqIncrementalDecoder):
return self.max_target_positions return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions()) return min(self.max_target_positions, self.embed_positions.max_positions())
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device:
self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
if self._future_mask.size(0) < dim:
self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1)
return self._future_mask[:dim, :dim]
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
"""Upgrade a (possibly old) state dict for new versions of fairseq.""" """Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
...@@ -615,7 +627,8 @@ class TransformerDecoderLayer(nn.Module): ...@@ -615,7 +627,8 @@ class TransformerDecoderLayer(nn.Module):
self.final_layer_norm = LayerNorm(self.embed_dim) self.final_layer_norm = LayerNorm(self.embed_dim)
self.need_attn = True self.need_attn = True
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state): def forward(self, x, encoder_out, encoder_padding_mask, incremental_state, self_attn_mask=None,
self_attn_padding_mask=None):
""" """
Args: Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
...@@ -631,9 +644,10 @@ class TransformerDecoderLayer(nn.Module): ...@@ -631,9 +644,10 @@ class TransformerDecoderLayer(nn.Module):
query=x, query=x,
key=x, key=x,
value=x, value=x,
mask_future_timesteps=True, key_padding_mask=self_attn_padding_mask,
incremental_state=incremental_state, incremental_state=incremental_state,
need_weights=False, need_weights=False,
attn_mask=self_attn_mask,
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
...@@ -728,7 +742,6 @@ def base_lm_architecture(args): ...@@ -728,7 +742,6 @@ def base_lm_architecture(args):
# The model training is not stable without this # The model training is not stable without this
args.decoder_normalize_before = True args.decoder_normalize_before = True
@register_model_architecture('transformer_lm', 'transformer_lm_big') @register_model_architecture('transformer_lm', 'transformer_lm_big')
def transformer_lm_big(args): def transformer_lm_big(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024) args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024)
...@@ -740,7 +753,7 @@ def transformer_lm_big(args): ...@@ -740,7 +753,7 @@ def transformer_lm_big(args):
@register_model_architecture('transformer_lm', 'transformer_lm_wiki103') @register_model_architecture('transformer_lm', 'transformer_lm_wiki103')
def transformer_lm_wiki103(args): def transformer_lm_wiki103(args):
args.dropout = getattr(args, 'dropout', 0.3) args.dropout = getattr(args, 'dropout', 0.3)
base_lm_architecture(args) transformer_lm_big(args)
@register_model_architecture('transformer_lm', 'transformer_lm_gbw') @register_model_architecture('transformer_lm', 'transformer_lm_gbw')
......
...@@ -18,23 +18,31 @@ class MultiheadAttention(nn.Module): ...@@ -18,23 +18,31 @@ class MultiheadAttention(nn.Module):
See "Attention Is All You Need" for more details. See "Attention Is All You Need" for more details.
""" """
def __init__(self, embed_dim, num_heads, dropout=0., bias=True):
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
self.head_dim = embed_dim // num_heads self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim ** -0.5
self._mask = None
self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim)) self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
if bias: if bias:
self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim)) self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
else: else:
self.register_parameter('in_proj_bias', None) self.register_parameter('in_proj_bias', None)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
if add_bias_kv:
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
else:
self.bias_k = self.bias_v = None
self.add_zero_attn = add_zero_attn
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
...@@ -43,15 +51,18 @@ class MultiheadAttention(nn.Module): ...@@ -43,15 +51,18 @@ class MultiheadAttention(nn.Module):
if self.in_proj_bias is not None: if self.in_proj_bias is not None:
nn.init.constant_(self.in_proj_bias, 0.) nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj.bias, 0.) nn.init.constant_(self.out_proj.bias, 0.)
if self.bias_k is not None:
nn.init.xavier_normal_(self.bias_k)
if self.bias_v is not None:
nn.init.xavier_normal_(self.bias_v)
def forward(self, query, key, value, mask_future_timesteps=False, def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
key_padding_mask=None, incremental_state=None, need_weights=True, static_kv=False, attn_mask=None):
need_weights=True, static_kv=False):
"""Input shape: Time x Batch x Channel """Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the query, key and value. Timesteps can be masked by supplying a T x T mask in the
`mask_future_timesteps` argument. Padding elements can be excluded from `attn_mask` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape: the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s. batch x src_len, where padding elements are indicated by 1s.
""" """
...@@ -103,24 +114,40 @@ class MultiheadAttention(nn.Module): ...@@ -103,24 +114,40 @@ class MultiheadAttention(nn.Module):
saved_state['prev_value'] = v saved_state['prev_value'] = v
self._set_input_buffer(incremental_state, saved_state) self._set_input_buffer(incremental_state, saved_state)
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
src_len = k.size(0) src_len = k.size(0)
if key_padding_mask is not None: if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len assert key_padding_mask.size(1) == src_len
q = q.contiguous().view(tgt_len, bsz*self.num_heads, self.head_dim).transpose(0, 1) q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
k = k.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1) k = k.contiguous().view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
v = v.contiguous().view(src_len, bsz*self.num_heads, self.head_dim).transpose(0, 1) v = v.contiguous().view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if self.add_zero_attn:
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
if key_padding_mask is not None:
key_padding_mask = torch.cat([key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
attn_weights = torch.bmm(q, k.transpose(1, 2)) attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
# only apply masking at training time (when incremental state is None) if attn_mask is not None:
if mask_future_timesteps and incremental_state is None: attn_weights += attn_mask.unsqueeze(0)
assert query.size() == key.size(), \
'mask_future_timesteps only applies to self-attention'
attn_weights += self.buffered_mask(attn_weights).unsqueeze(0)
if key_padding_mask is not None: if key_padding_mask is not None:
# don't attend to padding symbols # don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
...@@ -129,6 +156,7 @@ class MultiheadAttention(nn.Module): ...@@ -129,6 +156,7 @@ class MultiheadAttention(nn.Module):
float('-inf'), float('-inf'),
).type_as(attn_weights) # FP16 support: cast to float and back ).type_as(attn_weights) # FP16 support: cast to float and back
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights) attn_weights = F.softmax(attn_weights.float(), dim=-1).type_as(attn_weights)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
...@@ -156,10 +184,10 @@ class MultiheadAttention(nn.Module): ...@@ -156,10 +184,10 @@ class MultiheadAttention(nn.Module):
return self._in_proj(query, end=self.embed_dim) return self._in_proj(query, end=self.embed_dim)
def in_proj_k(self, key): def in_proj_k(self, key):
return self._in_proj(key, start=self.embed_dim, end=2*self.embed_dim) return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
def in_proj_v(self, value): def in_proj_v(self, value):
return self._in_proj(value, start=2*self.embed_dim) return self._in_proj(value, start=2 * self.embed_dim)
def _in_proj(self, input, start=0, end=None): def _in_proj(self, input, start=0, end=None):
weight = self.in_proj_weight weight = self.in_proj_weight
...@@ -169,14 +197,6 @@ class MultiheadAttention(nn.Module): ...@@ -169,14 +197,6 @@ class MultiheadAttention(nn.Module):
bias = bias[start:end] bias = bias[start:end]
return F.linear(input, weight, bias) return F.linear(input, weight, bias)
def buffered_mask(self, tensor):
dim = tensor.size(-1)
if self._mask is None:
self._mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1)
if self._mask.size(0) < dim:
self._mask = torch.triu(utils.fill_with_neg_inf(self._mask.resize_(dim, dim)), 1)
return self._mask[:dim, :dim]
def reorder_incremental_state(self, incremental_state, new_order): def reorder_incremental_state(self, incremental_state, new_order):
"""Reorder buffered internal state (for incremental generation).""" """Reorder buffered internal state (for incremental generation)."""
input_buffer = self._get_input_buffer(incremental_state) input_buffer = self._get_input_buffer(incremental_state)
......
...@@ -106,8 +106,9 @@ class FP16Optimizer(optim.FairseqOptimizer): ...@@ -106,8 +106,9 @@ class FP16Optimizer(optim.FairseqOptimizer):
for p in self.params: for p in self.params:
if not p.requires_grad: if not p.requires_grad:
continue continue
numel = p.grad.data.numel() grad_data = p.grad.data if p.grad is not None else p.data.new_zeros(p.data.shape)
self.fp32_params.grad.data[offset:offset+numel].copy_(p.grad.data.view(-1)) numel = grad_data.numel()
self.fp32_params.grad.data[offset:offset+numel].copy_(grad_data.view(-1))
offset += numel offset += numel
# correct for dynamic loss scaler # correct for dynamic loss scaler
......
...@@ -507,7 +507,11 @@ class SequenceGenerator(object): ...@@ -507,7 +507,11 @@ class SequenceGenerator(object):
decoder_out = list(model.decoder(tokens, encoder_out)) decoder_out = list(model.decoder(tokens, encoder_out))
decoder_out[0] = decoder_out[0][:, -1, :] decoder_out[0] = decoder_out[0][:, -1, :]
attn = decoder_out[1] attn = decoder_out[1]
if type(attn) is dict:
attn = attn['attn']
if attn is not None: if attn is not None:
if type(attn) is dict:
attn = attn['attn']
attn = attn[:, -1, :] attn = attn[:, -1, :]
probs = model.get_normalized_probs(decoder_out, log_probs=log_probs) probs = model.get_normalized_probs(decoder_out, log_probs=log_probs)
return probs, attn return probs, attn
...@@ -13,7 +13,7 @@ from torch.utils.data import ConcatDataset ...@@ -13,7 +13,7 @@ from torch.utils.data import ConcatDataset
from fairseq.data import ( from fairseq.data import (
Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset, Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset,
MonolingualDataset, TokenBlockDataset, MonolingualDataset, TokenBlockDataset, TruncatedDictionary
) )
from . import FairseqTask, register_task from . import FairseqTask, register_task
...@@ -25,7 +25,14 @@ class LanguageModelingTask(FairseqTask): ...@@ -25,7 +25,14 @@ class LanguageModelingTask(FairseqTask):
Train a language model. Train a language model.
Args: Args:
dictionary (Dictionary): the dictionary for the language model dictionary (Dictionary): the dictionary for the input of the language model
output_dictionary (Dictionary): the dictionary for the output of the language model.
In most cases it will be the same as dictionary, but could possibly be a more limited
version of the dictionary (if --output-dictionary-size is used).
targets (List[str]): list of the target types that the language model should predict.
Can be one of "self", "future", and "past". Defaults to "future".
.. note:: .. note::
...@@ -55,10 +62,23 @@ class LanguageModelingTask(FairseqTask): ...@@ -55,10 +62,23 @@ class LanguageModelingTask(FairseqTask):
help='max number of tokens per sample for LM dataset') help='max number of tokens per sample for LM dataset')
parser.add_argument('--raw-text', default=False, action='store_true', parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset') help='load raw text dataset')
parser.add_argument('--output-dictionary-size', default=-1, type=int,
def __init__(self, args, dictionary): help='limit the size of output dictionary')
parser.add_argument('--self-target', action='store_true',
help='include self target')
parser.add_argument('--future-target', action='store_true',
help='include future target')
parser.add_argument('--past-target', action='store_true',
help='include past target')
def __init__(self, args, dictionary, output_dictionary, targets=None):
super().__init__(args) super().__init__(args)
self.dictionary = dictionary self.dictionary = dictionary
self.output_dictionary = output_dictionary
if targets is None:
targets = ['future']
self.targets = targets
@classmethod @classmethod
def setup_task(cls, args, **kwargs): def setup_task(cls, args, **kwargs):
...@@ -69,7 +89,36 @@ class LanguageModelingTask(FairseqTask): ...@@ -69,7 +89,36 @@ class LanguageModelingTask(FairseqTask):
""" """
dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt')) dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
print('| dictionary: {} types'.format(len(dictionary))) print('| dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary) output_dictionary = dictionary
if args.output_dictionary_size >= 0:
output_dictionary = TruncatedDictionary(dictionary, args.output_dictionary_size)
# upgrade old checkpoints
if hasattr(args, 'exclude_self_target'):
args.self_target = not args.exclude_self_target
targets = []
if args.self_target:
targets.append('self')
if args.future_target:
targets.append('future')
if args.past_target:
targets.append('past')
if len(targets) == 0:
# standard language modeling
targets = ['future']
return cls(args, dictionary, output_dictionary, targets=targets)
def build_model(self, args):
model = super().build_model(args)
for target in self.targets:
if target not in model.supported_targets:
raise ValueError('Unsupported language modeling target: {}'.format(target))
return model
def load_dataset(self, split, combine=False): def load_dataset(self, split, combine=False):
"""Load a given dataset split. """Load a given dataset split.
...@@ -98,8 +147,8 @@ class LanguageModelingTask(FairseqTask): ...@@ -98,8 +147,8 @@ class LanguageModelingTask(FairseqTask):
loaded_datasets.append( loaded_datasets.append(
TokenBlockDataset( TokenBlockDataset(
tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode, tokens, ds.sizes, self.args.tokens_per_sample, pad=self.dictionary.pad(), eos=self.dictionary.eos(),
include_targets=True break_mode=self.args.sample_break_mode, include_targets=True,
)) ))
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1]))) print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))
...@@ -114,10 +163,16 @@ class LanguageModelingTask(FairseqTask): ...@@ -114,10 +163,16 @@ class LanguageModelingTask(FairseqTask):
dataset = ConcatDataset(loaded_datasets) dataset = ConcatDataset(loaded_datasets)
sizes = np.concatenate([ds.sizes for ds in loaded_datasets]) sizes = np.concatenate([ds.sizes for ds in loaded_datasets])
self.datasets[split] = MonolingualDataset(dataset, sizes, self.dictionary, shuffle=False) add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none'
self.datasets[split] = MonolingualDataset(
dataset, sizes, self.dictionary, self.output_dictionary,
add_eos_for_other_targets=add_eos_for_other_targets, shuffle=False,
targets=self.targets,
)
@property @property
def target_dictionary(self): def target_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language """Return the :class:`~fairseq.data.Dictionary` for the language
model.""" model."""
return self.dictionary return self.output_dictionary
...@@ -40,7 +40,7 @@ def mock_dict(): ...@@ -40,7 +40,7 @@ def mock_dict():
def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch): def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
tokens = torch.LongTensor(list(range(epoch_size))) tokens = torch.LongTensor(list(range(epoch_size)))
tokens_ds = data.TokenBlockDataset(tokens, [len(tokens)], 1, include_targets=False) tokens_ds = data.TokenBlockDataset(tokens, sizes=[len(tokens)], block_size=1, pad=0, eos=1, include_targets=False)
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch) trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False) dataset = data.LanguagePairDataset(tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False)
epoch_itr = data.EpochBatchIterator( epoch_itr = data.EpochBatchIterator(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment