Commit b87c5366 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes (#295)

Summary:
Changelog:
- `90f52a1`: Support loading subsets of the data on each worker with the `--fix-batches-to-gpus` flag. This should fix #217 and #266.
- `6eda0a9`: Update README for replicating the "Scaling Neural Machine Translation" paper
- `b14c7cf`: Fallback to no_c10d backend for pytorch 0.4.1 (fixes #294)
Pull Request resolved: https://github.com/pytorch/fairseq/pull/295

Differential Revision: D10121559

Pulled By: myleott

fbshipit-source-id: 41c84d0ee4cdd113544b5d3aa38ae8b23acc2c27
parent 0bc5c2e9
...@@ -134,25 +134,28 @@ $ python generate.py data-bin/fconv_wmt_en_fr \ ...@@ -134,25 +134,28 @@ $ python generate.py data-bin/fconv_wmt_en_fr \
## Replicating results from "Scaling Neural Machine Translation" ## Replicating results from "Scaling Neural Machine Translation"
To replicate results from the paper [Scaling Neural Machine Translation (Ott et al., 2018)](https://arxiv.org/abs/1806.00187): To replicate results from the paper [Scaling Neural Machine Translation (Ott et al., 2018)](https://arxiv.org/abs/1806.00187),
please first download the [preprocessed WMT'16 En-De data provided by Google](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8).
1. Prepare the WMT'14 En-De data with a BPE vocab of 32k: 1. Extract the WMT'16 En-De data:
``` ```
$ bash prepare-wmt14en2de.sh --scaling18 $ TEXT=wmt16_en_de_bpe32k
$ cd ../.. $ mkdir $TEXT
$ tar -xzvf wmt16_en_de.tar.gz -C $TEXT
``` ```
2. Preprocess the dataset with a joined dictionary: 2. Preprocess the dataset with a joined dictionary:
``` ```
$ TEXT=examples/translation/wmt14_en_de
$ python preprocess.py --source-lang en --target-lang de \ $ python preprocess.py --source-lang en --target-lang de \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ --trainpref $TEXT/train.tok.clean.bpe.32000 \
--destdir data-bin/wmt14_en_de_joined_dict \ --validpref $TEXT/newstest2013.tok.bpe.32000 \
--testpref $TEXT/newstest2014.tok.bpe.32000 \
--destdir data-bin/wmt16_en_de_bpe32k \
--nwordssrc 32768 --nwordstgt 32768 \ --nwordssrc 32768 --nwordstgt 32768 \
--joined-dictionary --joined-dictionary
``` ```
3. Train a model: 3. Train a model:
``` ```
$ python train.py data-bin/wmt14_en_de_joined_dict \ $ python train.py data-bin/wmt16_en_de_bpe32k \
--arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \ --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \ --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
......
...@@ -43,12 +43,6 @@ if [ "$1" == "--icml17" ]; then ...@@ -43,12 +43,6 @@ if [ "$1" == "--icml17" ]; then
CORPORA[2]="training/news-commentary-v9.de-en" CORPORA[2]="training/news-commentary-v9.de-en"
fi fi
# This will make the dataset comparable to the one used in "Scaling Neural Machine Translation"
# https://arxiv.org/abs/1806.00187
if [ "$1" == "--scaling18" ]; then
BPE_TOKENS=32764
fi
if [ ! -d "$SCRIPTS" ]; then if [ ! -d "$SCRIPTS" ]; then
echo "Please set SCRIPTS variable correctly to point to Moses scripts." echo "Please set SCRIPTS variable correctly to point to Moses scripts."
exit exit
...@@ -114,26 +108,11 @@ for l in $src $tgt; do ...@@ -114,26 +108,11 @@ for l in $src $tgt; do
echo "" echo ""
done done
if [ "$1" == "--scaling18" ]; then echo "splitting train and valid..."
# apply length filtering before BPE for --scaling18 for l in $src $tgt; do
perl $CLEAN $tmp/train.tags.$lang.tok $src $tgt $tmp/train 1 80 awk '{if (NR%100 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
awk '{if (NR%100 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
# use newstest2013 for valid done
echo "pre-processing valid data..."
for l in $src $tgt; do
rm $tmp/valid.$l
cat $orig/$dev.$l | \
perl $NORM_PUNC $l | \
perl $REM_NON_PRINT_CHAR | \
perl $TOKENIZER -threads 8 -a -l $l >> $tmp/valid.$l
done
else
echo "splitting train and valid..."
for l in $src $tgt; do
awk '{if (NR%100 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
awk '{if (NR%100 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
done
fi
TRAIN=$tmp/train.de-en TRAIN=$tmp/train.de-en
BPE_CODE=$prep/code BPE_CODE=$prep/code
...@@ -152,15 +131,8 @@ for L in $src $tgt; do ...@@ -152,15 +131,8 @@ for L in $src $tgt; do
done done
done done
if [ "$1" == "--scaling18" ]; then perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
for L in $src $tgt; do perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250
cp $tmp/bpe.train.$L $prep/train.$L
cp $tmp/bpe.valid.$L $prep/valid.$L
done
else
perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250
fi
for L in $src $tgt; do for L in $src $tgt; do
cp $tmp/bpe.test.$L $prep/test.$L cp $tmp/bpe.test.$L $prep/test.$L
......
...@@ -7,7 +7,8 @@ ...@@ -7,7 +7,8 @@
from .dictionary import Dictionary, TruncatedDictionary from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset from .fairseq_dataset import FairseqDataset
from .indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
from .token_block_dataset import TokenBlockDataset from .token_block_dataset import TokenBlockDataset
...@@ -20,11 +21,13 @@ from .iterators import ( ...@@ -20,11 +21,13 @@ from .iterators import (
) )
__all__ = [ __all__ = [
'ConcatDataset',
'CountingIterator', 'CountingIterator',
'Dictionary', 'Dictionary',
'EpochBatchIterator', 'EpochBatchIterator',
'FairseqDataset', 'FairseqDataset',
'GroupedIterator', 'GroupedIterator',
'IndexedCachedDataset',
'IndexedDataset', 'IndexedDataset',
'IndexedInMemoryDataset', 'IndexedInMemoryDataset',
'IndexedRawTextDataset', 'IndexedRawTextDataset',
......
import bisect
from . import FairseqDataset
class ConcatDataset(FairseqDataset):
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets):
super(ConcatDataset, self).__init__()
assert len(datasets) > 0, 'datasets should not be an empty iterable'
self.datasets = list(datasets)
self.cummulative_sizes = self.cumsum(self.datasets)
def __len__(self):
return self.cummulative_sizes[-1]
def __getitem__(self, idx):
dataset_idx = bisect.bisect_right(self.cummulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cummulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
@property
def supports_prefetch(self):
return all([d.supports_prefetch for d in self.datasets])
def prefetch(self, indices):
frm = 0
for to, ds in zip(self.cummulative_sizes, self.datasets):
ds.prefetch([i - frm for i in indices if frm <= i < to])
frm = to
...@@ -144,7 +144,6 @@ def batch_by_size( ...@@ -144,7 +144,6 @@ def batch_by_size(
sample_len = 0 sample_len = 0
sample_lens = [] sample_lens = []
ignored = []
for idx in indices: for idx in indices:
sample_lens.append(num_tokens_fn(idx)) sample_lens.append(num_tokens_fn(idx))
sample_len = max(sample_len, sample_lens[-1]) sample_len = max(sample_len, sample_lens[-1])
......
...@@ -48,3 +48,10 @@ class FairseqDataset(torch.utils.data.Dataset): ...@@ -48,3 +48,10 @@ class FairseqDataset(torch.utils.data.Dataset):
"""Return an ordered list of indices. Batches will be constructed based """Return an ordered list of indices. Batches will be constructed based
on this order.""" on this order."""
raise NotImplementedError raise NotImplementedError
@property
def supports_prefetch(self):
return False
def prefetch(self, indices):
raise NotImplementedError
...@@ -106,6 +106,47 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -106,6 +106,47 @@ class IndexedDataset(torch.utils.data.Dataset):
) )
class IndexedCachedDataset(IndexedDataset):
def __init__(self, path, fix_lua_indexing=False):
super().__init__(path, fix_lua_indexing, True)
self.cache = None
self.cache_index = {}
@property
def supports_prefetch(self):
return True
def prefetch(self, indices):
if all(i in self.cache_index for i in indices):
return
indices.sort()
total_size = 0
for i in indices:
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
self.cache = np.empty(total_size, dtype=self.dtype)
ptx = 0
self.cache_index.clear()
for i in indices:
self.cache_index[i] = ptx
size = self.data_offsets[i + 1] - self.data_offsets[i]
a = self.cache[ptx : ptx + size]
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
ptx += size
def __getitem__(self, i):
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
ptx = self.cache_index[i]
np.copyto(a, self.cache[ptx : ptx + a.size])
item = torch.from_numpy(a).long()
if self.fix_lua_indexing:
item -= 1 # subtract 1 for 0-based indexing
return item
class IndexedInMemoryDataset(IndexedDataset): class IndexedInMemoryDataset(IndexedDataset):
"""Loader for TorchNet IndexedDataset, keeps all the data in memory""" """Loader for TorchNet IndexedDataset, keeps all the data in memory"""
......
...@@ -86,23 +86,31 @@ class EpochBatchIterator(object): ...@@ -86,23 +86,31 @@ class EpochBatchIterator(object):
self.epoch = 0 self.epoch = 0
self._cur_epoch_itr = None self._cur_epoch_itr = None
self._next_epoch_itr = None self._next_epoch_itr = None
self._supports_prefetch = (
hasattr(dataset, 'supports_prefetch') and dataset.supports_prefetch
)
def __len__(self): def __len__(self):
return len(self.frozen_batches) return len(self.frozen_batches)
def next_epoch_itr(self, shuffle=True): def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
"""Return a new iterator over the dataset. """Return a new iterator over the dataset.
Args: Args:
shuffle (bool, optional): shuffle batches before returning the shuffle (bool, optional): shuffle batches before returning the
iterator. Default: ``True`` iterator. Default: ``True``
fix_batches_to_gpus: ensure that batches are always
allocated to the same shards across epochs. Requires
that :attr:`dataset` supports prefetching. Default:
``False``
""" """
if self._next_epoch_itr is not None: if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr self._cur_epoch_itr = self._next_epoch_itr
self._next_epoch_itr = None self._next_epoch_itr = None
else: else:
self.epoch += 1 self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch(self.epoch, shuffle) self._cur_epoch_itr = self._get_iterator_for_epoch(
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus)
return self._cur_epoch_itr return self._cur_epoch_itr
def end_of_epoch(self): def end_of_epoch(self):
...@@ -135,19 +143,39 @@ class EpochBatchIterator(object): ...@@ -135,19 +143,39 @@ class EpochBatchIterator(object):
if itr_pos < len(itr): if itr_pos < len(itr):
self._next_epoch_itr = itr.skip(itr_pos) self._next_epoch_itr = itr.skip(itr_pos)
def _get_iterator_for_epoch(self, epoch, shuffle): def _get_iterator_for_epoch(self, epoch, shuffle, fix_batches_to_gpus=False):
if shuffle:
def shuffle_batches(batches, seed):
# set seed based on the seed and epoch number so that we get # set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints # reproducible results when resuming from checkpoints
with data_utils.numpy_seed(self.seed + epoch): with data_utils.numpy_seed(seed):
batches = list(self.frozen_batches) # copy
np.random.shuffle(batches) np.random.shuffle(batches)
else: return batches
if self._supports_prefetch:
batches = self.frozen_batches batches = self.frozen_batches
if shuffle and not fix_batches_to_gpus:
batches = shuffle_batches(list(batches), self.seed + epoch)
batches = list(ShardedIterator(
batches, self.num_shards, self.shard_id, fill_value=[]))
self.dataset.prefetch([i for s in batches for i in s])
if shuffle and fix_batches_to_gpus:
batches = shuffle_batches(batches, self.seed + epoch + self.shard_id)
else:
if shuffle:
batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch)
else:
batches = self.frozen_batches
batches = ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
return CountingIterator(torch.utils.data.DataLoader( return CountingIterator(torch.utils.data.DataLoader(
self.dataset, self.dataset,
collate_fn=self.collate_fn, collate_fn=self.collate_fn,
batch_sampler=ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]), batch_sampler=batches,
)) ))
......
...@@ -200,3 +200,16 @@ class LanguagePairDataset(FairseqDataset): ...@@ -200,3 +200,16 @@ class LanguagePairDataset(FairseqDataset):
if self.tgt_sizes is not None: if self.tgt_sizes is not None:
indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')] indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')] return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
@property
def supports_prefetch(self):
return (
hasattr(self.src, 'supports_prefetch')
and self.src.supports_prefetch
and hasattr(self.tgt, 'supports_prefetch')
and self.tgt.supports_prefetch
)
...@@ -180,5 +180,5 @@ class MonolingualDataset(FairseqDataset): ...@@ -180,5 +180,5 @@ class MonolingualDataset(FairseqDataset):
order = [np.random.permutation(len(self))] order = [np.random.permutation(len(self))]
else: else:
order = [np.arange(len(self))] order = [np.arange(len(self))]
order.append(np.flip(self.sizes, 0)) order.append(self.sizes)
return np.lexsort(order) return np.lexsort(order)
...@@ -26,7 +26,7 @@ C10dStatus = namedtuple('C10dStatus', ['has_c10d', 'is_default']) ...@@ -26,7 +26,7 @@ C10dStatus = namedtuple('C10dStatus', ['has_c10d', 'is_default'])
if hasattr(nn.parallel, 'deprecated'): if hasattr(nn.parallel, 'deprecated'):
c10d_status = C10dStatus(has_c10d=True, is_default=True) c10d_status = C10dStatus(has_c10d=True, is_default=True)
elif hasattr(nn.parallel, '_DistributedDataParallelC10d'): elif hasattr(torch.distributed, 'c10d') and hasattr(torch.distributed.c10d, 'init_process_group'):
c10d_status = C10dStatus(has_c10d=True, is_default=False) c10d_status = C10dStatus(has_c10d=True, is_default=False)
else: else:
c10d_status = C10dStatus(has_c10d=False, is_default=False) c10d_status = C10dStatus(has_c10d=False, is_default=False)
...@@ -46,7 +46,8 @@ def distributed_init(args): ...@@ -46,7 +46,8 @@ def distributed_init(args):
if args.distributed_world_size == 1: if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1') raise ValueError('Cannot initialize distributed with distributed_world_size=1')
if args.ddp_backend == 'no_c10d': if args.ddp_backend == 'no_c10d' or not c10d_status.has_c10d:
args.ddp_backend = 'no_c10d'
_use_c10d[0] = False _use_c10d[0] = False
print('| distributed init (rank {}): {}'.format( print('| distributed init (rank {}): {}'.format(
......
...@@ -190,6 +190,10 @@ def add_distributed_training_args(parser): ...@@ -190,6 +190,10 @@ def add_distributed_training_args(parser):
help='DistributedDataParallel backend') help='DistributedDataParallel backend')
group.add_argument('--bucket-cap-mb', default=150, type=int, metavar='MB', group.add_argument('--bucket-cap-mb', default=150, type=int, metavar='MB',
help='bucket size for reduction') help='bucket size for reduction')
group.add_argument('--fix-batches-to-gpus', action='store_true',
help='Don\'t shuffle batches between GPUs, this reduces overall '
'randomness and may affect precision but avoids the cost of'
're-reading the data')
return group return group
......
...@@ -9,12 +9,10 @@ import itertools ...@@ -9,12 +9,10 @@ import itertools
import numpy as np import numpy as np
import os import os
from torch.utils.data import ConcatDataset
from fairseq import options from fairseq import options
from fairseq.data import ( from fairseq.data import (
data_utils, Dictionary, LanguagePairDataset, IndexedInMemoryDataset, data_utils, Dictionary, LanguagePairDataset, ConcatDataset,
IndexedRawTextDataset, IndexedRawTextDataset, IndexedCachedDataset, IndexedDataset
) )
from . import FairseqTask, register_task from . import FairseqTask, register_task
...@@ -106,15 +104,15 @@ class TranslationTask(FairseqTask): ...@@ -106,15 +104,15 @@ class TranslationTask(FairseqTask):
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename): if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True return True
elif not self.args.raw_text and IndexedInMemoryDataset.exists(filename): elif not self.args.raw_text and IndexedDataset.exists(filename):
return True return True
return False return False
def indexed_dataset(path, dictionary): def indexed_dataset(path, dictionary):
if self.args.raw_text: if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary) return IndexedRawTextDataset(path, dictionary)
elif IndexedInMemoryDataset.exists(path): elif IndexedDataset.exists(path):
return IndexedInMemoryDataset(path, fix_lua_indexing=True) return IndexedCachedDataset(path, fix_lua_indexing=True)
return None return None
src_datasets = [] src_datasets = []
...@@ -122,7 +120,7 @@ class TranslationTask(FairseqTask): ...@@ -122,7 +120,7 @@ class TranslationTask(FairseqTask):
data_paths = self.args.data data_paths = self.args.data
for data_path in data_paths: for dk, data_path in enumerate(data_paths):
for k in itertools.count(): for k in itertools.count():
split_k = split + (str(k) if k > 0 else '') split_k = split + (str(k) if k > 0 else '')
...@@ -133,7 +131,7 @@ class TranslationTask(FairseqTask): ...@@ -133,7 +131,7 @@ class TranslationTask(FairseqTask):
elif split_exists(split_k, tgt, src, src, data_path): elif split_exists(split_k, tgt, src, src, data_path):
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src)) prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
else: else:
if k > 0: if k > 0 or dk > 0:
break break
else: else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
......
...@@ -112,7 +112,7 @@ def train(args, trainer, task, epoch_itr): ...@@ -112,7 +112,7 @@ def train(args, trainer, task, epoch_itr):
update_freq = args.update_freq[-1] update_freq = args.update_freq[-1]
# Initialize data iterator # Initialize data iterator
itr = epoch_itr.next_epoch_itr() itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=args.fix_batches_to_gpus)
itr = iterators.GroupedIterator(itr, update_freq) itr = iterators.GroupedIterator(itr, update_freq)
progress = progress_bar.build_progress_bar( progress = progress_bar.build_progress_bar(
args, itr, epoch_itr.epoch, no_progress_bar='simple', args, itr, epoch_itr.epoch, no_progress_bar='simple',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment