Commit ff68a9ef authored by Myle Ott's avatar Myle Ott
Browse files

Add FairseqTask

A Task defines the data format, stores shared state (e.g., dictionaries) and provides helpers for building the model/criterion and calculating the loss.

Changes:
- Add TranslationTask and LanguageModelingTask. New tasks can be registered with @register_task decorator.
- Add EpochBatchIterator to encapsulate batching and saving/restoring dataloader position
- Remove LEFT_PAD_* constants and make them configurable per task
parent 2de93532
...@@ -9,45 +9,43 @@ ...@@ -9,45 +9,43 @@
import numpy as np import numpy as np
import torch import torch
from fairseq import options, utils, progress_bar from fairseq import data, options, progress_bar, tasks, utils
from fairseq.data import data_utils, data_loaders
from fairseq.meters import StopwatchMeter, TimeMeter from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer from fairseq.sequence_scorer import SequenceScorer
def main(args): def main(args):
assert args.path is not None, '--path required for evaluation!' assert args.path is not None, '--path required for evaluation!'
print(args)
if args.max_target_positions is None: if args.tokens_per_sample is None:
args.max_target_positions = 1024 args.tokens_per_sample = 1024
print(args)
use_cuda = torch.cuda.is_available() and not args.cpu use_cuda = torch.cuda.is_available() and not args.cpu
dataset = data_loaders.load_dataset(args, [args.gen_subset], False)
# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(args.path)) print('| loading model(s) from {}'.format(args.path))
models, _ = utils.load_ensemble_for_inference(args.path.split(','), dataset.src_dict, dataset.dst_dict) models, _ = utils.load_ensemble_for_inference(args.path.split(','), task)
print('| Dictionary: {} types'.format(len(dataset.src_dict)))
print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))
# Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
for model in models: for model in models:
model.make_generation_fast_() model.make_generation_fast_()
model.src_dict = dataset.src_dict
model.dst_dict = dataset.dst_dict
itr = dataset.eval_dataloader( itr = data.EpochBatchIterator(
args.gen_subset, dataset=task.dataset(args.gen_subset),
max_sentences=args.max_sentences or 4, max_sentences=args.max_sentences or 4,
max_positions=args.max_target_positions or 1024, max_positions=model.max_positions(),
descending=True, num_shards=args.num_shards,
) shard_id=args.shard_id,
itr = data_utils.ShardedIterator(itr, args.num_shards, args.shard_id) ).next_epoch_itr(shuffle=False)
gen_timer = StopwatchMeter() gen_timer = StopwatchMeter()
scorer = SequenceScorer(models) scorer = SequenceScorer(models, task.target_dictionary)
if use_cuda: if use_cuda:
scorer.cuda() scorer.cuda()
...@@ -62,7 +60,7 @@ def main(args): ...@@ -62,7 +60,7 @@ def main(args):
inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf')) inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
if inf_scores.any(): if inf_scores.any():
print('| Skipping tokens with inf scores:', print('| Skipping tokens with inf scores:',
dataset.src_dict.string(hypo['tokens'][inf_scores.nonzero()])) task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
pos_scores = pos_scores[(~inf_scores).nonzero()] pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum() score_sum += pos_scores.sum()
count += pos_scores.numel() count += pos_scores.numel()
......
...@@ -22,11 +22,11 @@ $ python preprocess.py --only-source \ ...@@ -22,11 +22,11 @@ $ python preprocess.py --only-source \
# Train the model: # Train the model:
# If it runs out of memory, try to reduce max-tokens and max-target-positions # If it runs out of memory, try to reduce max-tokens and max-target-positions
$ mkdir -p checkpoints/wikitext-103 $ mkdir -p checkpoints/wikitext-103
$ python train.py data-bin/wikitext-103 --save-dir /checkpoints/wikitext-103 \ $ python train.py --task language_modeling data-bin/wikitext-103 \
--max-epoch 35 --arch fconv_lm --optimizer nag --lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \ --max-epoch 35 --arch fconv_lm_dauphin_wikitext103 --optimizer nag \
--decoder-layers '[(850, 6)] * 3 + [(850,1)] + [(850,5)] * 4 + [(850,1)] + [(850,4)] * 3 + [(1024,4)] + [(2048, 4)]' \ --lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \
--decoder-embed-dim 280 --clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion adaptive_loss \ --clip-norm 0.1 --dropout 0.2 --weight-decay 5e-06 --criterion adaptive_loss \
--adaptive-softmax-cutoff 10000,20000,200000 --max-tokens 1024 --max-target-positions 1024 --adaptive-softmax-cutoff 10000,20000,200000 --max-tokens 1024 --tokens-per-sample 1024
# Evaluate: # Evaluate:
$ python eval_lm.py data-bin/wikitext-103 --path 'checkpoints/wiki103/checkpoint_best.pt' $ python eval_lm.py data-bin/wikitext-103 --path 'checkpoints/wiki103/checkpoint_best.pt'
......
...@@ -28,6 +28,3 @@ $ python train.py data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --clip- ...@@ -28,6 +28,3 @@ $ python train.py data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --clip-
# Generate: # Generate:
$ python generate.py data-bin/writingPrompts --path /path/to/trained/model/checkpoint_best.pt --batch-size 32 --beam 1 --sampling --sampling-topk 10 --sampling-temperature 0.8 --nbest 1 $ python generate.py data-bin/writingPrompts --path /path/to/trained/model/checkpoint_best.pt --batch-size 32 --beam 1 --sampling --sampling-topk 10 --sampling-temperature 0.8 --nbest 1
``` ```
...@@ -103,4 +103,3 @@ $ python generate.py data-bin/fconv_wmt_en_fr \ ...@@ -103,4 +103,3 @@ $ python generate.py data-bin/fconv_wmt_en_fr \
--path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt --beam 5 --remove-bpe --path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt --beam 5 --remove-bpe
``` ```
...@@ -15,8 +15,8 @@ CRITERION_REGISTRY = {} ...@@ -15,8 +15,8 @@ CRITERION_REGISTRY = {}
CRITERION_CLASS_NAMES = set() CRITERION_CLASS_NAMES = set()
def build_criterion(args, src_dict, dst_dict): def build_criterion(args, task):
return CRITERION_REGISTRY[args.criterion](args, src_dict, dst_dict) return CRITERION_REGISTRY[args.criterion](args, task)
def register_criterion(name): def register_criterion(name):
......
...@@ -19,8 +19,8 @@ class AdaptiveLoss(FairseqCriterion): ...@@ -19,8 +19,8 @@ class AdaptiveLoss(FairseqCriterion):
graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs" graphical processing units (GPU), described in the paper "Efficient softmax approximation for GPUs"
(http://arxiv.org/abs/1609.04309).""" (http://arxiv.org/abs/1609.04309)."""
def __init__(self, args, src_dict, dst_dict): def __init__(self, args, task):
super().__init__(args, src_dict, dst_dict) super().__init__(args, task)
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
......
...@@ -16,8 +16,8 @@ from . import FairseqCriterion, register_criterion ...@@ -16,8 +16,8 @@ from . import FairseqCriterion, register_criterion
@register_criterion('cross_entropy') @register_criterion('cross_entropy')
class CrossEntropyCriterion(FairseqCriterion): class CrossEntropyCriterion(FairseqCriterion):
def __init__(self, args, src_dict, dst_dict): def __init__(self, args, task):
super().__init__(args, src_dict, dst_dict) super().__init__(args, task)
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
......
...@@ -10,10 +10,10 @@ from torch.nn.modules.loss import _Loss ...@@ -10,10 +10,10 @@ from torch.nn.modules.loss import _Loss
class FairseqCriterion(_Loss): class FairseqCriterion(_Loss):
def __init__(self, args, src_dict, dst_dict): def __init__(self, args, task):
super().__init__() super().__init__()
self.args = args self.args = args
self.padding_idx = dst_dict.pad() self.padding_idx = task.target_dictionary.pad()
@staticmethod @staticmethod
def add_args(parser): def add_args(parser):
......
...@@ -15,8 +15,8 @@ from . import FairseqCriterion, register_criterion ...@@ -15,8 +15,8 @@ from . import FairseqCriterion, register_criterion
@register_criterion('label_smoothed_cross_entropy') @register_criterion('label_smoothed_cross_entropy')
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def __init__(self, args, src_dict, dst_dict): def __init__(self, args, task):
super().__init__(args, src_dict, dst_dict) super().__init__(args, task)
self.eps = args.label_smoothing self.eps = args.label_smoothing
@staticmethod @staticmethod
......
...@@ -6,7 +6,10 @@ ...@@ -6,7 +6,10 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from .dictionary import Dictionary from .dictionary import Dictionary
from .token_block_dataset import TokenBlockDataset from .fairseq_dataset import FairseqDataset
from .language_dataset import LanguageDatasets from .indexed_dataset import IndexedInMemoryDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
from .token_block_dataset import TokenBlockDataset
from .data_utils import EpochBatchIterator
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
from fairseq.data import LanguagePairDataset, MonolingualDataset
from fairseq.data.data_utils import infer_language_pair
def load_dataset(args, splits, is_raw):
""" Detect if we have a multi language dataset, or a single language dataset """
if args.source_lang is None and args.target_lang is None:
# find language pair automatically
args.source_lang, args.target_lang = infer_language_pair(args.data, splits)
if args.source_lang is None and args.target_lang is None and all(
os.path.exists(os.path.join(args.data, '{}.bin'.format(split))) for split in splits):
cls = MonolingualDataset
else:
cls = LanguagePairDataset
return cls.create_dataset(args, splits, is_raw)
...@@ -6,51 +6,25 @@ ...@@ -6,51 +6,25 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import contextlib import contextlib
import glob
import itertools import itertools
import math
import numbers
import numpy as np
import os import os
import numpy as np
import torch import torch
from torch.autograd import Variable
import torch.utils.data
from fairseq.data.dictionary import Dictionary from . import FairseqDataset
from fairseq.data.indexed_dataset import SizedDataset
def has_binary_files(data_dir, splits): def infer_language_pair(path):
for split in splits:
if len(glob.glob(os.path.join(data_dir, '{}*.bin'.format(split)))) == 0:
return False
return True
def infer_language_pair(path, splits):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx""" """Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src, dst = None, None src, dst = None, None
for filename in os.listdir(path): for filename in os.listdir(path):
parts = filename.split('.') parts = filename.split('.')
for split in splits: if len(parts) >= 3 and len(parts[1].split('-')) == 2:
if len(parts) >= 3 and parts[0] == split and parts[-1] == 'idx': return parts[1].split('-')
src, dst = parts[1].split('-')
break
return src, dst return src, dst
def load_dictionaries(path, src_lang, dst_lang):
"""Load dictionaries for a given language pair."""
src_dict = Dictionary.load(os.path.join(path, 'dict.{}.txt'.format(src_lang)))
dst_dict = Dictionary.load(os.path.join(path, 'dict.{}.txt'.format(dst_lang)))
return src_dict, dst_dict
def fmt_path(path, fmt, *args):
return os.path.join(path, fmt.format(*args))
class ShardedIterator(object): class ShardedIterator(object):
"""A sharded wrapper around an iterable (padded to length).""" """A sharded wrapper around an iterable (padded to length)."""
...@@ -78,7 +52,35 @@ class ShardedIterator(object): ...@@ -78,7 +52,35 @@ class ShardedIterator(object):
return next(self.itr)[1] return next(self.itr)[1]
class CountingIterator(object):
"""Wrapper around an iterable that maintains the iteration count."""
def __init__(self, iterable):
self.iterable = iterable
self.count = 0
self.itr = iter(self)
def __len__(self):
return len(self.iterable)
def __iter__(self):
for x in self.iterable:
self.count += 1
yield x
def __next__(self):
return next(self.itr)
def has_next(self):
return self.count < len(self)
def skip(self, num_to_skip):
next(itertools.islice(self.itr, num_to_skip, num_to_skip), None)
return self
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False): def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size = max(v.size(0) for v in values) size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(pad_idx) res = values[0].new(len(values), size).fill_(pad_idx)
...@@ -96,114 +98,149 @@ def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=Fal ...@@ -96,114 +98,149 @@ def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=Fal
return res return res
def _valid_size(src_size, dst_size, max_positions): class EpochBatchIterator(object):
if isinstance(max_positions, numbers.Number): """Iterate over a FairseqDataset and yield batches bucketed by size.
max_src_positions, max_dst_positions = max_positions, max_positions
else: Batches may contain sequences of different lengths. This iterator can be
max_src_positions, max_dst_positions = max_positions reused across multiple epochs with the next_epoch_itr() method.
if src_size < 1 or src_size > max_src_positions:
return False Args:
if dst_size is not None and (dst_size < 1 or dst_size > max_dst_positions): dataset: a FairseqDataset
return False max_tokens: max number of tokens in each batch
return True max_sentences: max number of sentences in each batch
max_positions: max sentence length supported by the model
ignore_invalid_inputs: don't raise Exception for sentences that are too long
required_batch_size_multiple: require batch size to be a multiple of N
seed: seed for random number generator for reproducibility
num_shards: shard the data iterator into N shards
shard_id: which shard of the data iterator to return
"""
def __init__(
self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
ignore_invalid_inputs=False, required_batch_size_multiple=1, seed=1,
num_shards=1, shard_id=0,
):
assert isinstance(dataset, FairseqDataset)
self.dataset = dataset
self.max_tokens = max_tokens if max_tokens is not None else float('Inf')
self.max_sentences = max_sentences if max_sentences is not None else float('Inf')
self.max_positions = max_positions
self.ignore_invalid_inputs = ignore_invalid_inputs
self.bsz_mult = required_batch_size_multiple
self.seed = seed
self.num_shards = num_shards
self.shard_id = shard_id
with numpy_seed(self.seed):
self.frozen_batches = tuple(self._batch_generator())
self.epoch = 0
self._cur_epoch_itr = None
self._next_epoch_itr = None
def __len__(self):
return len(self.frozen_batches)
def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions, def next_epoch_itr(self, shuffle=True):
ignore_invalid_inputs=False, allow_different_src_lens=False, """Shuffle batches and return a new iterator over the dataset."""
required_batch_size_multiple=1): if self._next_epoch_itr is not None:
batch = [] self._cur_epoch_itr = self._next_epoch_itr
mult = required_batch_size_multiple self._next_epoch_itr = None
else:
def yield_batch(next_idx, num_tokens): self.epoch += 1
if len(batch) == 0: self._cur_epoch_itr = self._get_iterator_for_epoch(self.epoch, shuffle)
return self._cur_epoch_itr
def end_of_epoch(self):
return not self._cur_epoch_itr.has_next()
@property
def iterations_in_epoch(self):
if self._cur_epoch_itr is not None:
return self._cur_epoch_itr.count
elif self._next_epoch_itr is not None:
return self._next_epoch_itr.count
return 0
def state_dict(self):
return {
'epoch': self.epoch,
'iterations_in_epoch': self.iterations_in_epoch,
}
def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
itr_pos = state_dict.get('iterations_in_epoch', 0)
if itr_pos > 0:
# fast-forward epoch iterator
itr = self._get_iterator_for_epoch(self.epoch, state_dict.get('shuffle', True))
if itr_pos < len(itr):
self._next_epoch_itr = itr.skip(itr_pos)
def _get_iterator_for_epoch(self, epoch, shuffle):
if shuffle:
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with numpy_seed(self.seed + epoch):
batches = list(self.frozen_batches) # copy
np.random.shuffle(batches)
else:
batches = self.frozen_batches
return CountingIterator(torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.dataset.collater,
batch_sampler=ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]),
))
def _batch_generator(self):
batch = []
def is_batch_full(num_tokens):
if len(batch) == 0:
return False
if len(batch) == self.max_sentences:
return True
if num_tokens > self.max_tokens:
return True
return False return False
if len(batch) == max_sentences:
return True sample_len = 0
if num_tokens > max_tokens: sample_lens = []
return True ignored = []
if not allow_different_src_lens and \ for idx in self.dataset.ordered_indices():
(src.sizes[batch[0]] != src.sizes[next_idx]): if not self.dataset.valid_size(idx, self.max_positions):
return True if self.ignore_invalid_inputs:
return False ignored.append(idx)
continue
sample_len = 0 raise Exception((
sample_lens = [] 'Size of sample #{} is invalid, max_positions={}, skip this '
ignored = [] 'example with --skip-invalid-size-inputs-valid-test'
for idx in map(int, indices): ).format(idx, self.max_positions))
src_size = src.sizes[idx]
dst_size = dst.sizes[idx] if dst else src_size sample_lens.append(self.dataset.num_tokens(idx))
if not _valid_size(src_size, dst_size, max_positions): sample_len = max(sample_len, sample_lens[-1])
if ignore_invalid_inputs: num_tokens = (len(batch) + 1) * sample_len
ignored.append(idx) if is_batch_full(num_tokens):
continue mod_len = max(
raise Exception(( self.bsz_mult * (len(batch) // self.bsz_mult),
"Sample #{} has size (src={}, dst={}) but max size is {}." len(batch) % self.bsz_mult,
" Skip this example with --skip-invalid-size-inputs-valid-test" )
).format(idx, src_size, dst_size, max_positions)) yield batch[:mod_len]
batch = batch[mod_len:]
sample_lens.append(max(src_size, dst_size)) sample_lens = sample_lens[mod_len:]
sample_len = max(sample_len, sample_lens[-1]) sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
num_tokens = (len(batch) + 1) * sample_len
if yield_batch(idx, num_tokens): batch.append(idx)
mod8_len = max(mult * (len(batch) // mult), len(batch) % mult)
yield batch[:mod8_len] if len(batch) > 0:
batch = batch[mod8_len:] yield batch
sample_lens = sample_lens[mod8_len:]
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 if len(ignored) > 0:
print((
batch.append(idx) '| WARNING: {} samples have invalid sizes and will be skipped, '
'max_positions={}, first few sample ids={}'
if len(batch) > 0: ).format(len(ignored), self.max_positions, ignored[:10]))
yield batch
if len(ignored) > 0:
print("Warning! {} samples are either too short or too long "
"and will be ignored, first few sample ids={}".format(len(ignored), ignored[:10]))
def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), ignore_invalid_inputs=False,
descending=False, required_batch_size_multiple=1, allow_different_src_lens=False):
"""Returns batches of indices sorted by size. Sequences with different
source lengths are not allowed in the same batch."""
assert isinstance(src, SizedDataset) and (dst is None or isinstance(dst, SizedDataset))
if max_tokens is None:
max_tokens = float('Inf')
if max_sentences is None:
max_sentences = float('Inf')
indices = np.argsort(src.sizes, kind='mergesort')
if descending:
indices = np.flip(indices, 0)
return list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, allow_different_src_lens=allow_different_src_lens,
required_batch_size_multiple=required_batch_size_multiple,
))
def uneven_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024),
required_batch_size_multiple=1):
"""Returns batches of indices bucketed by size. Batches may contain
sequences of different lengths."""
assert isinstance(src, SizedDataset) and isinstance(dst, SizedDataset)
if max_tokens is None:
max_tokens = float('Inf')
if max_sentences is None:
max_sentences = float('Inf')
indices = np.random.permutation(len(src))
# sort by sizes
indices = indices[np.argsort(dst.sizes[indices], kind='mergesort')]
indices = indices[np.argsort(src.sizes[indices], kind='mergesort')]
batches = list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs=True, allow_different_src_lens=True,
required_batch_size_multiple=required_batch_size_multiple,
))
return batches
@contextlib.contextmanager @contextlib.contextmanager
...@@ -219,21 +256,3 @@ def numpy_seed(seed): ...@@ -219,21 +256,3 @@ def numpy_seed(seed):
yield yield
finally: finally:
np.random.set_state(state) np.random.set_state(state)
def get_dummy_batch(ntokens, src_dict, dst_dict, src_len=128, tgt_len=128):
bsz = int(ntokens / max(src_len, tgt_len))
bsz = math.ceil(bsz / 8) * 8
assert src_dict.pad() == dst_dict.pad()
pad_idx = src_dict.pad()
src_vocab, dst_vocab = len(src_dict), len(dst_dict)
dummy_batch = {}
dummy_batch['id'] = Variable(torch.arange(bsz).long().cuda())
dummy_batch['ntokens'] = tgt_len * bsz
dummy_batch['target'] = Variable(torch.Tensor(bsz, tgt_len).uniform_(pad_idx + 1, dst_vocab - 1).long().cuda())
input = {}
input['prev_output_tokens'] = Variable(dummy_batch['target'].data.clone())
input['src_lengths'] = Variable(torch.LongTensor(bsz).fill_(src_len).cuda())
input['src_tokens'] = Variable(torch.Tensor(bsz, src_len).uniform_(pad_idx + 1, src_vocab - 1).long().cuda())
dummy_batch['net_input'] = input
return dummy_batch
...@@ -188,3 +188,8 @@ class Dictionary(object): ...@@ -188,3 +188,8 @@ class Dictionary(object):
return self.save(fd, threshold, nwords) return self.save(fd, threshold, nwords)
for symbol, count in zip(self.symbols[self.nspecial:], self.count[self.nspecial:]): for symbol, count in zip(self.symbols[self.nspecial:], self.count[self.nspecial:]):
print('{} {}'.format(symbol, count), file=f) print('{} {}'.format(symbol, count), file=f)
def dummy_sentence(self, length):
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
t[-1] = self.eos()
return t
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.utils.data
class FairseqDataset(torch.utils.data.Dataset):
"""A dataset that provides helpers for batching."""
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
raise NotImplementedError
def get_dummy_batch(self, num_tokens, max_positions):
"""Return a dummy batch with a given number of tokens."""
raise NotImplementedError
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
raise NotImplementedError
def ordered_indices(self):
"""Ordered indices for batching."""
raise NotImplementedError
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
raise NotImplementedError
...@@ -10,7 +10,6 @@ import struct ...@@ -10,7 +10,6 @@ import struct
import numpy as np import numpy as np
import torch import torch
import torch.utils.data
from fairseq.tokenizer import Tokenizer from fairseq.tokenizer import Tokenizer
...@@ -50,16 +49,7 @@ def data_file_path(prefix_path): ...@@ -50,16 +49,7 @@ def data_file_path(prefix_path):
return prefix_path + '.bin' return prefix_path + '.bin'
class SizedDataset(torch.utils.data.Dataset): class IndexedDataset(torch.utils.data.Dataset):
def __init__(self):
self._sizes = None
@property
def sizes(self):
return self._sizes
class IndexedDataset(SizedDataset):
"""Loader for TorchNet IndexedDataset""" """Loader for TorchNet IndexedDataset"""
def __init__(self, path): def __init__(self, path):
...@@ -74,7 +64,7 @@ class IndexedDataset(SizedDataset): ...@@ -74,7 +64,7 @@ class IndexedDataset(SizedDataset):
self.size, self.s = struct.unpack('<QQ', f.read(16)) self.size, self.s = struct.unpack('<QQ', f.read(16))
self.dim_offsets = read_longs(f, self.size + 1) self.dim_offsets = read_longs(f, self.size + 1)
self.data_offsets = read_longs(f, self.size + 1) self.data_offsets = read_longs(f, self.size + 1)
self._sizes = read_longs(f, self.s) self.sizes = read_longs(f, self.s)
self.read_data(path) self.read_data(path)
def read_data(self, path): def read_data(self, path):
...@@ -93,7 +83,7 @@ class IndexedDataset(SizedDataset): ...@@ -93,7 +83,7 @@ class IndexedDataset(SizedDataset):
a = np.empty(tensor_size, dtype=self.dtype) a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a) self.data_file.readinto(a)
return torch.from_numpy(a) return torch.from_numpy(a).long() - 1 # subtract 1 for 0-based indexing
def __len__(self): def __len__(self):
return self.size return self.size
...@@ -114,6 +104,7 @@ class IndexedInMemoryDataset(IndexedDataset): ...@@ -114,6 +104,7 @@ class IndexedInMemoryDataset(IndexedDataset):
self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype) self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype)
self.data_file.readinto(self.buffer) self.data_file.readinto(self.buffer)
self.data_file.close() self.data_file.close()
self.buffer -= 1 # subtract 1 for 0-based indexing
def __del__(self): def __del__(self):
pass pass
...@@ -123,7 +114,7 @@ class IndexedInMemoryDataset(IndexedDataset): ...@@ -123,7 +114,7 @@ class IndexedInMemoryDataset(IndexedDataset):
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype) a = np.empty(tensor_size, dtype=self.dtype)
np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]]) np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]])
return torch.from_numpy(a) return torch.from_numpy(a).long()
class IndexedRawTextDataset(IndexedDataset): class IndexedRawTextDataset(IndexedDataset):
...@@ -133,7 +124,7 @@ class IndexedRawTextDataset(IndexedDataset): ...@@ -133,7 +124,7 @@ class IndexedRawTextDataset(IndexedDataset):
def __init__(self, path, dictionary, append_eos=True, reverse_order=False): def __init__(self, path, dictionary, append_eos=True, reverse_order=False):
self.tokens_list = [] self.tokens_list = []
self.lines = [] self.lines = []
self._sizes = [] self.sizes = []
self.append_eos = append_eos self.append_eos = append_eos
self.reverse_order = reverse_order self.reverse_order = reverse_order
self.read_data(path, dictionary) self.read_data(path, dictionary)
...@@ -146,10 +137,10 @@ class IndexedRawTextDataset(IndexedDataset): ...@@ -146,10 +137,10 @@ class IndexedRawTextDataset(IndexedDataset):
tokens = Tokenizer.tokenize( tokens = Tokenizer.tokenize(
line, dictionary, add_if_not_exist=False, line, dictionary, add_if_not_exist=False,
append_eos=self.append_eos, reverse_order=self.reverse_order, append_eos=self.append_eos, reverse_order=self.reverse_order,
) + 1 # +1 for Lua compatibility ).long()
self.tokens_list.append(tokens) self.tokens_list.append(tokens)
self._sizes.append(len(tokens)) self.sizes.append(len(tokens))
self._sizes = np.array(self._sizes) self.sizes = np.array(self.sizes)
def __getitem__(self, i): def __getitem__(self, i):
self.check_index(i) self.check_index(i)
...@@ -165,6 +156,10 @@ class IndexedRawTextDataset(IndexedDataset): ...@@ -165,6 +156,10 @@ class IndexedRawTextDataset(IndexedDataset):
def __len__(self): def __len__(self):
return self.size return self.size
@staticmethod
def exists(path):
return os.path.exists(path)
class IndexedDatasetBuilder(object): class IndexedDatasetBuilder(object):
element_sizes = { element_sizes = {
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import torch
from fairseq.data.data_utils import numpy_seed, uneven_batches_by_size, ShardedIterator, batches_by_size
class LanguageDatasets(object):
def __init__(self, src, dst, src_dict, dst_dict):
self.src = src
self.dst = dst
self.src_dict = src_dict
self.dst_dict = dst_dict
self.splits = {}
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
def train_dataloader_generator(
self, split, max_tokens=None, max_sentences=None,
max_positions=(1024, 1024), seed=None, sample_without_replacement=0,
shard_id=0, num_shards=1
):
dataset = self.splits[split]
with numpy_seed(seed):
batches = uneven_batches_by_size(
dataset.src, dataset.dst, max_tokens=max_tokens,
max_sentences=max_sentences, max_positions=max_positions,
# FP16: during training keep the batch size a multiple of 8
required_batch_size_multiple=8,
)
frozen_batches = tuple(batches) # freeze
def dataloader(b):
b = ShardedIterator(b, num_shards, shard_id, fill_value=[])
return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collater, batch_sampler=b)
for epoch in itertools.count(1):
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with numpy_seed(seed + epoch):
batches = list(frozen_batches) # copy
np.random.shuffle(batches)
if sample_without_replacement > 0:
# emit sub-epoch dataloaders
while len(batches) >= sample_without_replacement:
sampled_batches = batches[:sample_without_replacement]
remaining_batches = batches[sample_without_replacement:]
yield dataloader(sampled_batches)
batches = remaining_batches
if len(batches) > 0:
yield dataloader(batches)
else:
# emit full dataloader
yield dataloader(batches)
def eval_dataloader(self, split, num_workers=0, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False,
descending=False, shard_id=0, num_shards=1):
dataset = self.splits[split]
batch_sampler = batches_by_size(
dataset.src, dataset.dst, max_tokens, max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test,
descending=descending,
allow_different_src_lens=True)
batch_sampler = ShardedIterator(batch_sampler, num_shards, shard_id, fill_value=[])
return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler)
...@@ -5,29 +5,24 @@ ...@@ -5,29 +5,24 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import itertools import numpy as np
import os
import torch import torch
import torch.utils
from fairseq.data import LanguageDatasets from . import data_utils, FairseqDataset
from fairseq.data.consts import LEFT_PAD_TARGET, LEFT_PAD_SOURCE
from fairseq.data.data_utils import fmt_path, load_dictionaries, collate_tokens
from fairseq.data.indexed_dataset import IndexedInMemoryDataset, IndexedRawTextDataset
def collate(samples, pad_idx, eos_idx, has_target): def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False):
if len(samples) == 0: if len(samples) == 0:
return {} return {}
def merge(key, left_pad, move_eos_to_beginning=False): def merge(key, left_pad, move_eos_to_beginning=False):
return collate_tokens( return data_utils.collate_tokens(
[s[key] for s in samples], [s[key] for s in samples],
pad_idx, eos_idx, left_pad, move_eos_to_beginning, pad_idx, eos_idx, left_pad, move_eos_to_beginning,
) )
id = torch.LongTensor([s['id'] for s in samples]) id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=LEFT_PAD_SOURCE) src_tokens = merge('source', left_pad=left_pad_source)
# sort by descending source length # sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples]) src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True) src_lengths, sort_order = src_lengths.sort(descending=True)
...@@ -36,19 +31,20 @@ def collate(samples, pad_idx, eos_idx, has_target): ...@@ -36,19 +31,20 @@ def collate(samples, pad_idx, eos_idx, has_target):
prev_output_tokens = None prev_output_tokens = None
target = None target = None
ntokens = None if samples[0].get('target', None) is not None:
if has_target: target = merge('target', left_pad=left_pad_target)
target = merge('target', left_pad=LEFT_PAD_TARGET)
# we create a shifted version of targets for feeding the # we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step # previous output token(s) into the next decoder step
prev_output_tokens = merge( prev_output_tokens = merge(
'target', 'target',
left_pad=LEFT_PAD_TARGET, left_pad=left_pad_target,
move_eos_to_beginning=True, move_eos_to_beginning=True,
) )
prev_output_tokens = prev_output_tokens.index_select(0, sort_order) prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
target = target.index_select(0, sort_order) target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples) ntokens = sum(len(s['target']) for s in samples)
else:
ntokens = sum(len(s['source']) for s in samples)
return { return {
'id': id, 'id': id,
...@@ -62,93 +58,87 @@ def collate(samples, pad_idx, eos_idx, has_target): ...@@ -62,93 +58,87 @@ def collate(samples, pad_idx, eos_idx, has_target):
} }
class LanguagePairDataset(torch.utils.data.Dataset): class LanguagePairDataset(FairseqDataset):
"""A pair of torch.utils.data.Datasets."""
def __init__(self, src, dst, pad_idx, eos_idx):
def __init__(
self, src, src_sizes, src_dict,
tgt=None, tgt_sizes=None, tgt_dict=None,
left_pad_source=True, left_pad_target=False,
max_source_positions=1024, max_target_positions=1024,
shuffle=True,
):
if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
self.src = src self.src = src
self.dst = dst self.tgt = tgt
self.pad_idx = pad_idx self.src_sizes = np.array(src_sizes)
self.eos_idx = eos_idx self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
self.src_dict = src_dict
def __getitem__(self, i): self.tgt_dict = tgt_dict
# subtract 1 for 0-based indexing self.left_pad_source = left_pad_source
source = self.src[i].long() - 1 self.left_pad_target = left_pad_target
res = {'id': i, 'source': source} self.max_source_positions = max_source_positions
if self.dst: self.max_target_positions = max_target_positions
res['target'] = self.dst[i].long() - 1 self.shuffle = shuffle
return res def __getitem__(self, index):
return {
'id': index,
'source': self.src[index],
'target': self.tgt[index] if self.tgt is not None else None,
}
def __len__(self): def __len__(self):
return len(self.src) return len(self.src)
def collater(self, samples): def collater(self, samples):
return collate(samples, self.pad_idx, self.eos_idx, self.dst is not None) """Merge a list of samples to form a mini-batch."""
return collate(
@staticmethod samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
def create_dataset(args, splits, is_raw): left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target,
src, dst = args.source_lang, args.target_lang )
assert src is not None and dst is not None, 'Source and target languages should be provided'
def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128):
src_dict, dst_dict = load_dictionaries(args.data, src, dst) max_source_positions, max_target_positions = self._get_max_positions(max_positions)
dataset = LanguageDatasets(src, dst, src_dict, dst_dict) src_len, tgt_len = min(src_len, max_source_positions), min(tgt_len, max_target_positions)
bsz = num_tokens // max(src_len, tgt_len)
def create_raw_dataset(): return self.collater([
"""Loads specified data splits (e.g., test, train or valid) from raw text {
files in the specified folder.""" 'id': i,
'source': self.src_dict.dummy_sentence(src_len),
# Load dataset from raw text files 'target': self.tgt_dict.dummy_sentence(tgt_len) if self.tgt_dict is not None else None,
for split in splits: }
src_path = os.path.join(args.data, '{}.{}'.format(split, src)) for i in range(bsz)
dst_path = os.path.join(args.data, '{}.{}'.format(split, dst)) ])
dataset.splits[split] = LanguagePairDataset(
IndexedRawTextDataset(src_path, src_dict), def num_tokens(self, index):
IndexedRawTextDataset(dst_path, dst_dict), """Return an example's length (number of tokens), used for batching."""
pad_idx=dataset.src_dict.pad(), return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
eos_idx=dataset.src_dict.eos(),
) def ordered_indices(self):
return dataset """Ordered indices for batching."""
if self.shuffle:
def create_binary_dataset(): indices = np.random.permutation(len(self))
"""Loads specified data splits (e.g., test, train or valid) from the else:
specified folder and check that files exist.""" indices = np.arange(len(self))
if self.tgt_sizes is not None:
# Load dataset from binary files indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
def all_splits_exist(src, dst, lang): return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
for split in splits:
filename = '{0}.{1}-{2}.{3}.idx'.format(split, src, dst, lang) def valid_size(self, index, max_positions):
if not os.path.exists(os.path.join(args.data, filename)): """Check if an example's size is valid according to max_positions."""
return False max_source_positions, max_target_positions = self._get_max_positions(max_positions)
return True return (
self.src_sizes[index] <= max_source_positions
# infer langcode and (self.tgt_sizes is None or self.tgt_sizes[index] <= max_target_positions)
if all_splits_exist(src, dst, src): )
langcode = '{}-{}'.format(src, dst)
elif all_splits_exist(dst, src, src): def _get_max_positions(self, max_positions):
langcode = '{}-{}'.format(dst, src) if max_positions is None:
else: return self.max_source_positions, self.max_target_positions
raise Exception('Dataset cannot be loaded from path: ' + args.data) assert len(max_positions) == 2
max_src_pos, max_tgt_pos = max_positions
for split in splits: return min(self.max_source_positions, max_src_pos), min(self.max_target_positions, max_tgt_pos)
for k in itertools.count():
prefix = "{}{}".format(split, k if k > 0 else '')
src_path = fmt_path(args.data, '{}.{}.{}', prefix, langcode, src)
dst_path = fmt_path(args.data, '{}.{}.{}', prefix, langcode, dst)
if not IndexedInMemoryDataset.exists(src_path):
break
target_dataset = None
if IndexedInMemoryDataset.exists(dst_path):
target_dataset = IndexedInMemoryDataset(dst_path)
dataset.splits[prefix] = LanguagePairDataset(
IndexedInMemoryDataset(src_path),
target_dataset,
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
)
return dataset
return create_raw_dataset() if is_raw else create_binary_dataset()
...@@ -5,117 +5,77 @@ ...@@ -5,117 +5,77 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import itertools import numpy as np
import os
import torch import torch
from torch.utils.data import Dataset from . import data_utils, FairseqDataset
from fairseq.data import TokenBlockDataset, Dictionary, LanguageDatasets
from fairseq.data.indexed_dataset import IndexedInMemoryDataset
from fairseq.data.data_utils import fmt_path, collate_tokens
def collate(samples, pad_idx, eos_idx):
def collate(samples, pad_idx, eos_idx, has_target):
if len(samples) == 0: if len(samples) == 0:
return {} return {}
def merge(key): def merge(key):
return collate_tokens( return data_utils.collate_tokens(
[s[key] for s in samples], [s[key] for s in samples], pad_idx, eos_idx, left_pad=False,
pad_idx, eos_idx, left_pad=False, move_eos_to_beginning=False,
) )
id = torch.LongTensor([s['id'] for s in samples])
# language models only have a decoder which is not padding-aware, so don't left pad for them
src_tokens = merge('source')
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
_, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
target = None
ntokens = None
if has_target:
target = merge('target')
target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
return { return {
'id': id, 'id': torch.LongTensor([s['id'] for s in samples]),
'ntokens': ntokens, 'ntokens': sum(len(s['target']) for s in samples),
'net_input': { 'net_input': {
'src_tokens': src_tokens, 'src_tokens': merge('source'),
}, },
'target': target, 'target': merge('target'),
} }
class MonolingualDataset(Dataset): class MonolingualDataset(FairseqDataset):
"""A wrapper around torch.utils.data.Dataset for monolingual data."""
def __init__(self, tokens, sizes, token_block_size, break_mode, pad_idx, eos_idx, next_token_is_target):
if next_token_is_target:
self.src = TokenBlockDataset(tokens, token_block_size, sizes, offset=1, break_mode=break_mode)
self.dst = TokenBlockDataset(tokens, token_block_size, sizes, offset=0, break_mode=break_mode)
else:
self.src = TokenBlockDataset(tokens, token_block_size, sizes, offset=0, break_mode=break_mode)
self.dst = None
self.pad_idx = pad_idx def __init__(self, dataset, sizes, vocab, shuffle):
self.eos_idx = eos_idx self.dataset = dataset
self.sizes = np.array(sizes)
self.vocab = vocab
self.shuffle = shuffle
def __getitem__(self, i): def __getitem__(self, index):
# subtract 1 for 0-based indexing source, target = self.dataset[index]
source = self.src[i].long() - 1 return {'id': index, 'source': source, 'target': target}
res = {'id': i, 'source': source}
if self.dst:
res['target'] = self.dst[i].long() - 1
return res
def __len__(self): def __len__(self):
return len(self.src) return len(self.dataset)
def collater(self, samples): def collater(self, samples):
return collate(samples, self.pad_idx, self.eos_idx, self.dst is not None) """Merge a list of samples to form a mini-batch."""
return collate(samples, self.vocab.pad(), self.vocab.eos())
@staticmethod
def create_dataset(args, splits, is_raw): def get_dummy_batch(self, num_tokens, max_positions, tgt_len=128):
"""Loads specified data splits (e.g., test, train or valid) from the assert isinstance(max_positions, float) or isinstance(max_positions, int)
specified folder and check that files exist.""" tgt_len = min(tgt_len, max_positions)
bsz = num_tokens // tgt_len
if is_raw: target = self.vocab.dummy_sentence(tgt_len + 1)
raise Exception('raw text single language data sets are currently not supported') source, target = target[:-1], target[1:]
return self.collater([
assert args.sample_break_mode == 'eos' or args.max_target_positions is not None {'id': i, 'source': source, 'target': target}
for i in range(bsz)
path = args.data ])
dict = Dictionary.load(os.path.join(path, 'dict.txt'))
dataset = LanguageDatasets(None, None, dict, dict) def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
assert all(os.path.exists(os.path.join(path, '{}.bin'.format(split))) for split in splits) source, target = self.dataset[index]
return len(source)
for split in splits:
for k in itertools.count(): def ordered_indices(self):
prefix = "{}{}".format(split, k if k > 0 else '') """Ordered indices for batching."""
split_path = fmt_path(path, '{}', prefix) if self.shuffle:
order = [np.random.permutation(len(self))]
if not IndexedInMemoryDataset.exists(split_path): else:
break order = [np.arange(len(self))]
order.append(self.sizes)
ds = IndexedInMemoryDataset(split_path) return np.lexsort(order)
tokens = torch.from_numpy(ds.buffer)
def valid_size(self, index, max_positions):
dataset.splits[prefix] = MonolingualDataset( """Check if an example's size is valid according to max_positions."""
tokens, assert isinstance(max_positions, float) or isinstance(max_positions, int)
ds.sizes, return self.sizes[index] <= max_positions
args.max_target_positions,
args.sample_break_mode,
pad_idx=dataset.src_dict.pad(),
eos_idx=dataset.src_dict.eos(),
next_token_is_target=True,
)
return dataset
...@@ -5,47 +5,49 @@ ...@@ -5,47 +5,49 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import math import math
import numpy as np import numpy as np
import torch import torch
from fairseq.data.indexed_dataset import SizedDataset
class TokenBlockDataset(torch.utils.data.Dataset):
"""Break a 1d tensor of tokens into blocks.
The blocks are fetched from the original tensor so no additional memory is allocated.
class TokenBlockDataset(SizedDataset): Args:
"""Given a 1d tensor of tokens, this dataset will break tokens into blocks based on parameters. The blocks are tokens: 1d tensor of tokens to break into blocks
fetched from the original tensor so no additional memory is allocated""" sizes: sentence lengths (required for 'complete' and 'eos')
block_size: maximum block size (ignored in 'eos' break mode)
break_mode: Mode used for breaking tokens. Values can be one of:
- 'none': break tokens into equally sized blocks (up to block_size)
- 'complete': break tokens into blocks (up to block_size) such that
blocks contains complete sentences, although block_size may be
exceeded if some sentences exceed block_size
- 'eos': each block contains one sentence (block_size is ignored)
include_targets: return next tokens as targets
"""
def __init__(self, tokens, block_size, sizes, offset=0, break_mode=None): def __init__(self, tokens, sizes, block_size, break_mode=None, include_targets=False):
"""
Args:
tokens: torch tensor of tokens to break into blocks
block_size: An integer. the maximum size of each block (note this has no effect in 'eos' break mode)
sizes: A list of integers. sizes of sentences in the block. the sum of the sizes should add up to the
length of tokens
offset: An integer. rotates the tokens by this much before computing blocks. useful for language model targets
break_mode: A boolean if None/'none' then breaks tokens into equally sized blocks of size block_size
if 'complete' then breaks tokens into block sizes of up to block_size such that each block
contains complete sentences. block_size may be exceeded if some sentences exceed block_size
if 'eos' then each block contains a single sentence. does not respect block_size"""
super().__init__() super().__init__()
self.tokens = tokens self.tokens = tokens
self.offset = offset self.total_size = len(tokens)
self.include_targets = include_targets
self.slice_indices = [] self.slice_indices = []
if break_mode is None or break_mode == 'none': if break_mode is None or break_mode == 'none':
length = math.ceil(tokens.numel() / block_size) length = math.ceil(len(tokens) / block_size)
def block_at(i): def block_at(i):
start = i * block_size start = i * block_size
end = min(start + block_size, len(tokens)) end = min(start + block_size, len(tokens))
return (start, end) return (start, end)
self.slice_indices = [block_at(i) for i in np.arange(length)] self.slice_indices = [block_at(i) for i in range(length)]
elif break_mode == 'complete': elif break_mode == 'complete':
assert sizes is not None and sum(sizes) == len(tokens)
tok_idx = 0 tok_idx = 0
sz_idx = 0 sz_idx = 0
curr_size = 0 curr_size = 0
...@@ -60,6 +62,7 @@ class TokenBlockDataset(SizedDataset): ...@@ -60,6 +62,7 @@ class TokenBlockDataset(SizedDataset):
if curr_size > 0: if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size)) self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos': elif break_mode == 'eos':
assert sizes is not None and sum(sizes) == len(tokens)
curr = 0 curr = 0
for sz in sizes: for sz in sizes:
# skip samples with just 1 example (which would be just the eos token) # skip samples with just 1 example (which would be just the eos token)
...@@ -67,19 +70,20 @@ class TokenBlockDataset(SizedDataset): ...@@ -67,19 +70,20 @@ class TokenBlockDataset(SizedDataset):
self.slice_indices.append((curr, curr + sz)) self.slice_indices.append((curr, curr + sz))
curr += sz curr += sz
else: else:
raise Exception('invalid break_mode. Supported values: none, complete, eos') raise ValueError('Invalid break_mode: ' + break_mode)
self._sizes = np.array([e - s for s, e in self.slice_indices]) self.sizes = np.array([e - s for s, e in self.slice_indices])
def _slice(self, s, e): def __getitem__(self, index):
# this will copy only the first block if offset > 0, instead of all blocks if we just rotated s, e = self.slice_indices[index]
# the tensor with torch.cat() item = torch.LongTensor(self.tokens[s:e])
if s < self.offset: if self.include_targets:
return torch.cat([self.tokens[s - self.offset:], self.tokens[s:e - self.offset]]) if e == self.total_size:
return self.tokens[s - self.offset:e - self.offset] return item[:-1], item[1:]
else:
def __getitem__(self, i): return item, torch.LongTensor(self.tokens[s + 1:e + 1])
return self._slice(*self.slice_indices[i]) else:
return item
def __len__(self): def __len__(self):
return len(self.slice_indices) return len(self.slice_indices)
...@@ -49,8 +49,8 @@ class FP16Trainer(Trainer): ...@@ -49,8 +49,8 @@ class FP16Trainer(Trainer):
We do forward/backward with FP16 and compute the loss + optimize with FP32. We do forward/backward with FP16 and compute the loss + optimize with FP32.
""" """
def __init__(self, args, model, criterion): def __init__(self, args, task, model, criterion):
super().__init__(args, model, criterion) super().__init__(args, task, model, criterion)
# convert model to FP16 (but keep criterion FP32) # convert model to FP16 (but keep criterion FP32)
self.model.half() self.model.half()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment