Commit d2e2a1d4 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

Transformer lm

This implements transformer based language model. It already obtains better perplexity on wikitext103 without any tuning. I will also train it on gbw where I also expect to get better ppl

Example training command:

python train.py /private/home/abaevski/data/wiki103 —save-dir /tmp —fp16 —max-epoch 80 —save-interval 1 —arch transformer_lm —task language_modeling —optimizer nag —lr 0.008 —lr-scheduler reduce_lr_on_plateau —lr-shrink 0.6 —dropout 0.2 —criterion adaptive_loss —adaptive-softmax-cutoff 10000,50000,200000 —max-tokens 512 —tokens-per-sample 512 —seed 1 —sample-break-mode none —log-format json —log-interval 50 —save-interval-updates 2500 —keep-interval-updates 25
small transformer got to 31.3 ppl on wiki text 103 (compared to 35 with fconv) while @myleott got a big transformer lm to 27 something ppl on wiki text 103
parent 0ef2856c
...@@ -41,8 +41,8 @@ def main(args): ...@@ -41,8 +41,8 @@ def main(args):
itr = data.EpochBatchIterator( itr = data.EpochBatchIterator(
dataset=task.dataset(args.gen_subset), dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens, max_tokens=args.max_tokens or 36000,
max_sentences=args.max_sentences or 4, max_sentences=args.max_sentences,
max_positions=models[0].max_positions(), max_positions=models[0].max_positions(),
num_shards=args.num_shards, num_shards=args.num_shards,
shard_id=args.shard_id, shard_id=args.shard_id,
...@@ -56,19 +56,35 @@ def main(args): ...@@ -56,19 +56,35 @@ def main(args):
score_sum = 0. score_sum = 0.
count = 0 count = 0
if args.remove_bpe is not None:
bpe_cont = args.remove_bpe.rstrip()
bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
else:
bpe_toks = None
with progress_bar.build_progress_bar(args, itr) as t: with progress_bar.build_progress_bar(args, itr) as t:
results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer) results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
wps_meter = TimeMeter() wps_meter = TimeMeter()
for _, src_tokens, __, hypos in results: for _, src_tokens, __, hypos in results:
for hypo in hypos: for hypo in hypos:
pos_scores = hypo['positional_scores'] pos_scores = hypo['positional_scores']
skipped_toks = 0
if bpe_toks is not None:
for i in range(len(hypo['tokens']) - 1):
if hypo['tokens'][i].item() in bpe_toks:
skipped_toks += 1
pos_scores[i + 1] += pos_scores[i]
pos_scores[i] = 0
inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
if inf_scores.any(): if inf_scores.any():
print('| Skipping tokens with inf scores:', print('| Skipping tokens with inf scores:',
task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()])) task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
pos_scores = pos_scores[(~inf_scores).nonzero()] pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum() score_sum += pos_scores.sum()
count += pos_scores.numel() count += pos_scores.numel() - skipped_toks
wps_meter.update(src_tokens.size(0)) wps_meter.update(src_tokens.size(0))
t.log({'wps': round(wps_meter.avg)}) t.log({'wps': round(wps_meter.avg)})
......
...@@ -29,12 +29,13 @@ class TokenBlockDataset(torch.utils.data.Dataset): ...@@ -29,12 +29,13 @@ class TokenBlockDataset(torch.utils.data.Dataset):
include_targets: return next tokens as targets include_targets: return next tokens as targets
""" """
def __init__(self, tokens, sizes, block_size, break_mode=None, include_targets=False): def __init__(self, tokens, sizes, block_size, break_mode=None, include_targets=False, reverse=False):
super().__init__() super().__init__()
self.tokens = tokens self.tokens = tokens
self.total_size = len(tokens) self.total_size = len(tokens)
self.include_targets = include_targets self.include_targets = include_targets
self.reverse = reverse
self.slice_indices = [] self.slice_indices = []
if break_mode is None or break_mode == 'none': if break_mode is None or break_mode == 'none':
...@@ -77,8 +78,19 @@ class TokenBlockDataset(torch.utils.data.Dataset): ...@@ -77,8 +78,19 @@ class TokenBlockDataset(torch.utils.data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
s, e = self.slice_indices[index] s, e = self.slice_indices[index]
if self.reverse:
item = torch.LongTensor(np.flip(self.tokens[s:e], 0).copy())
else:
item = torch.LongTensor(self.tokens[s:e]) item = torch.LongTensor(self.tokens[s:e])
if self.include_targets: if self.include_targets:
if self.reverse:
if s == 0:
target = np.concatenate([self.tokens[-1:], item.numpy()[1:]])
else:
target = np.concatenate([self.tokens[s - 1:s], item.numpy()[:-1]])
return item, torch.LongTensor(target)
# target is the sentence, for source, rotate item one token to the left (would start with eos) # target is the sentence, for source, rotate item one token to the left (would start with eos)
if s == 0: if s == 0:
source = np.concatenate([self.tokens[-1:], self.tokens[0:e - 1]]) source = np.concatenate([self.tokens[-1:], self.tokens[0:e - 1]])
...@@ -86,7 +98,6 @@ class TokenBlockDataset(torch.utils.data.Dataset): ...@@ -86,7 +98,6 @@ class TokenBlockDataset(torch.utils.data.Dataset):
source = self.tokens[s - 1:e - 1] source = self.tokens[s - 1:e - 1]
return torch.LongTensor(source), item return torch.LongTensor(source), item
return item return item
def __len__(self): def __len__(self):
......
...@@ -19,8 +19,14 @@ class FairseqDecoder(nn.Module): ...@@ -19,8 +19,14 @@ class FairseqDecoder(nn.Module):
def forward(self, prev_output_tokens, encoder_out): def forward(self, prev_output_tokens, encoder_out):
raise NotImplementedError raise NotImplementedError
def get_normalized_probs(self, net_output, log_probs, _): def get_normalized_probs(self, net_output, log_probs, sample):
"""Get normalized probabilities (or log probs) from a net's output.""" """Get normalized probabilities (or log probs) from a net's output."""
if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None:
assert sample is not None and 'target' in sample
out = self.adaptive_softmax.get_log_prob(net_output[0], sample['target'])
return out.exp_() if not log_probs else out
logits = net_output[0].float() logits = net_output[0].float()
if log_probs: if log_probs:
return F.log_softmax(logits, dim=-1) return F.log_softmax(logits, dim=-1)
......
...@@ -493,16 +493,6 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -493,16 +493,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
return x, avg_attn_scores return x, avg_attn_scores
def get_normalized_probs(self, net_output, log_probs, sample):
"""Get normalized probabilities (or log probs) from a net's output."""
if self.adaptive_softmax is not None:
assert sample is not None and 'target' in sample
out = self.adaptive_softmax.get_log_prob(net_output[0], sample['target'])
return out.exp_() if not log_probs else out
else:
return super().get_normalized_probs(net_output, log_probs, sample)
def reorder_incremental_state(self, incremental_state, new_order): def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order) super().reorder_incremental_state(incremental_state, new_order)
encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out') encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out')
......
...@@ -11,16 +11,16 @@ import torch ...@@ -11,16 +11,16 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import options
from fairseq import utils from fairseq import utils
from fairseq.modules import ( from fairseq.modules import (
LearnedPositionalEmbedding, MultiheadAttention, AdaptiveSoftmax, LearnedPositionalEmbedding, MultiheadAttention, SinusoidalPositionalEmbedding
SinusoidalPositionalEmbedding,
) )
from . import ( from . import (
FairseqIncrementalDecoder, FairseqEncoder, FairseqModel, FairseqIncrementalDecoder, FairseqEncoder, FairseqLanguageModel, FairseqModel, register_model,
register_model, register_model_architecture, register_model_architecture,
) )
...@@ -71,13 +71,22 @@ class TransformerModel(FairseqModel): ...@@ -71,13 +71,22 @@ class TransformerModel(FairseqModel):
parser.add_argument('--share-all-embeddings', action='store_true', parser.add_argument('--share-all-embeddings', action='store_true',
help='share encoder, decoder and output embeddings' help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)') ' (requires shared dictionary and embed dim)')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion')
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
"""Build a new model instance.""" """Build a new model instance."""
# make sure that all args are properly defaulted (in case there are any new ones)
# make sure all arguments are present in older models
base_architecture(args) base_architecture(args)
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = 1024
if not hasattr(args, 'max_target_positions'):
args.max_target_positions = 1024
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim, path=None): def build_embedding(dictionary, embed_dim, path=None):
...@@ -117,6 +126,56 @@ class TransformerModel(FairseqModel): ...@@ -117,6 +126,56 @@ class TransformerModel(FairseqModel):
return TransformerModel(encoder, decoder) return TransformerModel(encoder, decoder)
@register_model('transformer_lm')
class TransformerLanguageModel(FairseqLanguageModel):
def __init__(self, decoder):
super().__init__(decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', default=0., type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', default=0., type=float, metavar='D',
help='dropout probability after ReLU in FFN')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-normalize-before', default=False, action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true',
help='share decoder input and output embeddings')
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_lm_architecture(args)
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.tokens_per_sample
if not hasattr(args, 'max_target_positions'):
args.max_target_positions = args.tokens_per_sample
embed_tokens = Embedding(len(task.dictionary), args.decoder_embed_dim, task.dictionary.pad())
decoder = TransformerDecoder(args, task.dictionary, embed_tokens, no_encoder_attn=True)
return TransformerLanguageModel(decoder)
class TransformerEncoder(FairseqEncoder): class TransformerEncoder(FairseqEncoder):
"""Transformer encoder.""" """Transformer encoder."""
...@@ -126,14 +185,15 @@ class TransformerEncoder(FairseqEncoder): ...@@ -126,14 +185,15 @@ class TransformerEncoder(FairseqEncoder):
embed_dim = embed_tokens.embedding_dim embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = args.max_source_positions
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
1024, embed_dim, self.padding_idx, args.max_source_positions, embed_dim, self.padding_idx,
left_pad=left_pad, left_pad=left_pad,
learned=args.encoder_learned_pos, learned=args.encoder_learned_pos,
) ) if not args.no_token_positional_embeddings else None
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.layers.extend([ self.layers.extend([
...@@ -144,6 +204,7 @@ class TransformerEncoder(FairseqEncoder): ...@@ -144,6 +204,7 @@ class TransformerEncoder(FairseqEncoder):
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
# embed tokens and positions # embed tokens and positions
x = self.embed_scale * self.embed_tokens(src_tokens) x = self.embed_scale * self.embed_tokens(src_tokens)
if self.embed_positions is not None:
x += self.embed_positions(src_tokens) x += self.embed_positions(src_tokens)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
...@@ -175,7 +236,9 @@ class TransformerEncoder(FairseqEncoder): ...@@ -175,7 +236,9 @@ class TransformerEncoder(FairseqEncoder):
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
return self.embed_positions.max_positions() if self.embed_positions is None:
return self.max_source_positions
return min(self.max_source_positions, self.embed_positions.max_positions())
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
...@@ -189,63 +252,76 @@ class TransformerEncoder(FairseqEncoder): ...@@ -189,63 +252,76 @@ class TransformerEncoder(FairseqEncoder):
class TransformerDecoder(FairseqIncrementalDecoder): class TransformerDecoder(FairseqIncrementalDecoder):
"""Transformer decoder.""" """Transformer decoder."""
def __init__(self, args, dictionary, embed_tokens, left_pad=False): def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout = args.dropout self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed self.share_input_output_embed = args.share_decoder_input_output_embed
embed_dim = embed_tokens.embedding_dim embed_dim = embed_tokens.embedding_dim
padding_idx = embed_tokens.padding_idx padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim) self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding( self.embed_positions = PositionalEmbedding(
1024, embed_dim, padding_idx, args.max_target_positions, embed_dim, padding_idx,
left_pad=left_pad, left_pad=left_pad,
learned=args.decoder_learned_pos, learned=args.decoder_learned_pos,
) ) if not args.no_token_positional_embeddings else None
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.layers.extend([ self.layers.extend([
TransformerDecoderLayer(args) TransformerDecoderLayer(args, no_encoder_attn)
for i in range(args.decoder_layers) for _ in range(args.decoder_layers)
]) ])
if not self.share_input_output_embed: self.adaptive_softmax = None
if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary), args.decoder_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.dropout
)
elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim)) self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5) nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
# embed positions # embed positions
positions = self.embed_positions( positions = self.embed_positions(
prev_output_tokens, prev_output_tokens,
incremental_state=incremental_state, incremental_state=incremental_state,
) ) if self.embed_positions is not None else None
if incremental_state is not None: if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:] prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:] positions = positions[:, -1:]
# embed tokens and positions # embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens) x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if positions is not None:
x += positions x += positions
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C # B x T x C -> T x B x C
x = x.transpose(0, 1) x = x.transpose(0, 1)
attn = None
# decoder layers # decoder layers
for layer in self.layers: for layer in self.layers:
x, attn = layer( x, attn = layer(
x, x,
encoder_out['encoder_out'], encoder_out['encoder_out'] if encoder_out is not None else None,
encoder_out['encoder_padding_mask'], encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
incremental_state, incremental_state,
) )
# T x B x C -> B x T x C # T x B x C -> B x T x C
x = x.transpose(0, 1) x = x.transpose(0, 1)
if self.adaptive_softmax is None:
# project back to size of vocabulary # project back to size of vocabulary
if self.share_input_output_embed: if self.share_input_output_embed:
x = F.linear(x, self.embed_tokens.weight) x = F.linear(x, self.embed_tokens.weight)
...@@ -256,7 +332,9 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -256,7 +332,9 @@ class TransformerDecoder(FairseqIncrementalDecoder):
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.embed_positions.max_positions() if self.embed_positions is None:
return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions())
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
...@@ -264,6 +342,21 @@ class TransformerDecoder(FairseqIncrementalDecoder): ...@@ -264,6 +342,21 @@ class TransformerDecoder(FairseqIncrementalDecoder):
del state_dict['decoder.embed_positions.weights'] del state_dict['decoder.embed_positions.weights']
if 'decoder.embed_positions._float_tensor' not in state_dict: if 'decoder.embed_positions._float_tensor' not in state_dict:
state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor() state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor()
for i in range(len(self.layers)):
# update layer norms
layer_norm_map = {
'0': 'self_attn_layer_norm',
'1': 'encoder_attn_layer_norm',
'2': 'final_layer_norm'
}
for old, new in layer_norm_map.items():
for m in ('weight', 'bias'):
k = 'decoder.layers.{}.layer_norms.{}.{}'.format(i, old, m)
if k in state_dict:
state_dict['decoder.layers.{}.{}.{}'.format(i, new, m)] = state_dict[k]
del state_dict[k]
return state_dict return state_dict
...@@ -322,7 +415,7 @@ class TransformerEncoderLayer(nn.Module): ...@@ -322,7 +415,7 @@ class TransformerEncoderLayer(nn.Module):
class TransformerDecoderLayer(nn.Module): class TransformerDecoderLayer(nn.Module):
"""Decoder layer block.""" """Decoder layer block."""
def __init__(self, args): def __init__(self, args, no_encoder_attn=False):
super().__init__() super().__init__()
self.embed_dim = args.decoder_embed_dim self.embed_dim = args.decoder_embed_dim
self.self_attn = MultiheadAttention( self.self_attn = MultiheadAttention(
...@@ -332,18 +425,28 @@ class TransformerDecoderLayer(nn.Module): ...@@ -332,18 +425,28 @@ class TransformerDecoderLayer(nn.Module):
self.dropout = args.dropout self.dropout = args.dropout
self.relu_dropout = args.relu_dropout self.relu_dropout = args.relu_dropout
self.normalize_before = args.decoder_normalize_before self.normalize_before = args.decoder_normalize_before
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = MultiheadAttention( self.encoder_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads, self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout, dropout=args.attention_dropout,
) )
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim) self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim) self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)])
self.final_layer_norm = LayerNorm(self.embed_dim)
self.need_attn = True self.need_attn = True
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state): def forward(self, x, encoder_out, encoder_padding_mask, incremental_state):
residual = x residual = x
x = self.maybe_layer_norm(0, x, before=True) x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
x, _ = self.self_attn( x, _ = self.self_attn(
query=x, query=x,
key=x, key=x,
...@@ -354,10 +457,12 @@ class TransformerDecoderLayer(nn.Module): ...@@ -354,10 +457,12 @@ class TransformerDecoderLayer(nn.Module):
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.maybe_layer_norm(0, x, after=True) x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
attn = None
if self.encoder_attn is not None:
residual = x residual = x
x = self.maybe_layer_norm(1, x, before=True) x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
x, attn = self.encoder_attn( x, attn = self.encoder_attn(
query=x, query=x,
key=encoder_out, key=encoder_out,
...@@ -369,22 +474,22 @@ class TransformerDecoderLayer(nn.Module): ...@@ -369,22 +474,22 @@ class TransformerDecoderLayer(nn.Module):
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.maybe_layer_norm(1, x, after=True) x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
residual = x residual = x
x = self.maybe_layer_norm(2, x, before=True) x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = F.relu(self.fc1(x)) x = F.relu(self.fc1(x))
x = F.dropout(x, p=self.relu_dropout, training=self.training) x = F.dropout(x, p=self.relu_dropout, training=self.training)
x = self.fc2(x) x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
x = self.maybe_layer_norm(2, x, after=True) x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
return x, attn return x, attn
def maybe_layer_norm(self, i, x, before=False, after=False): def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after assert before ^ after
if after ^ self.normalize_before: if after ^ self.normalize_before:
return self.layer_norms[i](x) return layer_norm(x)
else: else:
return x return x
...@@ -395,6 +500,7 @@ class TransformerDecoderLayer(nn.Module): ...@@ -395,6 +500,7 @@ class TransformerDecoderLayer(nn.Module):
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m return m
...@@ -412,14 +518,48 @@ def Linear(in_features, out_features, bias=True): ...@@ -412,14 +518,48 @@ def Linear(in_features, out_features, bias=True):
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False): def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
if learned: if learned:
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad) m = LearnedPositionalEmbedding(num_embeddings + padding_idx + 1, embedding_dim, padding_idx, left_pad)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0) nn.init.constant_(m.weight[padding_idx], 0)
else: else:
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, num_embeddings) m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, num_embeddings + padding_idx + 1)
return m return m
@register_model_architecture('transformer_lm', 'transformer_lm')
def base_lm_architecture(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
# The model training is not stable without this
args.decoder_normalize_before = True
@register_model_architecture('transformer_lm', 'transformer_lm_big')
def transformer_lm_big(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
base_lm_architecture(args)
@register_model_architecture('transformer_lm', 'transformer_lm_wiki103')
def transformer_lm_wiki103(args):
args.dropout = getattr(args, 'dropout', 0.3)
base_lm_architecture(args)
@register_model_architecture('transformer_lm', 'transformer_lm_gbw')
def transformer_lm_gbw(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
transformer_lm_big(args)
@register_model_architecture('transformer', 'transformer') @register_model_architecture('transformer', 'transformer')
def base_architecture(args): def base_architecture(args):
args.encoder_embed_path = getattr(args, 'encoder_embed_path', None) args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
...@@ -439,8 +579,10 @@ def base_architecture(args): ...@@ -439,8 +579,10 @@ def base_architecture(args):
args.attention_dropout = getattr(args, 'attention_dropout', 0.) args.attention_dropout = getattr(args, 'attention_dropout', 0.)
args.relu_dropout = getattr(args, 'relu_dropout', 0.) args.relu_dropout = getattr(args, 'relu_dropout', 0.)
args.dropout = getattr(args, 'dropout', 0.1) args.dropout = getattr(args, 'dropout', 0.1)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False) args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False) args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
@register_model_architecture('transformer', 'transformer_iwslt_de_en') @register_model_architecture('transformer', 'transformer_iwslt_de_en')
......
...@@ -28,7 +28,7 @@ class FairseqTask(object): ...@@ -28,7 +28,7 @@ class FairseqTask(object):
def setup_task(cls, args, **kwargs): def setup_task(cls, args, **kwargs):
raise NotImplementedError raise NotImplementedError
def load_dataset(self, split): def load_dataset(self, split, combine=False):
raise NotImplementedError raise NotImplementedError
def dataset(self, split): def dataset(self, split):
......
...@@ -5,8 +5,12 @@ ...@@ -5,8 +5,12 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import os import os
from torch.utils.data import ConcatDataset
from fairseq.data import ( from fairseq.data import (
Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset, Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset,
MonolingualDataset, TokenBlockDataset, MonolingualDataset, TokenBlockDataset,
...@@ -32,9 +36,13 @@ class LanguageModelingTask(FairseqTask): ...@@ -32,9 +36,13 @@ class LanguageModelingTask(FairseqTask):
help='max number of tokens per sample for LM dataset') help='max number of tokens per sample for LM dataset')
parser.add_argument('--raw-text', default=False, action='store_true', parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset') help='load raw text dataset')
parser.add_argument('--right-to-left', default=False, action='store_true',
help='if set, trains a language model right-to-left (instead of left-to-right)')
def __init__(self, args, dictionary): def __init__(self, args, dictionary):
super().__init__(args) super().__init__(args)
args.right_to_left = getattr(args, 'right_to_left', False)
self.dictionary = dictionary self.dictionary = dictionary
@classmethod @classmethod
...@@ -43,23 +51,46 @@ class LanguageModelingTask(FairseqTask): ...@@ -43,23 +51,46 @@ class LanguageModelingTask(FairseqTask):
print('| dictionary: {} types'.format(len(dictionary))) print('| dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary) return cls(args, dictionary)
def load_dataset(self, split): def load_dataset(self, split, combine=False):
"""Load a dataset split.""" """Load a dataset split."""
path = os.path.join(self.args.data, split)
loaded_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
path = os.path.join(self.args.data, split_k)
if self.args.raw_text and IndexedRawTextDataset.exists(path): if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary) ds = IndexedRawTextDataset(path, self.dictionary)
tokens = [t for l in ds.tokens_list for t in l] tokens = [t for l in ds.tokens_list for t in l]
elif not self.args.raw_text and IndexedInMemoryDataset.exists(path): elif not self.args.raw_text and IndexedInMemoryDataset.exists(path):
ds = IndexedInMemoryDataset(path, fix_lua_indexing=True) ds = IndexedInMemoryDataset(path, fix_lua_indexing=True)
tokens = ds.buffer tokens = ds.buffer
else:
if k > 0:
break
else: else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data)) raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
dataset = TokenBlockDataset( loaded_datasets.append(
TokenBlockDataset(
tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode, tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode,
include_targets=True, # return next tokens as targets include_targets=True, reverse=self.args.right_to_left,
) ))
self.datasets[split] = MonolingualDataset(dataset, dataset.sizes, self.dictionary, shuffle=False)
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))
if not combine:
break
if len(loaded_datasets) == 1:
dataset = loaded_datasets[0]
sizes = dataset.sizes
else:
dataset = ConcatDataset(loaded_datasets)
sizes = np.concatenate([ds.sizes for ds in loaded_datasets])
self.datasets[split] = MonolingualDataset(dataset, sizes, self.dictionary, shuffle=False)
@property @property
def target_dictionary(self): def target_dictionary(self):
......
...@@ -5,8 +5,12 @@ ...@@ -5,8 +5,12 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import os import os
from torch.utils.data import ConcatDataset
from fairseq import options from fairseq import options
from fairseq.data import ( from fairseq.data import (
data_utils, Dictionary, LanguagePairDataset, IndexedInMemoryDataset, data_utils, Dictionary, LanguagePairDataset, IndexedInMemoryDataset,
...@@ -65,10 +69,10 @@ class TranslationTask(FairseqTask): ...@@ -65,10 +69,10 @@ class TranslationTask(FairseqTask):
return cls(args, src_dict, tgt_dict) return cls(args, src_dict, tgt_dict)
def load_dataset(self, split): def load_dataset(self, split, combine=False):
"""Load a dataset split.""" """Load a dataset split."""
def split_exists(src, tgt, lang): def split_exists(split, src, tgt, lang):
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang)) filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename): if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True return True
...@@ -76,15 +80,6 @@ class TranslationTask(FairseqTask): ...@@ -76,15 +80,6 @@ class TranslationTask(FairseqTask):
return True return True
return False return False
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
if split_exists(src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt))
elif split_exists(tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src))
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
def indexed_dataset(path, dictionary): def indexed_dataset(path, dictionary):
if self.args.raw_text: if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary) return IndexedRawTextDataset(path, dictionary)
...@@ -92,11 +87,48 @@ class TranslationTask(FairseqTask): ...@@ -92,11 +87,48 @@ class TranslationTask(FairseqTask):
return IndexedInMemoryDataset(path, fix_lua_indexing=True) return IndexedInMemoryDataset(path, fix_lua_indexing=True)
return None return None
src_dataset = indexed_dataset(prefix + src, self.src_dict) src_datasets = []
tgt_dataset = indexed_dataset(prefix + tgt, self.tgt_dict) tgt_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
if split_exists(split_k, src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split_k, src, tgt))
elif split_exists(split_k, tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split_k, tgt, src))
else:
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
src_datasets.append(indexed_dataset(prefix + src, self.src_dict))
tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict))
print('| {} {} {} examples'.format(self.args.data, split_k, len(src_datasets[-1])))
if not combine:
break
assert len(src_datasets) == len(tgt_datasets)
if len(src_datasets) == 1:
src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
src_sizes = src_dataset.sizes
tgt_sizes = tgt_dataset.sizes
else:
src_dataset = ConcatDataset(src_datasets)
tgt_dataset = ConcatDataset(tgt_datasets)
src_sizes = np.concatenate([ds.sizes for ds in src_datasets])
tgt_sizes = np.concatenate([ds.sizes for ds in tgt_datasets])
self.datasets[split] = LanguagePairDataset( self.datasets[split] = LanguagePairDataset(
src_dataset, src_dataset.sizes, self.src_dict, src_dataset, src_sizes, self.src_dict,
tgt_dataset, tgt_dataset.sizes, self.tgt_dict, tgt_dataset, tgt_sizes, self.tgt_dict,
left_pad_source=self.args.left_pad_source, left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target, left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions, max_source_positions=self.args.max_source_positions,
......
...@@ -32,7 +32,7 @@ def main(args): ...@@ -32,7 +32,7 @@ def main(args):
task = tasks.setup_task(args) task = tasks.setup_task(args)
# Load dataset splits # Load dataset splits
load_dataset_splits(args, task, ['train', 'valid']) load_dataset_splits(task, ['train', 'valid'])
# Build model and criterion # Build model and criterion
model = task.build_model(args) model = task.build_model(args)
...@@ -316,13 +316,15 @@ def load_checkpoint(args, trainer, epoch_itr): ...@@ -316,13 +316,15 @@ def load_checkpoint(args, trainer, epoch_itr):
save_checkpoint.best = extra_state['best'] save_checkpoint.best = extra_state['best']
def load_dataset_splits(args, task, splits): def load_dataset_splits(task, splits):
for split in splits: for split in splits:
if split == 'train':
task.load_dataset(split, combine=True)
else:
for k in itertools.count(): for k in itertools.count():
split_k = split + (str(k) if k > 0 else '') split_k = split + (str(k) if k > 0 else '')
try: try:
task.load_dataset(split_k) task.load_dataset(split_k, combine=False)
print('| {} {} {} examples'.format(args.data, split_k, len(task.dataset(split_k))))
except FileNotFoundError as e: except FileNotFoundError as e:
if k > 0: if k > 0:
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment