Commit 66719e97 authored by Raul Puri's avatar Raul Puri Committed by Mohammad Shoeybi
Browse files

Faster dataloader merge (#1)

* threaded tf_dl+presplit sentences+shuffled dataset with resume

* elaborate in readme
parent fb4cbdc2
......@@ -33,6 +33,7 @@ python pretrain_bert.py \
--tokenizer-model-type bert-large-uncased \
--vocab-size 30522 \
--train-data wikipedia \
--presplit-sentences \
--loose-json \
--text-key text \
--split 1000,1,1 \
......@@ -79,7 +80,7 @@ This script runs BERT pretraining with a `sentencepiece` tokenizer. If no senten
# Collecting Wikipedia Training Data
We recommend following the wikipedia data extraction process specified by google research: "the recommended pre-processing is to download [the latest dump](https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2), extract the text with [WikiExtractor.py](https://github.com/attardi/wikiextractor), and then apply any necessary cleanup to convert it into plain text."
We recommend using the `--json` argument when using WikiExtractor, which will dump the wikipedia data into loose json format (one json per line), making it more manageable and readily consumable by our codebase.
We recommend using the `--json` argument when using WikiExtractor, which will dump the wikipedia data into loose json format (one json per line), making it more manageable and readily consumable by our codebase. We recommend further preprocessing this json dataset by preprocessing the dataset with nltk punctuation standardization, and presplitting each document into newline separated sentences. This can be done with the provided script `./scripts/presplit_sentences_json.py` and will allow for faster data processing during training time. Pretraining with presplit data should be run with the `--presplit-sentences` flag as shown above.
Once the json dataset is ready make sure to set the path in line 27 of `data_utils/corpora.py`.
......
......@@ -184,6 +184,9 @@ def add_data_args(parser):
group = parser.add_argument_group('data', 'data configurations')
group.add_argument('--shuffle', action='store_true',
help='Shuffle data. Shuffling is deterministic '
'based on seed and current epoch.')
group.add_argument('--train-data', nargs='+', required=True,
help='Filename (or whitespace separated filenames) '
'for training.')
......@@ -208,6 +211,9 @@ def add_data_args(parser):
help='Use loose json (one json-formatted string per '
'newline), instead of tight json (data file is one '
'json string)')
group.add_argument('--presplit-sentences', action='store_true',
help='Dataset content consists of documents where '
'each document consists of newline separated sentences')
group.add_argument('--num-workers', type=int, default=2,
help="""Number of workers to use for dataloading""")
group.add_argument('--tokenizer-model-type', type=str,
......
......@@ -46,7 +46,7 @@ def make_data_loader(dataset, batch_size, args):
shuffle = args.shuffle
if shuffle:
sampler = torch.utils.data.RandomSampler(dataset)
sampler = data_utils.samplers.RandomSampler(dataset, replacement=True, num_samples=batch_size*args.train_iters)
else:
sampler = torch.utils.data.SequentialSampler(dataset)
world_size = args.world_size
......@@ -81,8 +81,10 @@ def make_tfrecord_loaders(args):
'max_seq_len': args.seq_length,
'max_preds_per_seq': args.max_preds_per_seq,
'train': True,
'num_workers': args.num_workers,
'seed': args.seed+args.rank+1}
'num_workers': max(args.num_workers, 1),
'seed': args.seed + args.rank + 1,
'threaded_dl': args.num_workers > 0
}
train = data_utils.tf_dl.TFRecordDataLoader(args.train_data,
**data_set_args)
data_set_args['train'] = False
......@@ -140,7 +142,8 @@ def make_loaders(args):
'vocab_size': args.vocab_size,
'model_type': args.tokenizer_model_type,
'cache_dir': args.cache_dir,
'max_preds_per_seq': args.max_preds_per_seq}
'max_preds_per_seq': args.max_preds_per_seq,
'presplit_sentences': args.presplit_sentences}
eval_set_args = copy.copy(data_set_args)
eval_set_args['split'] = [1.]
......@@ -218,7 +221,6 @@ def configure_data():
'rank': -1,
'persist_state': 0,
'lazy': False,
'shuffle': False,
'transpose': False,
'data_set_type': 'supervised',
'seq_length': 256,
......
......@@ -46,7 +46,7 @@ def get_dataset(path, **kwargs):
if supported_corpus(path):
return corpora.NAMED_CORPORA[path](**kwargs)
ext = get_ext(path)
if ext =='.json':
if '.json' in ext:
text = json_dataset(path, **kwargs)
elif ext in ['.csv', '.tsv']:
text = csv_dataset(path, **kwargs)
......@@ -108,8 +108,10 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
if should_split(split):
ds = split_ds(ds, split)
if ds_type.lower() == 'bert':
ds = [bert_sentencepair_dataset(d, max_seq_len=seq_length) for d in ds]
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
ds = [bert_sentencepair_dataset(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences) for d in ds]
else:
if ds_type.lower() == 'bert':
ds = bert_sentencepair_dataset(ds, max_seq_len=seq_length)
presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
ds = bert_sentencepair_dataset(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
return ds, tokenizer
......@@ -449,7 +449,7 @@ class bert_sentencepair_dataset(data.Dataset):
dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1)
"""
def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, **kwargs):
def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, presplit_sentences=False, **kwargs):
self.ds = ds
self.ds_len = len(self.ds)
self.tokenizer = self.ds.GetTokenizer()
......@@ -464,6 +464,7 @@ class bert_sentencepair_dataset(data.Dataset):
self.dataset_size = dataset_size
if self.dataset_size is None:
self.dataset_size = self.ds_len * (self.ds_len-1)
self.presplit_sentences = presplit_sentences
def __len__(self):
return self.dataset_size
......@@ -494,7 +495,14 @@ class bert_sentencepair_dataset(data.Dataset):
def sentence_split(self, document):
"""split document into sentences"""
return tokenize.sent_tokenize(document)
lines = document.split('\n')
if self.presplit_sentences:
return [line for line in lines if line]
rtn = []
for line in lines:
if line != '':
rtn.extend(tokenize.sent_tokenize(line))
return rtn
def sentence_tokenize(self, sent, sentence_num=0, beginning=False, ending=False):
"""tokenize sentence and get token types"""
......
......@@ -21,6 +21,57 @@ import torch
from torch.utils import data
import numpy as np
class RandomSampler(data.sampler.Sampler):
r"""
Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler,
but this class lets the user set an epoch like DistributedSampler
Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify ``num_samples`` to draw.
Arguments:
data_source (Dataset): dataset to sample from
num_samples (int): number of samples to draw, default=len(dataset)
replacement (bool): samples are drawn with replacement if ``True``, default=False
"""
def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.epoch = -1
if self._num_samples is not None and replacement is False:
raise ValueError("With replacement=False, num_samples should not be specified, "
"since a random permute will be performed.")
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(self.num_samples))
if not isinstance(self.replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))
@property
def num_samples(self):
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self):
n = len(self.data_source)
g = torch.Generator()
if self.epoch >= 0:
g.manual_seed(self.epoch)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64, generator=g).tolist())
return iter(torch.randperm(n, generator=g).tolist())
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
class DistributedBatchSampler(data.sampler.BatchSampler):
"""
similar to normal implementation of distributed sampler, except implementation is at the
......
......@@ -14,12 +14,16 @@
# limitations under the License.
"""PyTorch DataLoader for TFRecords"""
import queue
import threading
import tensorflow as tf
tf.enable_eager_execution()
import torch
import numpy as np
class TFRecordDataLoader(object):
def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq, train, num_workers=2, seed=1):
def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq, train, num_workers=2, seed=1, threaded_dl=False):
assert max_preds_per_seq is not None, "--max-preds-per-seq MUST BE SPECIFIED when using tfrecords"
tf.set_random_seed(seed)
if isinstance(records, str):
......@@ -55,11 +59,18 @@ class TFRecordDataLoader(object):
'num_parallel_batches': num_workers,
'drop_remainder': train}
self.dataloader = self.dataset.apply(tf.contrib.data.map_and_batch(self.record_converter, **loader_args))
self.threaded_dl = threaded_dl
self.num_workers = num_workers
def __iter__(self):
data_iter = iter(self.dataloader)
for item in data_iter:
yield convert_tf_example_to_torch_tensors(item)
if self.threaded_dl:
data_iter = iter(MultiprocessLoader(self.dataloader, self.num_workers))
for item in data_iter:
yield item
else:
data_iter = iter(self.dataloader)
for item in data_iter:
yield convert_tf_example_to_torch_tensors(item)
class Record2Example(object):
def __init__(self, feature_map):
......@@ -74,14 +85,37 @@ class Record2Example(object):
return example
def convert_tf_example_to_torch_tensors(example):
item = {k: torch.from_numpy(v.numpy()) for k,v in example.items()}
mask = torch.zeros_like(item['input_ids'])
mask_labels = torch.ones_like(item['input_ids'])*-1
for b, row in enumerate(item['masked_lm_positions'].long()):
item = {k: (v.numpy()) for k,v in example.items()}
mask = np.zeros_like(item['input_ids'])
mask_labels = np.ones_like(item['input_ids'])*-1
for b, row in enumerate(item['masked_lm_positions'].astype(int)):
for i, idx in enumerate(row):
if item['masked_lm_weights'][b, i] != 0:
mask[b, idx] = 1
mask_labels[b, idx] = item['masked_lm_ids'][b, i]
return {'text': item['input_ids'], 'types': item['segment_ids'],'is_random': item['next_sentence_labels'],
'pad_mask': 1-item['input_mask'], 'mask': mask, 'mask_labels': mask_labels}
output = {'text': item['input_ids'], 'types': item['segment_ids'],'is_random': item['next_sentence_labels'],
'pad_mask': 1-item['input_mask'], 'mask': mask, 'mask_labels': mask_labels}
return {k: torch.from_numpy(v) for k,v in output.items()}
class MultiprocessLoader(object):
def __init__(self, dataloader, num_workers=2):
self.dl = dataloader
self.queue_size = 2*num_workers
def __iter__(self):
output_queue = queue.Queue(self.queue_size)
output_thread = threading.Thread(target=_multiproc_iter,
args=(self.dl, output_queue))
output_thread.daemon = True
output_thread.start()
while output_thread.is_alive():
yield output_queue.get(block=True)
else:
print(RuntimeError('TF record data loader thread exited unexpectedly'))
def _multiproc_iter(dl, output_queue):
data_iter = iter(dl)
for item in data_iter:
tensors = convert_tf_example_to_torch_tensors(item)
output_queue.put(tensors, block=True)
\ No newline at end of file
......@@ -434,6 +434,8 @@ def main():
train_data.batch_sampler.start_iter = total_iters % len(train_data)
# For all epochs.
for epoch in range(start_epoch, args.epochs+1):
if args.shuffle:
train_data.batch_sampler.sampler.set_epoch(epoch+args.seed)
timers('epoch time').start()
iteration, skipped = train_epoch(epoch, model, optimizer,
train_data, lr_scheduler,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment