"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d9ee3879b0ae5a6d1a4eff49fd5febaaa4a03a0a"
Commit a1c997bd authored by Davide Caroselli's avatar Davide Caroselli Committed by Facebook Github Bot
Browse files

Memory-Mapped IndexedDataset implementation (#589)

Summary:
Following discussion in https://github.com/pytorch/fairseq/issues/574:

 - Implemented MMapIndexedDataset and MMapIndexedDatasetBuilder compatible with IndexedDataset/IndexedDatasetBuilder
- Update scripts/read_binarized.py to support new MMapIndexedDataset
- Option '--raw-text' and '--lazy-load' replaced with '--dataset-impl' and moved the option definition custom task args to more high-level options.add_dataset_args() (more appropriate)
- Implemented also utils functions in indexed_dataset: make_dataset(), dataset_exists()
Pull Request resolved: https://github.com/pytorch/fairseq/pull/589

Differential Revision: D14597128

Pulled By: myleott

fbshipit-source-id: 4e92d99920cbaa52cfe5a0f1f5d9ae5c92d4268e
parent e4edf27a
...@@ -9,7 +9,7 @@ from .dictionary import Dictionary, TruncatedDictionary ...@@ -9,7 +9,7 @@ from .dictionary import Dictionary, TruncatedDictionary
from .fairseq_dataset import FairseqDataset from .fairseq_dataset import FairseqDataset
from .backtranslation_dataset import BacktranslationDataset from .backtranslation_dataset import BacktranslationDataset
from .concat_dataset import ConcatDataset from .concat_dataset import ConcatDataset
from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset from .indexed_dataset import IndexedCachedDataset, IndexedDataset, IndexedRawTextDataset, MMapIndexedDataset
from .language_pair_dataset import LanguagePairDataset from .language_pair_dataset import LanguagePairDataset
from .lm_context_window_dataset import LMContextWindowDataset from .lm_context_window_dataset import LMContextWindowDataset
from .monolingual_dataset import MonolingualDataset from .monolingual_dataset import MonolingualDataset
...@@ -39,6 +39,7 @@ __all__ = [ ...@@ -39,6 +39,7 @@ __all__ = [
'IndexedRawTextDataset', 'IndexedRawTextDataset',
'LanguagePairDataset', 'LanguagePairDataset',
'LMContextWindowDataset', 'LMContextWindowDataset',
'MMapIndexedDataset',
'MonolingualDataset', 'MonolingualDataset',
'NoisingDataset', 'NoisingDataset',
'RoundRobinZipDatasets', 'RoundRobinZipDatasets',
......
...@@ -4,14 +4,44 @@ ...@@ -4,14 +4,44 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import os import os
import shutil
import struct import struct
import numpy as np import numpy as np
import torch import torch
def make_builder(out_file, impl):
if impl == 'mmap':
return MMapIndexedDatasetBuilder(out_file)
else:
return IndexedDatasetBuilder(out_file)
def make_dataset(path, impl, fix_lua_indexing=False, dictionary=None):
if impl == 'raw' and IndexedRawTextDataset.exists(path):
assert dictionary is not None
return IndexedRawTextDataset(path, dictionary)
elif impl == 'lazy' and IndexedDataset.exists(path):
return IndexedDataset(path, fix_lua_indexing=fix_lua_indexing)
elif impl == 'cached' and IndexedDataset.exists(path):
return IndexedCachedDataset(path, fix_lua_indexing=fix_lua_indexing)
elif impl == 'mmap' and MMapIndexedDataset.exists(path):
return MMapIndexedDataset(path)
return None
def dataset_exists(path, impl):
if impl == 'raw':
return IndexedRawTextDataset.exists(path)
elif impl == 'mmap':
return MMapIndexedDataset.exists(path)
else:
return IndexedDataset.exists(path)
def read_longs(f, n): def read_longs(f, n):
a = np.empty(n, dtype=np.int64) a = np.empty(n, dtype=np.int64)
f.readinto(a) f.readinto(a)
...@@ -37,6 +67,7 @@ def code(dtype): ...@@ -37,6 +67,7 @@ def code(dtype):
for k in dtypes.keys(): for k in dtypes.keys():
if dtypes[k] == dtype: if dtypes[k] == dtype:
return k return k
raise ValueError(dtype)
def index_file_path(prefix_path): def index_file_path(prefix_path):
...@@ -100,8 +131,8 @@ class IndexedDataset(torch.utils.data.Dataset): ...@@ -100,8 +131,8 @@ class IndexedDataset(torch.utils.data.Dataset):
@staticmethod @staticmethod
def exists(path): def exists(path):
return ( return (
os.path.exists(index_file_path(path)) and os.path.exists(index_file_path(path)) and
os.path.exists(data_file_path(path)) os.path.exists(data_file_path(path))
) )
@property @property
...@@ -135,7 +166,7 @@ class IndexedCachedDataset(IndexedDataset): ...@@ -135,7 +166,7 @@ class IndexedCachedDataset(IndexedDataset):
for i in indices: for i in indices:
self.cache_index[i] = ptx self.cache_index[i] = ptx
size = self.data_offsets[i + 1] - self.data_offsets[i] size = self.data_offsets[i + 1] - self.data_offsets[i]
a = self.cache[ptx : ptx + size] a = self.cache[ptx: ptx + size]
self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a) self.data_file.readinto(a)
ptx += size ptx += size
...@@ -149,7 +180,7 @@ class IndexedCachedDataset(IndexedDataset): ...@@ -149,7 +180,7 @@ class IndexedCachedDataset(IndexedDataset):
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype) a = np.empty(tensor_size, dtype=self.dtype)
ptx = self.cache_index[i] ptx = self.cache_index[i]
np.copyto(a, self.cache[ptx : ptx + a.size]) np.copyto(a, self.cache[ptx: ptx + a.size])
item = torch.from_numpy(a).long() item = torch.from_numpy(a).long()
if self.fix_lua_indexing: if self.fix_lua_indexing:
item -= 1 # subtract 1 for 0-based indexing item -= 1 # subtract 1 for 0-based indexing
...@@ -262,3 +293,169 @@ class IndexedDatasetBuilder(object): ...@@ -262,3 +293,169 @@ class IndexedDatasetBuilder(object):
write_longs(index, self.data_offsets) write_longs(index, self.data_offsets)
write_longs(index, self.sizes) write_longs(index, self.sizes)
index.close() index.close()
def _warmup_mmap_file(path):
with open(path, 'rb') as stream:
while stream.read(100 * 1024 * 1024):
pass
class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b'MMIDIDX\x00\x00'
@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, 'wb')
self._file.write(cls._HDR_MAGIC)
self._file.write(struct.pack('<Q', 1))
self._file.write(struct.pack('<B', code(dtype)))
return self
@staticmethod
def _get_pointers(sizes):
dtype_size = dtype().itemsize
address = 0
pointers = []
for size in sizes:
pointers.append(address)
address += size * dtype_size
return pointers
def write(self, sizes):
pointers = self._get_pointers(sizes)
self._file.write(struct.pack('<Q', len(sizes)))
sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order='C'))
del sizes
pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order='C'))
del pointers
def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()
return _Writer()
def __init__(self, path):
with open(path, 'rb') as stream:
magic_test = stream.read(9)
assert self._HDR_MAGIC == magic_test
version = struct.unpack('<Q', stream.read(8))
assert (1,) == version
dtype_code, = struct.unpack('<B', stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype_size = self._dtype().itemsize
self._len = struct.unpack('<Q', stream.read(8))[0]
offset = stream.tell()
_warmup_mmap_file(path)
self._bin_buffer = memoryview(np.memmap(path, mode='r', order='C'))
self._sizes = np.frombuffer(self._bin_buffer, dtype=np.int32, count=self._len, offset=offset)
self._pointers = np.frombuffer(self._bin_buffer, dtype=np.int64, count=self._len,
offset=offset + self._sizes.nbytes)
@property
def dtype(self):
return self._dtype
@property
def sizes(self):
return self._sizes
def __getitem__(self, i):
return self._pointers[i], self._sizes[i]
def __len__(self):
return self._len
def __init__(self, path):
super().__init__()
self._path = None
self._index = None
self._bin_buffer = None
self._do_init(path)
def __getstate__(self):
return self._path
def __setstate__(self, state):
self._do_init(state)
def _do_init(self, path):
self._path = path
self._index = self.Index(index_file_path(self._path))
_warmup_mmap_file(data_file_path(self._path))
self._bin_buffer = memoryview(np.memmap(data_file_path(self._path), mode='r', order='C'))
def __len__(self):
return len(self._index)
def __getitem__(self, i):
ptr, size = self._index[i]
tensor = torch.from_numpy(np.frombuffer(self._bin_buffer, dtype=self._index.dtype, count=size, offset=ptr))
if tensor.dtype == torch.int64:
return tensor
else:
return tensor.long()
@property
def sizes(self):
return self._index.sizes
@property
def supports_prefetch(self):
return False
@staticmethod
def exists(path):
return (
os.path.exists(index_file_path(path)) and
os.path.exists(data_file_path(path))
)
class MMapIndexedDatasetBuilder(object):
def __init__(self, out_file, dtype=np.int64):
self._data_file = open(out_file, 'wb')
self._dtype = dtype
self._sizes = []
def add_item(self, tensor):
np_array = np.array(tensor.numpy(), dtype=self._dtype)
self._data_file.write(np_array.tobytes(order='C'))
self._sizes.append(np_array.size)
def merge_file_(self, another_file):
# Concatenate index
index = MMapIndexedDataset.Index(index_file_path(another_file))
assert index.dtype == self._dtype
for size in index.sizes:
self._sizes.append(size)
# Concatenate data
with open(data_file_path(another_file), 'rb') as f:
shutil.copyfileobj(f, self._data_file)
def finalize(self, index_file):
self._data_file.close()
with MMapIndexedDataset.Index.writer(index_file, self._dtype) as index:
index.write(self._sizes)
...@@ -198,9 +198,8 @@ def add_preprocess_args(parser): ...@@ -198,9 +198,8 @@ def add_preprocess_args(parser):
help="number of source words to retain") help="number of source words to retain")
group.add_argument("--alignfile", metavar="ALIGN", default=None, group.add_argument("--alignfile", metavar="ALIGN", default=None,
help="an alignment file (optional)") help="an alignment file (optional)")
group.add_argument("--output-format", metavar="FORMAT", default="binary", parser.add_argument('--dataset-impl', metavar="FORMAT", help='output dataset implementation',
choices=["binary", "raw"], choices=['raw', 'lazy', 'cached', 'mmap'], default='cached')
help="output format (optional)")
group.add_argument("--joined-dictionary", action="store_true", group.add_argument("--joined-dictionary", action="store_true",
help="Generate joined dictionary") help="Generate joined dictionary")
group.add_argument("--only-source", action="store_true", group.add_argument("--only-source", action="store_true",
...@@ -226,6 +225,8 @@ def add_dataset_args(parser, train=False, gen=False): ...@@ -226,6 +225,8 @@ def add_dataset_args(parser, train=False, gen=False):
help='maximum number of sentences in a batch') help='maximum number of sentences in a batch')
group.add_argument('--required-batch-size-multiple', default=8, type=int, metavar='N', group.add_argument('--required-batch-size-multiple', default=8, type=int, metavar='N',
help='batch size will be a multiplier of this value') help='batch size will be a multiplier of this value')
parser.add_argument('--dataset-impl', metavar="FORMAT", help='output dataset implementation',
choices=['raw', 'lazy', 'cached', 'mmap'], default='cached')
if train: if train:
group.add_argument('--train-subset', default='train', metavar='SPLIT', group.add_argument('--train-subset', default='train', metavar='SPLIT',
choices=['train', 'valid', 'test'], choices=['train', 'valid', 'test'],
......
...@@ -17,9 +17,7 @@ from fairseq.data.masked_lm_dictionary import MaskedLMDictionary ...@@ -17,9 +17,7 @@ from fairseq.data.masked_lm_dictionary import MaskedLMDictionary
from fairseq.data import ( from fairseq.data import (
ConcatDataset, ConcatDataset,
IndexedCachedDataset, indexed_dataset,
IndexedDataset,
IndexedRawTextDataset,
TokenBlockDataset, TokenBlockDataset,
) )
...@@ -118,14 +116,11 @@ class CrossLingualLMTask(FairseqTask): ...@@ -118,14 +116,11 @@ class CrossLingualLMTask(FairseqTask):
split_k = split + (str(k) if k > 0 else '') split_k = split + (str(k) if k > 0 else '')
path = os.path.join(data_path, split_k) path = os.path.join(data_path, split_k)
if self.args.raw_text and IndexedRawTextDataset.exists(path): ds = indexed_dataset.make_dataset(
ds = IndexedRawTextDataset(path, self.dictionary) path, impl=self.args.dataset_impl, fix_lua_indexing=True,
elif not self.args.raw_text and IndexedDataset.exists(path): dictionary=self.dictionary,
if self.args.lazy_load: )
ds = IndexedDataset(path, fix_lua_indexing=True) if ds is None:
else:
ds = IndexedCachedDataset(path, fix_lua_indexing=True)
else:
if k > 0: if k > 0:
break break
else: else:
......
...@@ -8,21 +8,19 @@ ...@@ -8,21 +8,19 @@
import itertools import itertools
import os import os
import torch
import numpy as np import numpy as np
import torch
from fairseq import utils
from fairseq.data import ( from fairseq.data import (
ConcatDataset, ConcatDataset,
Dictionary, Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
MonolingualDataset, MonolingualDataset,
TokenBlockDataset, TokenBlockDataset,
TransformEosDataset, TransformEosDataset,
TruncatedDictionary, TruncatedDictionary,
indexed_dataset
) )
from . import FairseqTask, register_task from . import FairseqTask, register_task
...@@ -101,6 +99,13 @@ class LanguageModelingTask(FairseqTask): ...@@ -101,6 +99,13 @@ class LanguageModelingTask(FairseqTask):
Args: Args:
args (argparse.Namespace): parsed command-line arguments args (argparse.Namespace): parsed command-line arguments
""" """
if getattr(args, 'raw_text', False):
utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw')
args.dataset_impl = 'raw'
elif getattr(args, 'lazy_load', False):
utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy')
args.dataset_impl = 'lazy'
dictionary = None dictionary = None
output_dictionary = None output_dictionary = None
if args.data: if args.data:
...@@ -154,15 +159,10 @@ class LanguageModelingTask(FairseqTask): ...@@ -154,15 +159,10 @@ class LanguageModelingTask(FairseqTask):
for k in itertools.count(): for k in itertools.count():
split_k = split + (str(k) if k > 0 else '') split_k = split + (str(k) if k > 0 else '')
path = os.path.join(data_path, split_k) path = os.path.join(data_path, split_k)
ds = indexed_dataset.make_dataset(path, impl=self.args.dataset_impl,
fix_lua_indexing=True, dictionary=self.dictionary)
if self.args.raw_text and IndexedRawTextDataset.exists(path): if ds is None:
ds = IndexedRawTextDataset(path, self.dictionary)
elif not self.args.raw_text and IndexedDataset.exists(path):
if self.args.lazy_load:
ds = IndexedDataset(path, fix_lua_indexing=True)
else:
ds = IndexedCachedDataset(path, fix_lua_indexing=True)
else:
if k > 0: if k > 0:
break break
else: else:
......
...@@ -11,17 +11,15 @@ import os ...@@ -11,17 +11,15 @@ import os
import torch import torch
from fairseq import options from fairseq import options, utils
from fairseq.data import ( from fairseq.data import (
BacktranslationDataset, BacktranslationDataset,
Dictionary, Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset, LanguagePairDataset,
NoisingDataset, NoisingDataset,
RoundRobinZipDatasets, RoundRobinZipDatasets,
TransformEosLangPairDataset, TransformEosLangPairDataset,
indexed_dataset,
) )
from fairseq.models import FairseqMultiModel from fairseq.models import FairseqMultiModel
...@@ -78,7 +76,7 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -78,7 +76,7 @@ class MultilingualTranslationTask(FairseqTask):
help='target language (only needed for inference)') help='target language (only needed for inference)')
parser.add_argument('--lazy-load', action='store_true', parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily') help='load the dataset lazily')
parser.add_argument('--raw-text', action='store_true', parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset') help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
help='pad the source on the left (default: True)') help='pad the source on the left (default: True)')
...@@ -122,6 +120,12 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -122,6 +120,12 @@ class MultilingualTranslationTask(FairseqTask):
def prepare(cls, args, **kargs): def prepare(cls, args, **kargs):
args.left_pad_source = options.eval_bool(args.left_pad_source) args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target) args.left_pad_target = options.eval_bool(args.left_pad_target)
if getattr(args, 'raw_text', False):
utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw')
args.dataset_impl = 'raw'
elif getattr(args, 'lazy_load', False):
utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy')
args.dataset_impl = 'lazy'
args.lang_pairs = args.lang_pairs.split(',') args.lang_pairs = args.lang_pairs.split(',')
sorted_langs = sorted(list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')})) sorted_langs = sorted(list({x for lang_pair in args.lang_pairs for x in lang_pair.split('-')}))
...@@ -196,21 +200,7 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -196,21 +200,7 @@ class MultilingualTranslationTask(FairseqTask):
def split_exists(split, src, tgt, lang): def split_exists(split, src, tgt, lang):
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename): return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)
return True
elif not self.args.raw_text and IndexedDataset.exists(filename):
return True
return False
def indexed_dataset(path, dictionary):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
elif IndexedDataset.exists(path):
if self.args.lazy_load:
return IndexedDataset(path, fix_lua_indexing=True)
else:
return IndexedCachedDataset(path, fix_lua_indexing=True)
return None
src_datasets, tgt_datasets = {}, {} src_datasets, tgt_datasets = {}, {}
for lang_pair in self.args.lang_pairs: for lang_pair in self.args.lang_pairs:
...@@ -221,8 +211,10 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -221,8 +211,10 @@ class MultilingualTranslationTask(FairseqTask):
prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src)) prefix = os.path.join(data_path, '{}.{}-{}.'.format(split, tgt, src))
else: else:
continue continue
src_datasets[lang_pair] = indexed_dataset(prefix + src, self.dicts[src]) src_datasets[lang_pair] = indexed_dataset.make_dataset(prefix + src, impl=self.args.dataset_impl,
tgt_datasets[lang_pair] = indexed_dataset(prefix + tgt, self.dicts[tgt]) fix_lua_indexing=True, dictionary=self.dicts[src])
tgt_datasets[lang_pair] = indexed_dataset.make_dataset(prefix + tgt, impl=self.args.dataset_impl,
fix_lua_indexing=True, dictionary=self.dicts[tgt])
print('| {} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair]))) print('| {} {} {} examples'.format(data_path, split, len(src_datasets[lang_pair])))
if len(src_datasets) == 0: if len(src_datasets) == 0:
......
...@@ -8,15 +8,13 @@ ...@@ -8,15 +8,13 @@
import itertools import itertools
import os import os
from fairseq import options from fairseq import options, utils
from fairseq.data import ( from fairseq.data import (
ConcatDataset, ConcatDataset,
data_utils, data_utils,
Dictionary, Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset, LanguagePairDataset,
indexed_dataset
) )
from . import FairseqTask, register_task from . import FairseqTask, register_task
...@@ -56,7 +54,7 @@ class TranslationTask(FairseqTask): ...@@ -56,7 +54,7 @@ class TranslationTask(FairseqTask):
help='target language') help='target language')
parser.add_argument('--lazy-load', action='store_true', parser.add_argument('--lazy-load', action='store_true',
help='load the dataset lazily') help='load the dataset lazily')
parser.add_argument('--raw-text', action='store_true', parser.add_argument('--raw-text', default=False, action='store_true',
help='load raw text dataset') help='load raw text dataset')
parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL', parser.add_argument('--left-pad-source', default='True', type=str, metavar='BOOL',
help='pad the source on the left') help='pad the source on the left')
...@@ -84,6 +82,12 @@ class TranslationTask(FairseqTask): ...@@ -84,6 +82,12 @@ class TranslationTask(FairseqTask):
""" """
args.left_pad_source = options.eval_bool(args.left_pad_source) args.left_pad_source = options.eval_bool(args.left_pad_source)
args.left_pad_target = options.eval_bool(args.left_pad_target) args.left_pad_target = options.eval_bool(args.left_pad_target)
if getattr(args, 'raw_text', False):
utils.deprecation_warning('--raw-text is deprecated, please use --dataset-impl=raw')
args.dataset_impl = 'raw'
elif getattr(args, 'lazy_load', False):
utils.deprecation_warning('--lazy-load is deprecated, please use --dataset-impl=lazy')
args.dataset_impl = 'lazy'
paths = args.data.split(':') paths = args.data.split(':')
assert len(paths) > 0 assert len(paths) > 0
...@@ -116,21 +120,7 @@ class TranslationTask(FairseqTask): ...@@ -116,21 +120,7 @@ class TranslationTask(FairseqTask):
def split_exists(split, src, tgt, lang, data_path): def split_exists(split, src, tgt, lang, data_path):
filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang)) filename = os.path.join(data_path, '{}.{}-{}.{}'.format(split, src, tgt, lang))
if self.args.raw_text and IndexedRawTextDataset.exists(filename): return indexed_dataset.dataset_exists(filename, impl=self.args.dataset_impl)
return True
elif not self.args.raw_text and IndexedDataset.exists(filename):
return True
return False
def indexed_dataset(path, dictionary):
if self.args.raw_text:
return IndexedRawTextDataset(path, dictionary)
elif IndexedDataset.exists(path):
if self.args.lazy_load:
return IndexedDataset(path, fix_lua_indexing=True)
else:
return IndexedCachedDataset(path, fix_lua_indexing=True)
return None
src_datasets = [] src_datasets = []
tgt_datasets = [] tgt_datasets = []
...@@ -150,8 +140,10 @@ class TranslationTask(FairseqTask): ...@@ -150,8 +140,10 @@ class TranslationTask(FairseqTask):
else: else:
raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path)) raise FileNotFoundError('Dataset not found: {} ({})'.format(split, data_path))
src_datasets.append(indexed_dataset(prefix + src, self.src_dict)) src_datasets.append(indexed_dataset.make_dataset(prefix + src, impl=self.args.dataset_impl,
tgt_datasets.append(indexed_dataset(prefix + tgt, self.tgt_dict)) fix_lua_indexing=True, dictionary=self.src_dict))
tgt_datasets.append(indexed_dataset.make_dataset(prefix + tgt, impl=self.args.dataset_impl,
fix_lua_indexing=True, dictionary=self.tgt_dict))
print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1]))) print('| {} {} {} examples'.format(data_path, split_k, len(src_datasets[-1])))
......
...@@ -8,17 +8,7 @@ import contextlib ...@@ -8,17 +8,7 @@ import contextlib
import torch import torch
from fairseq import modules, options, utils from fairseq import modules, utils
from fairseq.data import (
ConcatDataset,
data_utils,
Dictionary,
IndexedCachedDataset,
IndexedDataset,
IndexedRawTextDataset,
LanguagePairDataset,
)
from . import register_task from . import register_task
from .translation import TranslationTask from .translation import TranslationTask
...@@ -40,8 +30,8 @@ class TranslationMoETask(TranslationTask): ...@@ -40,8 +30,8 @@ class TranslationMoETask(TranslationTask):
(Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_. (Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_.
Args: Args:
src_dict (Dictionary): dictionary for the source language src_dict (~fairseq.data.Dictionary): dictionary for the source language
tgt_dict (Dictionary): dictionary for the target language tgt_dict (~fairseq.data.Dictionary): dictionary for the target language
.. note:: .. note::
......
...@@ -129,9 +129,7 @@ def main(args): ...@@ -129,9 +129,7 @@ def main(args):
) )
pool.close() pool.close()
ds = indexed_dataset.IndexedDatasetBuilder( ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"), impl=args.dataset_impl)
dataset_dest_file(args, output_prefix, lang, "bin")
)
merge_result( merge_result(
Binarizer.binarize( Binarizer.binarize(
input_file, vocab, lambda t: ds.add_item(t), input_file, vocab, lambda t: ds.add_item(t),
...@@ -161,15 +159,15 @@ def main(args): ...@@ -161,15 +159,15 @@ def main(args):
) )
def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1): def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
if args.output_format == "binary": if args.dataset_impl == "raw":
make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers)
elif args.output_format == "raw":
# Copy original text file to destination folder # Copy original text file to destination folder
output_text_file = dest_path( output_text_file = dest_path(
output_prefix + ".{}-{}".format(args.source_lang, args.target_lang), output_prefix + ".{}-{}".format(args.source_lang, args.target_lang),
lang, lang,
) )
shutil.copyfile(file_name(input_prefix, lang), output_text_file) shutil.copyfile(file_name(input_prefix, lang), output_text_file)
else:
make_binary_dataset(vocab, input_prefix, output_prefix, lang, num_workers)
def make_all(lang, vocab): def make_all(lang, vocab):
if args.trainpref: if args.trainpref:
...@@ -233,9 +231,7 @@ def main(args): ...@@ -233,9 +231,7 @@ def main(args):
def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos=True): def binarize(args, filename, vocab, output_prefix, lang, offset, end, append_eos=True):
ds = indexed_dataset.IndexedDatasetBuilder( ds = indexed_dataset.make_builder(dataset_dest_file(args, output_prefix, lang, "bin"), impl=args.dataset_impl)
dataset_dest_file(args, output_prefix, lang, "bin")
)
def consumer(tensor): def consumer(tensor):
ds.add_item(tensor) ds.add_item(tensor)
...@@ -263,15 +259,6 @@ def get_offsets(input_file, num_workers): ...@@ -263,15 +259,6 @@ def get_offsets(input_file, num_workers):
return Binarizer.find_offsets(input_file, num_workers) return Binarizer.find_offsets(input_file, num_workers)
def merge_files(files, outpath):
ds = indexed_dataset.IndexedDatasetBuilder("{}.bin".format(outpath))
for file in files:
ds.merge_file_(file)
os.remove(indexed_dataset.data_file_path(file))
os.remove(indexed_dataset.index_file_path(file))
ds.finalize("{}.idx".format(outpath))
def cli_main(): def cli_main():
parser = options.get_preprocessing_parser() parser = options.get_preprocessing_parser()
args = parser.parse_args() args = parser.parse_args()
......
...@@ -8,29 +8,39 @@ ...@@ -8,29 +8,39 @@
import argparse import argparse
from fairseq.data import dictionary from fairseq.data import Dictionary
from fairseq.data import IndexedDataset from fairseq.data import indexed_dataset
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='writes text from binarized file to stdout') description='writes text from binarized file to stdout')
# fmt: off # fmt: off
parser.add_argument('--dict', metavar='FP', required=True, help='dictionary containing known words') parser.add_argument('--dataset-impl', help='dataset implementation',
choices=['raw', 'lazy', 'cached', 'mmap'], default='lazy')
parser.add_argument('--dict', metavar='FP', help='dictionary containing known words', default=None)
parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read') parser.add_argument('--input', metavar='FP', required=True, help='binarized file to read')
# fmt: on # fmt: on
return parser return parser
def main(args): def main():
dict = dictionary.Dictionary.load(args.dict) parser = get_parser()
ds = IndexedDataset(args.input, fix_lua_indexing=True) args = parser.parse_args()
for tensor_line in ds:
print(dict.string(tensor_line)) dictionary = Dictionary.load(args.dict) if args.dict is not None else None
dataset = indexed_dataset.make_dataset(args.input, impl=args.dataset_impl,
fix_lua_indexing=True, dictionary=dictionary)
for tensor_line in dataset:
if dictionary is None:
line = ' '.join([str(int(x)) for x in tensor_line])
else:
line = dictionary.string(tensor_line)
print(line)
if __name__ == '__main__': if __name__ == '__main__':
parser = get_parser() main()
args = parser.parse_args()
main(args)
...@@ -38,9 +38,9 @@ class TestTranslation(unittest.TestCase): ...@@ -38,9 +38,9 @@ class TestTranslation(unittest.TestCase):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
with tempfile.TemporaryDirectory('test_fconv_raw') as data_dir: with tempfile.TemporaryDirectory('test_fconv_raw') as data_dir:
create_dummy_data(data_dir) create_dummy_data(data_dir)
preprocess_translation_data(data_dir, ['--output-format', 'raw']) preprocess_translation_data(data_dir, ['--dataset-impl', 'raw'])
train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--raw-text']) train_translation_model(data_dir, 'fconv_iwslt_de_en', ['--dataset-impl', 'raw'])
generate_main(data_dir, ['--raw-text']) generate_main(data_dir, ['--dataset-impl', 'raw'])
def test_fp16(self): def test_fp16(self):
with contextlib.redirect_stdout(StringIO()): with contextlib.redirect_stdout(StringIO()):
...@@ -418,7 +418,8 @@ def train_masked_language_model(data_dir, arch): ...@@ -418,7 +418,8 @@ def train_masked_language_model(data_dir, arch):
"--no-progress-bar", "--no-progress-bar",
"--distributed-world-size", "--distributed-world-size",
"1", "1",
"--raw-text", "--dataset-impl",
"raw",
], ],
) )
train.main(train_args) train.main(train_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment