Commit b87c5366 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes (#295)

Summary:
Changelog:
- `90f52a1`: Support loading subsets of the data on each worker with the `--fix-batches-to-gpus` flag. This should fix #217 and #266.
- `6eda0a9`: Update README for replicating the "Scaling Neural Machine Translation" paper
- `b14c7cf`: Fallback to no_c10d backend for pytorch 0.4.1 (fixes #294)
Pull Request resolved: https://github.com/pytorch/fairseq/pull/295

Differential Revision: D10121559

Pulled By: myleott

fbshipit-source-id: 41c84d0ee4cdd113544b5d3aa38ae8b23acc2c27
parent 0bc5c2e9
......@@ -134,25 +134,28 @@ $ python generate.py data-bin/fconv_wmt_en_fr \
## Replicating results from "Scaling Neural Machine Translation"
To replicate results from the paper [Scaling Neural Machine Translation (Ott et al., 2018)](https://arxiv.org/abs/1806.00187):
To replicate results from the paper [Scaling Neural Machine Translation (Ott et al., 2018)](https://arxiv.org/abs/1806.00187),
please first download the [preprocessed WMT'16 En-De data provided by Google](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8).
1. Prepare the WMT'14 En-De data with a BPE vocab of 32k:
1. Extract the WMT'16 En-De data:
```
$ bash prepare-wmt14en2de.sh --scaling18
$ cd ../..
$ TEXT=wmt16_en_de_bpe32k
$ mkdir $TEXT
$ tar -xzvf wmt16_en_de.tar.gz -C $TEXT
```
2. Preprocess the dataset with a joined dictionary:
```
$ TEXT=examples/translation/wmt14_en_de
$ python preprocess.py --source-lang en --target-lang de \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/wmt14_en_de_joined_dict \
--trainpref $TEXT/train.tok.clean.bpe.32000 \
--validpref $TEXT/newstest2013.tok.bpe.32000 \
--testpref $TEXT/newstest2014.tok.bpe.32000 \
--destdir data-bin/wmt16_en_de_bpe32k \
--nwordssrc 32768 --nwordstgt 32768 \
--joined-dictionary
```
3. Train a model:
```
$ python train.py data-bin/wmt14_en_de_joined_dict \
$ python train.py data-bin/wmt16_en_de_bpe32k \
--arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
--lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \
......
......@@ -43,12 +43,6 @@ if [ "$1" == "--icml17" ]; then
CORPORA[2]="training/news-commentary-v9.de-en"
fi
# This will make the dataset comparable to the one used in "Scaling Neural Machine Translation"
# https://arxiv.org/abs/1806.00187
if [ "$1" == "--scaling18" ]; then
BPE_TOKENS=32764
fi
if [ ! -d "$SCRIPTS" ]; then
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
exit
......@@ -114,26 +108,11 @@ for l in $src $tgt; do
echo ""
done
if [ "$1" == "--scaling18" ]; then
# apply length filtering before BPE for --scaling18
perl $CLEAN $tmp/train.tags.$lang.tok $src $tgt $tmp/train 1 80
# use newstest2013 for valid
echo "pre-processing valid data..."
for l in $src $tgt; do
rm $tmp/valid.$l
cat $orig/$dev.$l | \
perl $NORM_PUNC $l | \
perl $REM_NON_PRINT_CHAR | \
perl $TOKENIZER -threads 8 -a -l $l >> $tmp/valid.$l
done
else
echo "splitting train and valid..."
for l in $src $tgt; do
awk '{if (NR%100 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
awk '{if (NR%100 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
done
fi
echo "splitting train and valid..."
for l in $src $tgt; do
awk '{if (NR%100 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l
awk '{if (NR%100 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l
done
TRAIN=$tmp/train.de-en
BPE_CODE=$prep/code
......@@ -152,15 +131,8 @@ for L in $src $tgt; do
done
done
if [ "$1" == "--scaling18" ]; then
for L in $src $tgt; do
cp $tmp/bpe.train.$L $prep/train.$L
cp $tmp/bpe.valid.$L $prep/valid.$L
done
else
perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250
fi
perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250
perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250
for L in $src $tgt; do
cp $tmp/bpe.test.$L $prep/test.$L
......
......@@ -7,7 +7,8 @@
from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset
from .indexed_dataset import IndexedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedDataset, IndexedCachedDataset, IndexedInMemoryDataset, IndexedRawTextDataset
from .language_pair_dataset import LanguagePairDataset
from .monolingual_dataset import MonolingualDataset
from .token_block_dataset import TokenBlockDataset
......@@ -20,11 +21,13 @@ from .iterators import (
)
__all__ = [
'ConcatDataset',
'CountingIterator',
'Dictionary',
'EpochBatchIterator',
'FairseqDataset',
'GroupedIterator',
'IndexedCachedDataset',
'IndexedDataset',
'IndexedInMemoryDataset',
'IndexedRawTextDataset',
......
import bisect
from . import FairseqDataset
class ConcatDataset(FairseqDataset):
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, datasets):
super(ConcatDataset, self).__init__()
assert len(datasets) > 0, 'datasets should not be an empty iterable'
self.datasets = list(datasets)
self.cummulative_sizes = self.cumsum(self.datasets)
def __len__(self):
return self.cummulative_sizes[-1]
def __getitem__(self, idx):
dataset_idx = bisect.bisect_right(self.cummulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cummulative_sizes[dataset_idx - 1]
return self.datasets[dataset_idx][sample_idx]
@property
def supports_prefetch(self):
return all([d.supports_prefetch for d in self.datasets])
def prefetch(self, indices):
frm = 0
for to, ds in zip(self.cummulative_sizes, self.datasets):
ds.prefetch([i - frm for i in indices if frm <= i < to])
frm = to
......@@ -144,7 +144,6 @@ def batch_by_size(
sample_len = 0
sample_lens = []
ignored = []
for idx in indices:
sample_lens.append(num_tokens_fn(idx))
sample_len = max(sample_len, sample_lens[-1])
......
......@@ -48,3 +48,10 @@ class FairseqDataset(torch.utils.data.Dataset):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
raise NotImplementedError
@property
def supports_prefetch(self):
return False
def prefetch(self, indices):
raise NotImplementedError
......@@ -106,6 +106,47 @@ class IndexedDataset(torch.utils.data.Dataset):
)
class IndexedCachedDataset(IndexedDataset):
def __init__(self, path, fix_lua_indexing=False):
super().__init__(path, fix_lua_indexing, True)
self.cache = None
self.cache_index = {}
@property
def supports_prefetch(self):
return True
def prefetch(self, indices):
if all(i in self.cache_index for i in indices):
return
indices.sort()
total_size = 0
for i in indices:
total_size += self.data_offsets[i + 1] - self.data_offsets[i]
self.cache = np.empty(total_size, dtype=self.dtype)
ptx = 0
self.cache_index.clear()
for i in indices:
self.cache_index[i] = ptx
size = self.data_offsets[i + 1] - self.data_offsets[i]
a = self.cache[ptx : ptx + size]
self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a)
ptx += size
def __getitem__(self, i):
self.check_index(i)
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype)
ptx = self.cache_index[i]
np.copyto(a, self.cache[ptx : ptx + a.size])
item = torch.from_numpy(a).long()
if self.fix_lua_indexing:
item -= 1 # subtract 1 for 0-based indexing
return item
class IndexedInMemoryDataset(IndexedDataset):
"""Loader for TorchNet IndexedDataset, keeps all the data in memory"""
......
......@@ -86,23 +86,31 @@ class EpochBatchIterator(object):
self.epoch = 0
self._cur_epoch_itr = None
self._next_epoch_itr = None
self._supports_prefetch = (
hasattr(dataset, 'supports_prefetch') and dataset.supports_prefetch
)
def __len__(self):
return len(self.frozen_batches)
def next_epoch_itr(self, shuffle=True):
def next_epoch_itr(self, shuffle=True, fix_batches_to_gpus=False):
"""Return a new iterator over the dataset.
Args:
shuffle (bool, optional): shuffle batches before returning the
iterator. Default: ``True``
fix_batches_to_gpus: ensure that batches are always
allocated to the same shards across epochs. Requires
that :attr:`dataset` supports prefetching. Default:
``False``
"""
if self._next_epoch_itr is not None:
self._cur_epoch_itr = self._next_epoch_itr
self._next_epoch_itr = None
else:
self.epoch += 1
self._cur_epoch_itr = self._get_iterator_for_epoch(self.epoch, shuffle)
self._cur_epoch_itr = self._get_iterator_for_epoch(
self.epoch, shuffle, fix_batches_to_gpus=fix_batches_to_gpus)
return self._cur_epoch_itr
def end_of_epoch(self):
......@@ -135,19 +143,39 @@ class EpochBatchIterator(object):
if itr_pos < len(itr):
self._next_epoch_itr = itr.skip(itr_pos)
def _get_iterator_for_epoch(self, epoch, shuffle):
if shuffle:
def _get_iterator_for_epoch(self, epoch, shuffle, fix_batches_to_gpus=False):
def shuffle_batches(batches, seed):
# set seed based on the seed and epoch number so that we get
# reproducible results when resuming from checkpoints
with data_utils.numpy_seed(self.seed + epoch):
batches = list(self.frozen_batches) # copy
with data_utils.numpy_seed(seed):
np.random.shuffle(batches)
else:
return batches
if self._supports_prefetch:
batches = self.frozen_batches
if shuffle and not fix_batches_to_gpus:
batches = shuffle_batches(list(batches), self.seed + epoch)
batches = list(ShardedIterator(
batches, self.num_shards, self.shard_id, fill_value=[]))
self.dataset.prefetch([i for s in batches for i in s])
if shuffle and fix_batches_to_gpus:
batches = shuffle_batches(batches, self.seed + epoch + self.shard_id)
else:
if shuffle:
batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch)
else:
batches = self.frozen_batches
batches = ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[])
return CountingIterator(torch.utils.data.DataLoader(
self.dataset,
collate_fn=self.collate_fn,
batch_sampler=ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]),
batch_sampler=batches,
))
......
......@@ -200,3 +200,16 @@ class LanguagePairDataset(FairseqDataset):
if self.tgt_sizes is not None:
indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
def prefetch(self, indices):
self.src.prefetch(indices)
self.tgt.prefetch(indices)
@property
def supports_prefetch(self):
return (
hasattr(self.src, 'supports_prefetch')
and self.src.supports_prefetch
and hasattr(self.tgt, 'supports_prefetch')
and self.tgt.supports_prefetch
)
......@@ -180,5 +180,5 @@ class MonolingualDataset(FairseqDataset):
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(np.flip(self.sizes, 0))
order.append(self.sizes)
return np.lexsort(order)
......@@ -26,7 +26,7 @@ C10dStatus = namedtuple('C10dStatus', ['has_c10d', 'is_default'])
if hasattr(nn.parallel, 'deprecated'):
c10d_status = C10dStatus(has_c10d=True, is_default=True)
elif hasattr(nn.parallel, '_DistributedDataParallelC10d'):
elif hasattr(torch.distributed, 'c10d') and hasattr(torch.distributed.c10d, 'init_process_group'):
c10d_status = C10dStatus(has_c10d=True, is_default=False)
else:
c10d_status = C10dStatus(has_c10d=False, is_default=False)
......@@ -46,7 +46,8 @@ def distributed_init(args):
if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1')
if args.ddp_backend == 'no_c10d':
if args.ddp_backend == 'no_c10d' or not c10d_status.has_c10d:
args.ddp_backend = 'no_c10d'
_use_c10d[0] = False
print('| distributed init (rank {}): {}'.format(
......
......@@ -190,6 +190,10 @@ def add_distributed_training_args(parser):
help='DistributedDataParallel backend')
group.add_argument('--bucket-cap-mb', default=150, type=int, metavar='MB',
help='bucket size for reduction')
group.add_argument('--fix-batches-to-gpus', action='store_true',
help='Don\'t shuffle batches between GPUs, this reduces overall '
'randomness and may affect precision but avoids the cost of'
're-reading the data')
return group
......
......@@ -9,12 +9,10 @@ import itertools
import numpy as np
import os
from torch.utils.data import ConcatDataset
from fairseq import options
from fairseq.data import (
data_utils, Dictionary, LanguagePairDataset, IndexedInMemoryDataset,
IndexedRawTextDataset,
data_utils, Dictionary, LanguagePairDataset, ConcatDataset,
IndexedRawTextDataset, IndexedCachedDataset, IndexedDataset
)
from . import FairseqTask, register_task
......@@ -106,15 +104,15 @@ class TranslationTask(FairseqTask):
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename):
return True
elif not self.args.raw_text and IndexedInMemoryDataset.exists(filename):
elif not self.args.raw_text and IndexedDataset.exists(filename):
return True
return False
def indexed_dataset(path, dictionary):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
elif IndexedInMemoryDataset.exists(path):
return IndexedInMemoryDataset(path, fix_lua_indexing=True)
elif IndexedDataset.exists(path):
return IndexedCachedDataset(path, fix_lua_indexing=True)
return None
src_datasets = []
......@@ -122,7 +120,7 @@ class TranslationTask(FairseqTask):
data_paths = self.args.data
for data_path in data_paths:
for dk, data_path in enumerate(data_paths):
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
......@@ -133,7 +131,7 @@ class TranslationTask(FairseqTask):
elif split_exists(split_k, tgt, src, src, data_path):
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split_k, tgt, src))
else:
if k > 0:
if k > 0 or dk > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
......
......@@ -112,7 +112,7 @@ def train(args, trainer, task, epoch_itr):
update_freq = args.update_freq[-1]
# Initialize data iterator
itr = epoch_itr.next_epoch_itr()
itr = epoch_itr.next_epoch_itr(fix_batches_to_gpus=args.fix_batches_to_gpus)
itr = iterators.GroupedIterator(itr, update_freq)
progress = progress_bar.build_progress_bar(
args, itr, epoch_itr.epoch, no_progress_bar='simple',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment