Commit d2e2a1d4 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

Transformer lm

This implements transformer based language model. It already obtains better perplexity on wikitext103 without any tuning. I will also train it on gbw where I also expect to get better ppl

Example training command:

python train.py /private/home/abaevski/data/wiki103 —save-dir /tmp —fp16 —max-epoch 80 —save-interval 1 —arch transformer_lm —task language_modeling —optimizer nag —lr 0.008 —lr-scheduler reduce_lr_on_plateau —lr-shrink 0.6 —dropout 0.2 —criterion adaptive_loss —adaptive-softmax-cutoff 10000,50000,200000 —max-tokens 512 —tokens-per-sample 512 —seed 1 —sample-break-mode none —log-format json —log-interval 50 —save-interval-updates 2500 —keep-interval-updates 25
small transformer got to 31.3 ppl on wiki text 103 (compared to 35 with fconv) while @myleott got a big transformer lm to 27 something ppl on wiki text 103
parent 0ef2856c
......@@ -41,8 +41,8 @@ def main(args):
itr = data.EpochBatchIterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences or 4,
max_tokens=args.max_tokens or 36000,
max_sentences=args.max_sentences,
max_positions=models[0].max_positions(),
num_shards=args.num_shards,
shard_id=args.shard_id,
......@@ -56,19 +56,35 @@ def main(args):
score_sum = 0.
count = 0
if args.remove_bpe is not None:
bpe_cont = args.remove_bpe.rstrip()
bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
else:
bpe_toks = None
with progress_bar.build_progress_bar(args, itr) as t:
results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
wps_meter = TimeMeter()
for _, src_tokens, __, hypos in results:
for hypo in hypos:
pos_scores = hypo['positional_scores']
skipped_toks = 0
if bpe_toks is not None:
for i in range(len(hypo['tokens']) - 1):
if hypo['tokens'][i].item() in bpe_toks:
skipped_toks += 1
pos_scores[i + 1] += pos_scores[i]
pos_scores[i] = 0
inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
if inf_scores.any():
print('| Skipping tokens with inf scores:',
task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum()
count += pos_scores.numel()
count += pos_scores.numel() - skipped_toks
wps_meter.update(src_tokens.size(0))
t.log({'wps': round(wps_meter.avg)})
......
......@@ -29,12 +29,13 @@ class TokenBlockDataset(torch.utils.data.Dataset):
include_targets: return next tokens as targets
"""
def __init__(self, tokens, sizes, block_size, break_mode=None, include_targets=False):
def __init__(self, tokens, sizes, block_size, break_mode=None, include_targets=False, reverse=False):
super().__init__()
self.tokens = tokens
self.total_size = len(tokens)
self.include_targets = include_targets
self.reverse = reverse
self.slice_indices = []
if break_mode is None or break_mode == 'none':
......@@ -77,8 +78,19 @@ class TokenBlockDataset(torch.utils.data.Dataset):
def __getitem__(self, index):
s, e = self.slice_indices[index]
item = torch.LongTensor(self.tokens[s:e])
if self.reverse:
item = torch.LongTensor(np.flip(self.tokens[s:e], 0).copy())
else:
item = torch.LongTensor(self.tokens[s:e])
if self.include_targets:
if self.reverse:
if s == 0:
target = np.concatenate([self.tokens[-1:], item.numpy()[1:]])
else:
target = np.concatenate([self.tokens[s - 1:s], item.numpy()[:-1]])
return item, torch.LongTensor(target)
# target is the sentence, for source, rotate item one token to the left (would start with eos)
if s == 0:
source = np.concatenate([self.tokens[-1:], self.tokens[0:e - 1]])
......@@ -86,7 +98,6 @@ class TokenBlockDataset(torch.utils.data.Dataset):
source = self.tokens[s - 1:e - 1]
return torch.LongTensor(source), item
return item
def __len__(self):
......
......@@ -19,8 +19,14 @@ class FairseqDecoder(nn.Module):
def forward(self, prev_output_tokens, encoder_out):
raise NotImplementedError
def get_normalized_probs(self, net_output, log_probs, _):
def get_normalized_probs(self, net_output, log_probs, sample):
"""Get normalized probabilities (or log probs) from a net's output."""
if hasattr(self, 'adaptive_softmax') and self.adaptive_softmax is not None:
assert sample is not None and 'target' in sample
out = self.adaptive_softmax.get_log_prob(net_output[0], sample['target'])
return out.exp_() if not log_probs else out
logits = net_output[0].float()
if log_probs:
return F.log_softmax(logits, dim=-1)
......
......@@ -493,16 +493,6 @@ class FConvDecoder(FairseqIncrementalDecoder):
return x, avg_attn_scores
def get_normalized_probs(self, net_output, log_probs, sample):
"""Get normalized probabilities (or log probs) from a net's output."""
if self.adaptive_softmax is not None:
assert sample is not None and 'target' in sample
out = self.adaptive_softmax.get_log_prob(net_output[0], sample['target'])
return out.exp_() if not log_probs else out
else:
return super().get_normalized_probs(net_output, log_probs, sample)
def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
encoder_out = utils.get_incremental_state(self, incremental_state, 'encoder_out')
......
......@@ -11,16 +11,16 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import options
from fairseq import utils
from fairseq.modules import (
LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding,
AdaptiveSoftmax, LearnedPositionalEmbedding, MultiheadAttention, SinusoidalPositionalEmbedding
)
from . import (
FairseqIncrementalDecoder, FairseqEncoder, FairseqModel,
register_model, register_model_architecture,
FairseqIncrementalDecoder, FairseqEncoder, FairseqLanguageModel, FairseqModel, register_model,
register_model_architecture,
)
......@@ -71,13 +71,22 @@ class TransformerModel(FairseqModel):
parser.add_argument('--share-all-embeddings', action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion')
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure that all args are properly defaulted (in case there are any new ones)
# make sure all arguments are present in older models
base_architecture(args)
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = 1024
if not hasattr(args, 'max_target_positions'):
args.max_target_positions = 1024
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim, path=None):
......@@ -117,6 +126,56 @@ class TransformerModel(FairseqModel):
return TransformerModel(encoder, decoder)
@register_model('transformer_lm')
class TransformerLanguageModel(FairseqLanguageModel):
def __init__(self, decoder):
super().__init__(decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', default=0., type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', default=0., type=float, metavar='D',
help='dropout probability after ReLU in FFN')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-normalize-before', default=False, action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion')
parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true',
help='if set, disables positional embeddings (outside self attention)')
parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true',
help='share decoder input and output embeddings')
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_lm_architecture(args)
if not hasattr(args, 'max_source_positions'):
args.max_source_positions = args.tokens_per_sample
if not hasattr(args, 'max_target_positions'):
args.max_target_positions = args.tokens_per_sample
embed_tokens = Embedding(len(task.dictionary), args.decoder_embed_dim, task.dictionary.pad())
decoder = TransformerDecoder(args, task.dictionary, embed_tokens, no_encoder_attn=True)
return TransformerLanguageModel(decoder)
class TransformerEncoder(FairseqEncoder):
"""Transformer encoder."""
......@@ -126,14 +185,15 @@ class TransformerEncoder(FairseqEncoder):
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.max_source_positions = args.max_source_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
1024, embed_dim, self.padding_idx,
args.max_source_positions, embed_dim, self.padding_idx,
left_pad=left_pad,
learned=args.encoder_learned_pos,
)
) if not args.no_token_positional_embeddings else None
self.layers = nn.ModuleList([])
self.layers.extend([
......@@ -144,7 +204,8 @@ class TransformerEncoder(FairseqEncoder):
def forward(self, src_tokens, src_lengths):
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(src_tokens)
x += self.embed_positions(src_tokens)
if self.embed_positions is not None:
x += self.embed_positions(src_tokens)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
......@@ -175,7 +236,9 @@ class TransformerEncoder(FairseqEncoder):
def max_positions(self):
"""Maximum input length supported by the encoder."""
return self.embed_positions.max_positions()
if self.embed_positions is None:
return self.max_source_positions
return min(self.max_source_positions, self.embed_positions.max_positions())
def upgrade_state_dict(self, state_dict):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
......@@ -189,74 +252,89 @@ class TransformerEncoder(FairseqEncoder):
class TransformerDecoder(FairseqIncrementalDecoder):
"""Transformer decoder."""
def __init__(self, args, dictionary, embed_tokens, left_pad=False):
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False, left_pad=False):
super().__init__(dictionary)
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
embed_dim = embed_tokens.embedding_dim
padding_idx = embed_tokens.padding_idx
self.max_target_positions = args.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
1024, embed_dim, padding_idx,
args.max_target_positions, embed_dim, padding_idx,
left_pad=left_pad,
learned=args.decoder_learned_pos,
)
) if not args.no_token_positional_embeddings else None
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerDecoderLayer(args)
for i in range(args.decoder_layers)
TransformerDecoderLayer(args, no_encoder_attn)
for _ in range(args.decoder_layers)
])
if not self.share_input_output_embed:
self.adaptive_softmax = None
if args.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary), args.decoder_embed_dim,
options.eval_str_list(args.adaptive_softmax_cutoff, type=int),
dropout=args.dropout
)
elif not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
def forward(self, prev_output_tokens, encoder_out=None, incremental_state=None):
# embed positions
positions = self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
)
) if self.embed_positions is not None else None
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
positions = positions[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
x += positions
if positions is not None:
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
attn = None
# decoder layers
for layer in self.layers:
x, attn = layer(
x,
encoder_out['encoder_out'],
encoder_out['encoder_padding_mask'],
encoder_out['encoder_out'] if encoder_out is not None else None,
encoder_out['encoder_padding_mask'] if encoder_out is not None else None,
incremental_state,
)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
# project back to size of vocabulary
if self.share_input_output_embed:
x = F.linear(x, self.embed_tokens.weight)
else:
x = F.linear(x, self.embed_out)
if self.adaptive_softmax is None:
# project back to size of vocabulary
if self.share_input_output_embed:
x = F.linear(x, self.embed_tokens.weight)
else:
x = F.linear(x, self.embed_out)
return x, attn
def max_positions(self):
"""Maximum output length supported by the decoder."""
return self.embed_positions.max_positions()
if self.embed_positions is None:
return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions())
def upgrade_state_dict(self, state_dict):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
......@@ -264,6 +342,21 @@ class TransformerDecoder(FairseqIncrementalDecoder):
del state_dict['decoder.embed_positions.weights']
if 'decoder.embed_positions._float_tensor' not in state_dict:
state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor()
for i in range(len(self.layers)):
# update layer norms
layer_norm_map = {
'0': 'self_attn_layer_norm',
'1': 'encoder_attn_layer_norm',
'2': 'final_layer_norm'
}
for old, new in layer_norm_map.items():
for m in ('weight', 'bias'):
k = 'decoder.layers.{}.layer_norms.{}.{}'.format(i, old, m)
if k in state_dict:
state_dict['decoder.layers.{}.{}.{}'.format(i, new, m)] = state_dict[k]
del state_dict[k]
return state_dict
......@@ -322,7 +415,7 @@ class TransformerEncoderLayer(nn.Module):
class TransformerDecoderLayer(nn.Module):
"""Decoder layer block."""
def __init__(self, args):
def __init__(self, args, no_encoder_attn=False):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.self_attn = MultiheadAttention(
......@@ -332,18 +425,28 @@ class TransformerDecoderLayer(nn.Module):
self.dropout = args.dropout
self.relu_dropout = args.relu_dropout
self.normalize_before = args.decoder_normalize_before
self.encoder_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
)
self.self_attn_layer_norm = LayerNorm(self.embed_dim)
if no_encoder_attn:
self.encoder_attn = None
self.encoder_attn_layer_norm = None
else:
self.encoder_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
)
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim)
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)])
self.final_layer_norm = LayerNorm(self.embed_dim)
self.need_attn = True
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state):
residual = x
x = self.maybe_layer_norm(0, x, before=True)
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, before=True)
x, _ = self.self_attn(
query=x,
key=x,
......@@ -354,37 +457,39 @@ class TransformerDecoderLayer(nn.Module):
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(0, x, after=True)
residual = x
x = self.maybe_layer_norm(1, x, before=True)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=(not self.training and self.need_attn),
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(1, x, after=True)
x = self.maybe_layer_norm(self.self_attn_layer_norm, x, after=True)
attn = None
if self.encoder_attn is not None:
residual = x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, before=True)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
need_weights=(not self.training and self.need_attn),
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(self.encoder_attn_layer_norm, x, after=True)
residual = x
x = self.maybe_layer_norm(2, x, before=True)
x = self.maybe_layer_norm(self.final_layer_norm, x, before=True)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=self.relu_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(2, x, after=True)
x = self.maybe_layer_norm(self.final_layer_norm, x, after=True)
return x, attn
def maybe_layer_norm(self, i, x, before=False, after=False):
def maybe_layer_norm(self, layer_norm, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return self.layer_norms[i](x)
return layer_norm(x)
else:
return x
......@@ -395,6 +500,7 @@ class TransformerDecoderLayer(nn.Module):
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
return m
......@@ -412,14 +518,48 @@ def Linear(in_features, out_features, bias=True):
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
if learned:
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
m = LearnedPositionalEmbedding(num_embeddings + padding_idx + 1, embedding_dim, padding_idx, left_pad)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
else:
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, num_embeddings)
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, num_embeddings + padding_idx + 1)
return m
@register_model_architecture('transformer_lm', 'transformer_lm')
def base_lm_architecture(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 2048)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
# The model training is not stable without this
args.decoder_normalize_before = True
@register_model_architecture('transformer_lm', 'transformer_lm_big')
def transformer_lm_big(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
base_lm_architecture(args)
@register_model_architecture('transformer_lm', 'transformer_lm_wiki103')
def transformer_lm_wiki103(args):
args.dropout = getattr(args, 'dropout', 0.3)
base_lm_architecture(args)
@register_model_architecture('transformer_lm', 'transformer_lm_gbw')
def transformer_lm_gbw(args):
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.dropout = getattr(args, 'dropout', 0.1)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
transformer_lm_big(args)
@register_model_architecture('transformer', 'transformer')
def base_architecture(args):
args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
......@@ -439,8 +579,10 @@ def base_architecture(args):
args.attention_dropout = getattr(args, 'attention_dropout', 0.)
args.relu_dropout = getattr(args, 'relu_dropout', 0.)
args.dropout = getattr(args, 'dropout', 0.1)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', None)
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.no_token_positional_embeddings = getattr(args, 'no_token_positional_embeddings', False)
@register_model_architecture('transformer', 'transformer_iwslt_de_en')
......
......@@ -28,7 +28,7 @@ class FairseqTask(object):
def setup_task(cls, args, **kwargs):
raise NotImplementedError
def load_dataset(self, split):
def load_dataset(self, split, combine=False):
raise NotImplementedError
def dataset(self, split):
......
......@@ -5,8 +5,12 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import os
from torch.utils.data import ConcatDataset
from fairseq.data import (
Dictionary, IndexedInMemoryDataset, IndexedRawTextDataset,
MonolingualDataset, TokenBlockDataset,
......@@ -32,9 +36,13 @@ class LanguageModelingTask(FairseqTask):
help='max number of tokens per sample for LM dataset')
parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset')
parser.add_argument('--right-to-left', default=False, action='store_true',
help='if set, trains a language model right-to-left (instead of left-to-right)')
def __init__(self, args, dictionary):
super().__init__(args)
args.right_to_left = getattr(args, 'right_to_left', False)
self.dictionary = dictionary
@classmethod
......@@ -43,23 +51,46 @@ class LanguageModelingTask(FairseqTask):
print('| dictionary: {} types'.format(len(dictionary)))
return cls(args, dictionary)
def load_dataset(self, split):
def load_dataset(self, split, combine=False):
"""Load a dataset split."""
path = os.path.join(self.args.data, split)
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
tokens = [t for l in ds.tokens_list for t in l]
elif not self.args.raw_text and IndexedInMemoryDataset.exists(path):
ds = IndexedInMemoryDataset(path, fix_lua_indexing=True)
tokens = ds.buffer
loaded_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
path = os.path.join(self.args.data, split_k)
if self.args.raw_text and IndexedRawTextDataset.exists(path):
ds = IndexedRawTextDataset(path, self.dictionary)
tokens = [t for l in ds.tokens_list for t in l]
elif not self.args.raw_text and IndexedInMemoryDataset.exists(path):
ds = IndexedInMemoryDataset(path, fix_lua_indexing=True)
tokens = ds.buffer
else:
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
loaded_datasets.append(
TokenBlockDataset(
tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode,
include_targets=True, reverse=self.args.right_to_left,
))
print('| {} {} {} examples'.format(self.args.data, split_k, len(loaded_datasets[-1])))
if not combine:
break
if len(loaded_datasets) == 1:
dataset = loaded_datasets[0]
sizes = dataset.sizes
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
dataset = ConcatDataset(loaded_datasets)
sizes = np.concatenate([ds.sizes for ds in loaded_datasets])
dataset = TokenBlockDataset(
tokens, ds.sizes, self.args.tokens_per_sample, self.args.sample_break_mode,
include_targets=True, # return next tokens as targets
)
self.datasets[split] = MonolingualDataset(dataset, dataset.sizes, self.dictionary, shuffle=False)
self.datasets[split] = MonolingualDataset(dataset, sizes, self.dictionary, shuffle=False)
@property
def target_dictionary(self):
......
......@@ -5,8 +5,12 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import os
from torch.utils.data import ConcatDataset
from fairseq import options
from fairseq.data import (
data_utils, Dictionary, LanguagePairDataset, IndexedInMemoryDataset,
......@@ -65,10 +69,10 @@ class TranslationTask(FairseqTask):
return cls(args, src_dict, tgt_dict)
def load_dataset(self, split):
def load_dataset(self, split, combine=False):
"""Load a dataset split."""
def split_exists(src, tgt, lang):
def split_exists(split, src, tgt, lang):
filename = os.path.join(self.args.data, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True
......@@ -76,15 +80,6 @@ class TranslationTask(FairseqTask):
return True
return False
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
if split_exists(src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, src, tgt))
elif split_exists(tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split, tgt, src))
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
def indexed_dataset(path, dictionary):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
......@@ -92,11 +87,48 @@ class TranslationTask(FairseqTask):
return IndexedInMemoryDataset(path, fix_lua_indexing=True)
return None
src_dataset = indexed_dataset(prefix + src, self.src_dict)
tgt_dataset = indexed_dataset(prefix + tgt, self.tgt_dict)
src_datasets = []
tgt_datasets = []
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
# infer langcode
src, tgt = self.args.source_lang, self.args.target_lang
if split_exists(split_k, src, tgt, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split_k, src, tgt))
elif split_exists(split_k, tgt, src, src):
prefix = os.path.join(self.args.data, '{}.{}-{}.'.format(split_k, tgt, src))
else:
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, self.args.data))
src_datasets.append(indexed_dataset(prefix + src, self.src_dict))
tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict))
print('| {} {} {} examples'.format(self.args.data, split_k, len(src_datasets[-1])))
if not combine:
break
assert len(src_datasets) == len(tgt_datasets)
if len(src_datasets) == 1:
src_dataset, tgt_dataset = src_datasets[0], tgt_datasets[0]
src_sizes = src_dataset.sizes
tgt_sizes = tgt_dataset.sizes
else:
src_dataset = ConcatDataset(src_datasets)
tgt_dataset = ConcatDataset(tgt_datasets)
src_sizes = np.concatenate([ds.sizes for ds in src_datasets])
tgt_sizes = np.concatenate([ds.sizes for ds in tgt_datasets])
self.datasets[split] = LanguagePairDataset(
src_dataset, src_dataset.sizes, self.src_dict,
tgt_dataset, tgt_dataset.sizes, self.tgt_dict,
src_dataset, src_sizes, self.src_dict,
tgt_dataset, tgt_sizes, self.tgt_dict,
left_pad_source=self.args.left_pad_source,
left_pad_target=self.args.left_pad_target,
max_source_positions=self.args.max_source_positions,
......
......@@ -32,7 +32,7 @@ def main(args):
task = tasks.setup_task(args)
# Load dataset splits
load_dataset_splits(args, task, ['train', 'valid'])
load_dataset_splits(task, ['train', 'valid'])
# Build model and criterion
model = task.build_model(args)
......@@ -263,16 +263,16 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss):
checkpoint_conds = collections.OrderedDict()
checkpoint_conds['checkpoint{}.pt'.format(epoch)] = (
end_of_epoch and not args.no_epoch_checkpoints and
epoch % args.save_interval == 0
end_of_epoch and not args.no_epoch_checkpoints and
epoch % args.save_interval == 0
)
checkpoint_conds['checkpoint_{}_{}.pt'.format(epoch, updates)] = (
not end_of_epoch and args.save_interval_updates > 0 and
updates % args.save_interval_updates == 0
not end_of_epoch and args.save_interval_updates > 0 and
updates % args.save_interval_updates == 0
)
checkpoint_conds['checkpoint_best.pt'] = (
val_loss is not None and
(not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
val_loss is not None and
(not hasattr(save_checkpoint, 'best') or val_loss < save_checkpoint.best)
)
checkpoint_conds['checkpoint_last.pt'] = True # keep this last so that it's a symlink
......@@ -316,17 +316,19 @@ def load_checkpoint(args, trainer, epoch_itr):
save_checkpoint.best = extra_state['best']
def load_dataset_splits(args, task, splits):
def load_dataset_splits(task, splits):
for split in splits:
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
try:
task.load_dataset(split_k)
print('| {} {} {} examples'.format(args.data, split_k, len(task.dataset(split_k))))
except FileNotFoundError as e:
if k > 0:
break
raise e
if split == 'train':
task.load_dataset(split, combine=True)
else:
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
try:
task.load_dataset(split_k, combine=False)
except FileNotFoundError as e:
if k > 0:
break
raise e
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment