Commit 5f78106a authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Default to mmap and infer dataset implementations automatically

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/751

Differential Revision: D16410989

Pulled By: myleott

fbshipit-source-id: ddbbee49756f9ff6c4487977a3f5d2259b7abafe
parent 1f96d284
......@@ -5,13 +5,15 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import contextlib
import os
import numpy as np
try:
from collections.abc import Iterable
except ImportError:
from collections import Iterable
import contextlib
import itertools
import os
import numpy as np
def infer_language_pair(path):
......@@ -43,6 +45,51 @@ def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_be
return res
def load_indexed_dataset(path, dictionary, dataset_impl=None, combine=False):
"""A helper function for loading indexed datasets.
Args:
path (str): path to indexed dataset (e.g., 'data-bin/train')
dictionary (~fairseq.data.Dictionary): data dictionary
dataset_impl (str, optional): which dataset implementation to use. If
not provided, it will be inferred automatically. For legacy indexed
data we use the 'cached' implementation by default.
combine (bool, optional): automatically load and combine multiple
datasets. For example, if *path* is 'data-bin/train', then we will
combine 'data-bin/train', 'data-bin/train1', ... and return a
single ConcatDataset instance.
"""
from fairseq.data.concat_dataset import ConcatDataset
import fairseq.data.indexed_dataset as indexed_dataset
datasets = []
for k in itertools.count():
path_k = path + (str(k) if k > 0 else '')
dataset_impl_k = dataset_impl
if dataset_impl_k is None:
dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
dataset = indexed_dataset.make_dataset(
path_k,
impl=dataset_impl_k or 'cached',
fix_lua_indexing=True,
dictionary=dictionary,
)
if dataset is None:
break
print('| loaded {} examples from: {}'.format(len(dataset), path_k))
datasets.append(dataset)
if not combine:
break
if len(datasets) == 0:
return None
elif len(datasets) == 1:
return datasets[0]
else:
return ConcatDataset(datasets)
@contextlib.contextmanager
def numpy_seed(seed, *addl_seeds):
"""Context manager which seeds the NumPy PRNG with the specified seed and
......
......@@ -22,6 +22,26 @@ def __best_fitting_dtype(vocab_size=None):
return np.int32
def get_available_dataset_impl():
return ['raw', 'lazy', 'cached', 'mmap']
def infer_dataset_impl(path):
if IndexedRawTextDataset.exists(path):
return 'raw'
elif IndexedDataset.exists(path):
with open(index_file_path(path), 'rb') as f:
magic = f.read(8)
if magic == IndexedDataset._HDR_MAGIC:
return 'cached'
elif magic == MMapIndexedDataset.Index._HDR_MAGIC[:8]:
return 'mmap'
else:
return None
else:
return None
def make_builder(out_file, impl, vocab_size=None):
if impl == 'mmap':
return MMapIndexedDatasetBuilder(out_file, dtype=__best_fitting_dtype(vocab_size))
......@@ -39,7 +59,6 @@ def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None):
return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing)
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path)
return None
......@@ -91,6 +110,7 @@ def data_file_path(prefix_path):
class IndexedDataset(FairseqDataset):
"""Loader for TorchNet IndexedDataset"""
_HDR_MAGIC = b'TNTIDX\x00\x00'
def __init__(self, path, fix_lua_indexing=False):
super().__init__()
......@@ -102,7 +122,7 @@ class IndexedDataset(FairseqDataset):
def read_index(self, path):
with open(index_file_path(path), 'rb') as f:
magic = f.read(8)
assert magic == b'TNTIDX\x00\x00', (
assert magic == self._HDR_MAGIC, (
'Index file doesn\'t match expected format. '
'Make sure that --dataset-impl is configured properly.'
)
......@@ -151,7 +171,7 @@ class IndexedDataset(FairseqDataset):
@staticmethod
def exists(path):
return (
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
)
@property
......@@ -465,7 +485,7 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
@staticmethod
def exists(path):
return (
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
os.path.exists(index_file_path(path)) and os.path.exists(data_file_path(path))
)
......
......@@ -11,6 +11,7 @@ import torch
import sys
from fairseq import utils
from fairseq.data.indexed_dataset import get_available_dataset_impl
def get_preprocessing_parser(default_task='translation'):
......@@ -233,8 +234,9 @@ def add_preprocess_args(parser):
help="number of source words to retain")
group.add_argument("--alignfile", metavar="ALIGN", default=None,
help="an alignment file (optional)")
parser.add_argument('--dataset-impl', metavar="FORMAT", help='output dataset implementation',
choices=['raw', 'lazy', 'cached', 'mmap'], default='cached')
parser.add_argument('--dataset-impl', metavar='FORMAT', default='mmap',
choices=get_available_dataset_impl(),
help='output dataset implementation')
group.add_argument("--joined-dictionary", action="store_true",
help="Generate joined dictionary")
group.add_argument("--only-source", action="store_true",
......@@ -260,8 +262,9 @@ def add_dataset_args(parser, train=False, gen=False):
help='maximum number of sentences in a batch')
group.add_argument('--required-batch-size-multiple', default=8, type=int, metavar='N',
help='batch size will be a multiplier of this value')
parser.add_argument('--dataset-impl', metavar="FORMAT", help='output dataset implementation',
choices=['raw', 'lazy', 'cached', 'mmap'], default='cached')
parser.add_argument('--dataset-impl', metavar='FORMAT',
choices=get_available_dataset_impl(),
help='output dataset implementation')
if train:
group.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'],
......
......@@ -17,6 +17,7 @@ from fairseq.data.masked_lm_dictionary import MaskedLMDictionary
from fairseq.data import (
ConcatDataset,
data_utils,
indexed_dataset,
TokenBlockDataset,
)
......@@ -114,10 +115,7 @@ class CrossLingualLMTask(FairseqTask):
split_k = split + (str(k) if k > 0 else '')
path = os.path.join(data_path, split_k)
ds = indexed_dataset.make_dataset(
path, impl=self.args.dataset_impl, fix_lua_indexing=True,
dictionary=self.dictionary,
)
ds = data_utils.load_indexed_dataset(path, self.dictionary, self.args.dataset_impl)
if ds is None:
if k > 0:
break
......
......@@ -14,6 +14,7 @@ import torch
from fairseq import utils
from fairseq.data import (
ConcatDataset,
data_utils,
Dictionary,
MonolingualDataset,
TokenBlockDataset,
......@@ -152,49 +153,30 @@ class LanguageModelingTask(FairseqTask):
Args:
split (str): name of the split (e.g., train, valid, test)
"""
loaded_datasets = []
paths = self.args.data.split(':')
assert len(paths) > 0
data_path = paths[epoch % len(paths)]
split_path = os.path.join(data_path, split)
for k in itertools.count():
split_k = split + (str(k) if k > 0 else '')
path = os.path.join(data_path, split_k)
ds = indexed_dataset.make_dataset(path, impl=self.args.dataset_impl,
fix_lua_indexing=True, dictionary=self.dictionary)
if ds is None:
if k > 0:
break
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
loaded_datasets.append(
TokenBlockDataset(
ds, ds.sizes, self.args.tokens_per_sample,
pad=self.dictionary.pad(), eos=self.dictionary.eos(),
break_mode=self.args.sample_break_mode, include_targets=True,
)
)
print('| {} {} {} examples'.format(data_path, split_k, len(loaded_datasets[-1])))
if not combine:
break
dataset = data_utils.load_indexed_dataset(
split_path,
self.dictionary,
self.args.dataset_impl,
combine=combine,
)
if dataset is None:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, split_path))
if len(loaded_datasets) == 1:
dataset = loaded_datasets[0]
sizes = dataset.sizes
else:
dataset = ConcatDataset(loaded_datasets)
sizes = np.concatenate([ds.sizes for ds in loaded_datasets])
dataset = TokenBlockDataset(
dataset, dataset.sizes, self.args.tokens_per_sample,
pad=self.dictionary.pad(), eos=self.dictionary.eos(),
break_mode=self.args.sample_break_mode, include_targets=True,
)
add_eos_for_other_targets = self.args.sample_break_mode is not None and self.args.sample_break_mode != 'none'
self.datasets[split] = MonolingualDataset(
dataset, sizes, self.dictionary, self.output_dictionary,
dataset, dataset.sizes, self.dictionary, self.output_dictionary,
add_eos_for_other_targets=add_eos_for_other_targets, shuffle=True,
targets=self.targets, add_bos_token=self.args.add_bos_token,
)
......
......@@ -47,10 +47,12 @@ def load_langpair_dataset(
else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
src_datasets.append(indexed_dataset.make_dataset(prefix + src, impl=dataset_impl,
fix_lua_indexing=True, dictionary=src_dict))
tgt_datasets.append(indexed_dataset.make_dataset(prefix + tgt, impl=dataset_impl,
fix_lua_indexing=True, dictionary=tgt_dict))
src_datasets.append(
data_utils.load_indexed_dataset(prefix + src, src_dict, dataset_impl)
)
tgt_datasets.append(
data_utils.load_indexed_dataset(prefix + tgt, tgt_dict, dataset_impl)
)
print('| {} {} {}-{} {} examples'.format(data_path, split_k, src, tgt, len(src_datasets[-1])))
......
......@@ -17,7 +17,7 @@ def get_parser():
description='writes text from binarized file to stdout')
# fmt: off
parser.add_argument('--dataset-impl', help='dataset implementation',
choices=['raw', 'lazy', 'cached', 'mmap'], default='lazy')
choices=indexed_dataset.get_available_dataset_impl())
parser.add_argument('--dict', metavar='FP', help='dictionary containing known words', default=None)
parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read')
# fmt: on
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment