".github/vscode:/vscode.git/clone" did not exist on "6df103debab07514266b7871b4d6171670903775"
Commit 66719e97 authored by Raul Puri's avatar Raul Puri Committed by Mohammad Shoeybi
Browse files

Faster dataloader merge (#1)

* threaded tf_dl+presplit sentences+shuffled dataset with resume

* elaborate in readme
parent fb4cbdc2
...@@ -33,6 +33,7 @@ python pretrain_bert.py \ ...@@ -33,6 +33,7 @@ python pretrain_bert.py \
--tokenizer-model-type bert-large-uncased \ --tokenizer-model-type bert-large-uncased \
--vocab-size 30522 \ --vocab-size 30522 \
--train-data wikipedia \ --train-data wikipedia \
--presplit-sentences \
--loose-json \ --loose-json \
--text-key text \ --text-key text \
--split 1000,1,1 \ --split 1000,1,1 \
...@@ -79,7 +80,7 @@ This script runs BERT pretraining with a `sentencepiece` tokenizer. If no senten ...@@ -79,7 +80,7 @@ This script runs BERT pretraining with a `sentencepiece` tokenizer. If no senten
# Collecting Wikipedia Training Data # Collecting Wikipedia Training Data
We recommend following the wikipedia data extraction process specified by google research: "the recommended pre-processing is to download [the latest dump](https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2), extract the text with [WikiExtractor.py](https://github.com/attardi/wikiextractor), and then apply any necessary cleanup to convert it into plain text." We recommend following the wikipedia data extraction process specified by google research: "the recommended pre-processing is to download [the latest dump](https://dumps.wikimedia.org/enwiki/latest/enwiki-latest-pages-articles.xml.bz2), extract the text with [WikiExtractor.py](https://github.com/attardi/wikiextractor), and then apply any necessary cleanup to convert it into plain text."
We recommend using the `--json` argument when using WikiExtractor, which will dump the wikipedia data into loose json format (one json per line), making it more manageable and readily consumable by our codebase. We recommend using the `--json` argument when using WikiExtractor, which will dump the wikipedia data into loose json format (one json per line), making it more manageable and readily consumable by our codebase. We recommend further preprocessing this json dataset by preprocessing the dataset with nltk punctuation standardization, and presplitting each document into newline separated sentences. This can be done with the provided script `./scripts/presplit_sentences_json.py` and will allow for faster data processing during training time. Pretraining with presplit data should be run with the `--presplit-sentences` flag as shown above.
Once the json dataset is ready make sure to set the path in line 27 of `data_utils/corpora.py`. Once the json dataset is ready make sure to set the path in line 27 of `data_utils/corpora.py`.
......
...@@ -184,6 +184,9 @@ def add_data_args(parser): ...@@ -184,6 +184,9 @@ def add_data_args(parser):
group = parser.add_argument_group('data', 'data configurations') group = parser.add_argument_group('data', 'data configurations')
group.add_argument('--shuffle', action='store_true',
help='Shuffle data. Shuffling is deterministic '
'based on seed and current epoch.')
group.add_argument('--train-data', nargs='+', required=True, group.add_argument('--train-data', nargs='+', required=True,
help='Filename (or whitespace separated filenames) ' help='Filename (or whitespace separated filenames) '
'for training.') 'for training.')
...@@ -208,6 +211,9 @@ def add_data_args(parser): ...@@ -208,6 +211,9 @@ def add_data_args(parser):
help='Use loose json (one json-formatted string per ' help='Use loose json (one json-formatted string per '
'newline), instead of tight json (data file is one ' 'newline), instead of tight json (data file is one '
'json string)') 'json string)')
group.add_argument('--presplit-sentences', action='store_true',
help='Dataset content consists of documents where '
'each document consists of newline separated sentences')
group.add_argument('--num-workers', type=int, default=2, group.add_argument('--num-workers', type=int, default=2,
help="""Number of workers to use for dataloading""") help="""Number of workers to use for dataloading""")
group.add_argument('--tokenizer-model-type', type=str, group.add_argument('--tokenizer-model-type', type=str,
......
...@@ -46,7 +46,7 @@ def make_data_loader(dataset, batch_size, args): ...@@ -46,7 +46,7 @@ def make_data_loader(dataset, batch_size, args):
shuffle = args.shuffle shuffle = args.shuffle
if shuffle: if shuffle:
sampler = torch.utils.data.RandomSampler(dataset) sampler = data_utils.samplers.RandomSampler(dataset, replacement=True, num_samples=batch_size*args.train_iters)
else: else:
sampler = torch.utils.data.SequentialSampler(dataset) sampler = torch.utils.data.SequentialSampler(dataset)
world_size = args.world_size world_size = args.world_size
...@@ -81,8 +81,10 @@ def make_tfrecord_loaders(args): ...@@ -81,8 +81,10 @@ def make_tfrecord_loaders(args):
'max_seq_len': args.seq_length, 'max_seq_len': args.seq_length,
'max_preds_per_seq': args.max_preds_per_seq, 'max_preds_per_seq': args.max_preds_per_seq,
'train': True, 'train': True,
'num_workers': args.num_workers, 'num_workers': max(args.num_workers, 1),
'seed': args.seed+args.rank+1} 'seed': args.seed + args.rank + 1,
'threaded_dl': args.num_workers > 0
}
train = data_utils.tf_dl.TFRecordDataLoader(args.train_data, train = data_utils.tf_dl.TFRecordDataLoader(args.train_data,
**data_set_args) **data_set_args)
data_set_args['train'] = False data_set_args['train'] = False
...@@ -140,7 +142,8 @@ def make_loaders(args): ...@@ -140,7 +142,8 @@ def make_loaders(args):
'vocab_size': args.vocab_size, 'vocab_size': args.vocab_size,
'model_type': args.tokenizer_model_type, 'model_type': args.tokenizer_model_type,
'cache_dir': args.cache_dir, 'cache_dir': args.cache_dir,
'max_preds_per_seq': args.max_preds_per_seq} 'max_preds_per_seq': args.max_preds_per_seq,
'presplit_sentences': args.presplit_sentences}
eval_set_args = copy.copy(data_set_args) eval_set_args = copy.copy(data_set_args)
eval_set_args['split'] = [1.] eval_set_args['split'] = [1.]
...@@ -218,7 +221,6 @@ def configure_data(): ...@@ -218,7 +221,6 @@ def configure_data():
'rank': -1, 'rank': -1,
'persist_state': 0, 'persist_state': 0,
'lazy': False, 'lazy': False,
'shuffle': False,
'transpose': False, 'transpose': False,
'data_set_type': 'supervised', 'data_set_type': 'supervised',
'seq_length': 256, 'seq_length': 256,
......
...@@ -46,7 +46,7 @@ def get_dataset(path, **kwargs): ...@@ -46,7 +46,7 @@ def get_dataset(path, **kwargs):
if supported_corpus(path): if supported_corpus(path):
return corpora.NAMED_CORPORA[path](**kwargs) return corpora.NAMED_CORPORA[path](**kwargs)
ext = get_ext(path) ext = get_ext(path)
if ext =='.json': if '.json' in ext:
text = json_dataset(path, **kwargs) text = json_dataset(path, **kwargs)
elif ext in ['.csv', '.tsv']: elif ext in ['.csv', '.tsv']:
text = csv_dataset(path, **kwargs) text = csv_dataset(path, **kwargs)
...@@ -108,8 +108,10 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N ...@@ -108,8 +108,10 @@ def make_dataset(path, seq_length, text_key, label_key, lazy=False, process_fn=N
if should_split(split): if should_split(split):
ds = split_ds(ds, split) ds = split_ds(ds, split)
if ds_type.lower() == 'bert': if ds_type.lower() == 'bert':
ds = [bert_sentencepair_dataset(d, max_seq_len=seq_length) for d in ds] presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
ds = [bert_sentencepair_dataset(d, max_seq_len=seq_length, presplit_sentences=presplit_sentences) for d in ds]
else: else:
if ds_type.lower() == 'bert': if ds_type.lower() == 'bert':
ds = bert_sentencepair_dataset(ds, max_seq_len=seq_length) presplit_sentences = kwargs['presplit_sentences'] if 'presplit_sentences' in kwargs else False
ds = bert_sentencepair_dataset(ds, max_seq_len=seq_length, presplit_sentences=presplit_sentences)
return ds, tokenizer return ds, tokenizer
...@@ -449,7 +449,7 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -449,7 +449,7 @@ class bert_sentencepair_dataset(data.Dataset):
dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1) dataset_size (int): number of random sentencepairs in the dataset. Default: len(ds)*(len(ds)-1)
""" """
def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, **kwargs): def __init__(self, ds, max_seq_len=512, mask_lm_prob=.15, max_preds_per_seq=None, short_seq_prob=.01, dataset_size=None, presplit_sentences=False, **kwargs):
self.ds = ds self.ds = ds
self.ds_len = len(self.ds) self.ds_len = len(self.ds)
self.tokenizer = self.ds.GetTokenizer() self.tokenizer = self.ds.GetTokenizer()
...@@ -464,6 +464,7 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -464,6 +464,7 @@ class bert_sentencepair_dataset(data.Dataset):
self.dataset_size = dataset_size self.dataset_size = dataset_size
if self.dataset_size is None: if self.dataset_size is None:
self.dataset_size = self.ds_len * (self.ds_len-1) self.dataset_size = self.ds_len * (self.ds_len-1)
self.presplit_sentences = presplit_sentences
def __len__(self): def __len__(self):
return self.dataset_size return self.dataset_size
...@@ -494,7 +495,14 @@ class bert_sentencepair_dataset(data.Dataset): ...@@ -494,7 +495,14 @@ class bert_sentencepair_dataset(data.Dataset):
def sentence_split(self, document): def sentence_split(self, document):
"""split document into sentences""" """split document into sentences"""
return tokenize.sent_tokenize(document) lines = document.split('\n')
if self.presplit_sentences:
return [line for line in lines if line]
rtn = []
for line in lines:
if line != '':
rtn.extend(tokenize.sent_tokenize(line))
return rtn
def sentence_tokenize(self, sent, sentence_num=0, beginning=False, ending=False): def sentence_tokenize(self, sent, sentence_num=0, beginning=False, ending=False):
"""tokenize sentence and get token types""" """tokenize sentence and get token types"""
......
...@@ -21,6 +21,57 @@ import torch ...@@ -21,6 +21,57 @@ import torch
from torch.utils import data from torch.utils import data
import numpy as np import numpy as np
class RandomSampler(data.sampler.Sampler):
r"""
Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler,
but this class lets the user set an epoch like DistributedSampler
Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify ``num_samples`` to draw.
Arguments:
data_source (Dataset): dataset to sample from
num_samples (int): number of samples to draw, default=len(dataset)
replacement (bool): samples are drawn with replacement if ``True``, default=False
"""
def __init__(self, data_source, replacement=False, num_samples=None):
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.epoch = -1
if self._num_samples is not None and replacement is False:
raise ValueError("With replacement=False, num_samples should not be specified, "
"since a random permute will be performed.")
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(self.num_samples))
if not isinstance(self.replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(self.replacement))
@property
def num_samples(self):
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self):
n = len(self.data_source)
g = torch.Generator()
if self.epoch >= 0:
g.manual_seed(self.epoch)
if self.replacement:
return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64, generator=g).tolist())
return iter(torch.randperm(n, generator=g).tolist())
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch
class DistributedBatchSampler(data.sampler.BatchSampler): class DistributedBatchSampler(data.sampler.BatchSampler):
""" """
similar to normal implementation of distributed sampler, except implementation is at the similar to normal implementation of distributed sampler, except implementation is at the
......
...@@ -14,12 +14,16 @@ ...@@ -14,12 +14,16 @@
# limitations under the License. # limitations under the License.
"""PyTorch DataLoader for TFRecords""" """PyTorch DataLoader for TFRecords"""
import queue
import threading
import tensorflow as tf import tensorflow as tf
tf.enable_eager_execution() tf.enable_eager_execution()
import torch import torch
import numpy as np
class TFRecordDataLoader(object): class TFRecordDataLoader(object):
def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq, train, num_workers=2, seed=1): def __init__(self, records, batch_size, max_seq_len, max_preds_per_seq, train, num_workers=2, seed=1, threaded_dl=False):
assert max_preds_per_seq is not None, "--max-preds-per-seq MUST BE SPECIFIED when using tfrecords" assert max_preds_per_seq is not None, "--max-preds-per-seq MUST BE SPECIFIED when using tfrecords"
tf.set_random_seed(seed) tf.set_random_seed(seed)
if isinstance(records, str): if isinstance(records, str):
...@@ -55,8 +59,15 @@ class TFRecordDataLoader(object): ...@@ -55,8 +59,15 @@ class TFRecordDataLoader(object):
'num_parallel_batches': num_workers, 'num_parallel_batches': num_workers,
'drop_remainder': train} 'drop_remainder': train}
self.dataloader = self.dataset.apply(tf.contrib.data.map_and_batch(self.record_converter, **loader_args)) self.dataloader = self.dataset.apply(tf.contrib.data.map_and_batch(self.record_converter, **loader_args))
self.threaded_dl = threaded_dl
self.num_workers = num_workers
def __iter__(self): def __iter__(self):
if self.threaded_dl:
data_iter = iter(MultiprocessLoader(self.dataloader, self.num_workers))
for item in data_iter:
yield item
else:
data_iter = iter(self.dataloader) data_iter = iter(self.dataloader)
for item in data_iter: for item in data_iter:
yield convert_tf_example_to_torch_tensors(item) yield convert_tf_example_to_torch_tensors(item)
...@@ -74,14 +85,37 @@ class Record2Example(object): ...@@ -74,14 +85,37 @@ class Record2Example(object):
return example return example
def convert_tf_example_to_torch_tensors(example): def convert_tf_example_to_torch_tensors(example):
item = {k: torch.from_numpy(v.numpy()) for k,v in example.items()} item = {k: (v.numpy()) for k,v in example.items()}
mask = torch.zeros_like(item['input_ids']) mask = np.zeros_like(item['input_ids'])
mask_labels = torch.ones_like(item['input_ids'])*-1 mask_labels = np.ones_like(item['input_ids'])*-1
for b, row in enumerate(item['masked_lm_positions'].long()): for b, row in enumerate(item['masked_lm_positions'].astype(int)):
for i, idx in enumerate(row): for i, idx in enumerate(row):
if item['masked_lm_weights'][b, i] != 0: if item['masked_lm_weights'][b, i] != 0:
mask[b, idx] = 1 mask[b, idx] = 1
mask_labels[b, idx] = item['masked_lm_ids'][b, i] mask_labels[b, idx] = item['masked_lm_ids'][b, i]
return {'text': item['input_ids'], 'types': item['segment_ids'],'is_random': item['next_sentence_labels'], output = {'text': item['input_ids'], 'types': item['segment_ids'],'is_random': item['next_sentence_labels'],
'pad_mask': 1-item['input_mask'], 'mask': mask, 'mask_labels': mask_labels} 'pad_mask': 1-item['input_mask'], 'mask': mask, 'mask_labels': mask_labels}
return {k: torch.from_numpy(v) for k,v in output.items()}
class MultiprocessLoader(object):
def __init__(self, dataloader, num_workers=2):
self.dl = dataloader
self.queue_size = 2*num_workers
def __iter__(self):
output_queue = queue.Queue(self.queue_size)
output_thread = threading.Thread(target=_multiproc_iter,
args=(self.dl, output_queue))
output_thread.daemon = True
output_thread.start()
while output_thread.is_alive():
yield output_queue.get(block=True)
else:
print(RuntimeError('TF record data loader thread exited unexpectedly'))
def _multiproc_iter(dl, output_queue):
data_iter = iter(dl)
for item in data_iter:
tensors = convert_tf_example_to_torch_tensors(item)
output_queue.put(tensors, block=True)
\ No newline at end of file
...@@ -434,6 +434,8 @@ def main(): ...@@ -434,6 +434,8 @@ def main():
train_data.batch_sampler.start_iter = total_iters % len(train_data) train_data.batch_sampler.start_iter = total_iters % len(train_data)
# For all epochs. # For all epochs.
for epoch in range(start_epoch, args.epochs+1): for epoch in range(start_epoch, args.epochs+1):
if args.shuffle:
train_data.batch_sampler.sampler.set_epoch(epoch+args.seed)
timers('epoch time').start() timers('epoch time').start()
iteration, skipped = train_epoch(epoch, model, optimizer, iteration, skipped = train_epoch(epoch, model, optimizer,
train_data, lr_scheduler, train_data, lr_scheduler,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment